プロジェクト概要
DeepGEMM は、DeepSeek チームによって開発されたオープンソースの CUDA ライブラリで、FP8 型の一般行列乗算(GEMM)を効率的に実装することに専念しています。このライブラリは、深層学習における行列演算をサポートし、高性能を提供することを目指しています。DeepGEMM は、普通の行列乗算と混合専門家モデル(MoE)におけるグループ行列乗算を含む、複数の行列乗算シーンをサポートしています。このライブラリは、NVIDIA Hopper アーキテクチャの Tensor Memory Accelerator(TMA)と warp 級の最適化技術を利用して、顕著な性能向上を実現しています。
コア機能
FP8 行列乗算
DeepGEMM は、FP8 型に基づく行列乗算機能を提供し、入力行列を FP8 形式に変換して計算を行い、出力結果は BF16 形式で表されます。この設計は、計算精度を保証しながら、FP8 の低精度の利点を十分に活用し、計算効率を大幅に向上させます。
グループ行列乗算
DeepGEMM は、連続グループモードとマスク付きグループモードを含む、複数のグループ行列乗算モードをサポートしています。これらのモードは、大量の小行列をグループ化して計算する必要がある混合専門家モデル(MoE)などのシーンに特に適しています。複数の行列をグループ化してパッケージ化することで、DeepGEMM は、呼び出しの頻繁な開発コストを削減し、計算効率を向上させます。
TMA 読み込みと書き込み
DeepGEMM は、Hopper アーキテクチャにおける TMA の特性を十分に活用し、グローバルメモリから共有メモリへのデータ搬送を加速します。TMA のマルチキャスト機能により、複数のスレッドブロックが同じ LHS データを共有することができ、データ伝送の帯域幅消費を減らします。
精密なスケーリング係数
FP8 型は、数値範囲を保証するためのスケーリング係数を必要とします。DeepGEMM は、LHS と RHS にそれぞれ異なる次元のスケーリング係数を使用し、列優先/行優先の TMA アライメント要求を組み込んでおり、計算の正確性和効率性を保証します。
革新と特徴
完全 JIT コンパイルメカニズム
DeepGEMM は、完全な Just-In-Time(JIT)コンパイルメカニズムを採用しています。カーネルの CUDA コードは、実行時に具体的な行列の寸法に応じて動的に生成、コンパイルされ、ロードされます。このメカニズムは、コンパイラが異なる形状の GEMM に対してより多くの最適化を行うことを許可し、同時にテンプレートやマクロの複雑さを減らします。
TMA マルチキャスト機能
DeepGEMM は、TMA 読み込みでマルチキャスト機能をサポートしており、同じ SM 内の複数のスレッドブロックが LHS データを共有することができるため、データの重複読み込みを避けることができます。このメカニズムは、複数のスレッドブロックが同じ LHS データを必要とする場合に帯域幅消費を大幅に最適化します。
FFMA SASS 後処理
DeepGEMM は、コンパイルされた SASS の FFMA(Fused FMA)命令に対してバイナリレベルの置換を行い、一部のフラグビットを再設定することで、warp が他の命令と効果的にパイプライン化できるようにしています。この最適化は、特定のケースで顕著な性能向上をもたらします。
複数の GEMM モード
DeepGEMM は、普通の GEMM、連続グループ GEMM、マスク付きグループ GEMM の複数の GEMM モードを提供しています。これらのモードは、異なるシーンで異なるインターフェースを使用していますが、内部のコアロジックは同じ底层カーネル上でスケジュールされています。
自動ブロックサイズとパイプラインステージ
DeepGEMM は、行列の規模に応じて自動的に適切なブロックサイズとパイプラインステージ数を選択することができるので、レジスタの使用と並列スケールを最適にバランスさせることができます。この自己適応メカニズムは、異なる行列形状で最適な性能を実現することを保証します。
パフォーマンス
DeepGEMM は、異なる行列形状におけるパフォーマンスが優れており、内部の最適化された CUTLASS 3.6 の実装と比較して、速度向上が顕著です。例えば、H800 GPU 上では、特定の行列形状に対して、DeepGEMM の計算速度は 2.7 倍の向上を達成しています。
開発と使用
開発環境の要件
-
Hopper アーキテクチャ GPU、sm_90a をサポート
-
Python 3.8 またはそれ以降のバージョン
-
CUDA 12.3 またはそれ以降のバージョン(12.8 またはそれ以降のバージョンを使用することをお勧めします)
-
PyTorch 2.1 またはそれ以降のバージョン
-
CUTLASS 3.6 またはそれ以降のバージョン
インストールとテスト
プロジェクトをクローンし、サブモジュールを初期化します。
bash
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git シンボリックリンクを作成し、インストールします。
bash
python setup.py develop JIT コンパイルをテストします。
bash
python tests/test_jit.py すべての GEMM の実装をテストします。
bash
python tests/test_core.py
使用例
Python プロジェクトで deep_gemm をインポートし、提供される GEMM 関数を使用します。例えば、普通の GEMM 操作を実行するには:
Python
import deep_gemm result = deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, lhs_scale, rhs_scale)
最適化技術
持続的な warp 特化
DeepGEMM のカーネル設計では、データ伝送、テンソルコア MMA 命令、CUDA コアの拡張が並列に実行されるように warp 特化が採用されており、計算効率が向上します。
TMA 特性の活用
DeepGEMM は、Hopper アーキテクチャにおける TMA の特性を十分に活用し、TMA 読み込み、書き込み、マルチキャスト、ディスクリプタプリフェッチを含め、データ搬送の効率を大幅に向上させます。
統一されたブロックスケジューラ
DeepGEMM は、普通とグループ GEMM カーネルを処理するための統一されたブロックスケジューラを使用し、グリッド化されたアクセスパターンにより L2 キャッシュの利用率が高くなります。
完全な JIT デザイン
DeepGEMM の完全な JIT デザインにより、実行時に行列の形状に応じて動的にカーネルを生成してコンパイルすることができ、テンプレートやマクロの複雑さを減らし、同時にコンパイラがより多くの最適化を行うことができます。
非アライメントブロックサイズ
DeepGEMM は、SM リソースを十分に活用し、計算効率を向上させるために非アライメントブロックサイズをサポートしています。
FFMA 命令のインターリーブ
DeepGEMM は、コンパイルされた FFMA 命令に対して後処理を行い、フラグビットを再設定することで、warp が他の命令と効果的にパイプライン化できるようにし、性能を向上させます。
まとめ
DeepGEMM は、深層学習における行列演算に特に適した効率的で柔軟で使いやすい FP8 行列乗算ライブラリです。このライブラリは、JIT コンパイル、TMA 読み込みと書き込み、FFMA 命令のインターリーブなどの技術を活用して、顕著な性能向上を実現しています。DeepGEMM のコードスタイルは明瞭で、主なロジックは fp8_gemm.cuh と jit_kernels サブモジュールに集中しているため、Hopper アーキテクチャ下での高性能 GEMM の実装に興味がある開発者は、深く研究する価値があります。
aiスピーキング
ドルフィンAIは言語学習アプリケーションのためのプロフェッショナルな発音評価API(pronunciation assessment api)ソリューションを提供します。音素、単語、文章、チャプター、発音矯正、単語矯正、クイズ、フリーダイアログ、多肢選択問題など幅広く提供しています。当社の発音評価製品(pronunciation assessment)は、英語と中国語、クラウドAPI、オンプレミス、オフラインデバイスの展開をサポートしています。当社の発音評価API(pronunciation assessment api)は、正確性、流暢性、完全性、リズムの次元をカバーする豊富な評価指標を提供し、音素、単語、文の異なるレベルの評価スコアも提供します。また、音素、単語、文の異なるレベルでの評価スコアも提供します。数千万人のユーザーに安定した効率的で安全なサービスを提供しています。ドルフィンAIの発音評価製品(pronunciation assessment)を試してみませんか?