TransformerのSelf-Attention機構

Tech

TransformerのSelf-Attention機構

要点(3行)

  • Self-Attention機構は、Transformerが系列データの長距離依存関係を効率的に捉える中核技術であり、各トークン間の関連性を動的に学習します。

  • その計算量は系列長Lに対してO(L^2)ですが、FlashAttention-2によるIO最適化や、GQA、Sparse-FlashAttentionといった手法により実効速度とメモリ効率が大幅に改善されています。

  • 大規模言語モデルの運用では、KVキャッシュがメモリボトルネックとなるため、これらの効率化手法やMambaのようなAttention代替アーキテクチャの理解が不可欠です。

背景(課題/先行研究/最新動向)

Transformerモデルの中核をなすSelf-Attention機構は、自然言語処理分野に革命をもたらしました。従来のリカレントニューラルネットワーク(RNN)や畳み込みニューラルネットワーク(CNN)が抱えていた、長距離依存関係の学習困難さや並列計算の非効率性といった課題を克服するために、2017年6月12日に発表された論文「Attention Is All You Need」[1]によって導入されました。この機構により、系列内の任意の2つのトークン間の関連性を直接計算できるようになり、モデルの表現能力と訓練効率が飛躍的に向上しました。

しかし、Self-Attentionの計算量は系列長Lに対してO(L^2)であるため、特に非常に長いコンテキストを扱う際には、計算時間とGPUメモリ消費が大きなボトルネックとなります。この課題を解決するため、様々な効率化手法が提案され、現在も活発に研究開発が進められています。

最新動向(直近90日、2024年4月22日以降):

  • Sparse-FlashAttentionによる長文コンテキスト処理の効率化:2024年6月14日に発表された研究では、疎なAttentionメカニズムとFlashAttentionを組み合わせることで、長文コンテキストにおける計算量を線形に削減し、準線形メモリ利用を実現するSparse-FlashAttentionが提案されました。これは、特にKVキャッシュの管理を効率化します[6]。

  • データ並列AttentionによるLLM効率化:2024年5月10日には、分散学習環境で大規模言語モデル(LLM)が非常に長いコンテキストを効率的に処理するための、キー・バリュー(KV)キャッシュを共有するデータ並列Attention手法が発表されました。これはGPU間の通信を最適化し、推論効率を高めます[9]。

提案手法 / モデル構造

Self-Attention機構の核は、入力系列内の各要素(トークン)が、その系列内の他の全ての要素とどれだけ関連しているかを学習し、関連性の高い情報に「注意を払う」ことで、より豊かな表現を生成する点にあります。このプロセスは、Query(Q)、Key(K)、Value(V)という3つの概念に基づいて行われます。

  1. 線形変換: 入力されたトークン埋め込み(または前の層の出力)は、3つの異なる線形変換層を通ってQ、K、V行列に変換されます。

  2. スコア計算: Query行列とKey行列の転置との内積を計算することで、各クエリが各キーに対してどれだけ類似しているか、つまり「注意度」のスコアが算出されます。

  3. スケーリングとSoftmax: スコアはキーの次元数d_kの平方根でスケーリングされ、勾配の消失を防ぎます。その後、Softmax関数を適用することで、スコアが0から1の範囲の確率分布(Attention重み)に変換されます。

  4. 出力の結合: このAttention重みとValue行列の内積を計算することで、各トークンが他のトークンから受け取る情報の加重平均が算出され、Attention機構の出力となります。

Multi-Head Attention (MHA) は、このSelf-Attention機構を複数並列に実行するものです。それぞれの「ヘッド」が異なるQ, K, Vの線形変換を持ち、異なる「注意の仕方」を学習することで、モデルが多様な側面から情報を捉え、表現能力を向上させます。各ヘッドの出力は結合され、最終的な線形変換を経て出力されます。

Self-Attentionの計算フロー

graph TD
    A["入力トークン埋め込み"] --> B("線形変換層")
    B --> Q_L["Query行列 (Q)"]
    B --> K_L["Key行列 (K)"]
    B --> V_L["Value行列 (V)"]
    Q_L --- D1(Q)
    K_L --- D2(K)
    V_L --- D3(V)
    D1 -- 行列積 --> E["Q * K^T"] |類似度計算|
    E -- スケーリング --> F["Scaled Scores"] |安定化|
    F -- Softmax --> G["Attention Weights"] |確率分布化|
    G -- 行列積 --> H["Attention Output"] |加重平均|
    H --> I["出力"]

Self-Attentionの擬似コード

# Self-Attention Calculation (Pseudo Code)


# 入力: Q (Query行列), K (Key行列), V (Value行列), d_k (Keyの次元数)


# 出力: Attentionの出力行列


# 前提: Q, K, Vは同じ系列長 L とヘッド次元 d_head を持つ (L, d_head) の形式。


#       Multi-Head Attentionでは、各ヘッドに分割されたQ, K, Vが入力される。


# 計算量: L = シーケンス長, d_head = 各ヘッドの次元


#   - QK^T: O(L^2 * d_head)


#   - Softmax: O(L^2)


#   - Attention @ V: O(L^2 * d_head)


# 全体: O(L^2 * d_head)


# メモリ: Attention Weights行列のために O(L^2) のメモリが必要。

def self_attention(Q, K, V, d_k, mask=None):

    # 1. QとKの転置の内積を計算 (類似度スコア)

    scores = matmul(Q, transpose(K)) # サイズ: (L, L)

    # 2. スケーリング


    # d_kの平方根で割ることで、内積の値を安定化させ、Softmaxの勾配消失を防ぐ

    scores = scores / sqrt(d_k)

    # 3. マスク適用 (オプション: 例として未来のトークンを見せないための因果マスク)


    # デコーダーでは、現在のトークンが未来のトークンを参照しないようにマスクを適用

    if mask is not None:

        # マスク位置に非常に小さな負の値を加える (Softmaxでほぼ0になるように)

        scores = scores + mask * (-1e9)

    # 4. Softmax関数を適用


    # スコアを確率分布 (Attention重み) に変換。各行の和が1になる

    attention_weights = softmax(scores, axis=-1) # サイズ: (L, L)

    # 5. Attention重みとV行列の内積を計算


    # 各出力トークンは、Vの重み付き和として表現される

    output = matmul(attention_weights, V) # サイズ: (L, d_head)

    return output

計算量/メモリ/スケーリング

標準的なSelf-Attention機構の計算量とメモリ消費は、そのスケーラビリティにおける主要な課題です。

  • 計算量: 系列長Lとモデルの隠れ次元d_modelに対して、QueryとKeyの内積計算(QK^T)がO(L^2 * d_model)、Attention重みとValueの内積計算(Attn * V)もO(L^2 * d_model)となるため、全体としてO(L^2 * d_model)の計算量を持ちます。

  • メモリ消費: 特にAttention重み行列(L x L)を保持するためにO(L^2)のメモリが必要となります。さらに、大規模言語モデルの推論時には、以前のトークンのKeyとValueをキャッシュするKVキャッシュが重要なメモリボトルネックとなります。これはO(L * d_model * n_layers * n_heads)に比例し、Lが長くなるほど急増します。

これらの課題に対処するため、様々な効率化技術が開発されています。

  • FlashAttention-2 [4]:2023年7月12日に発表されたFlashAttention-2は、O(L^2 * d_model)という理論的な計算量は変えずに、GPUメモリの階層構造(SRAMとHBM)を最大限に活用するIO最適化されたCUDAカーネルによって、実効速度を大幅に向上させ、中間結果のHBMへの書き込みを削減することでメモリ効率を高めます。

  • Grouped-Query Attention (GQA) [5]:2023年5月26日に発表されたGQAは、Multi-Head Attention (MHA) と Multi-Query Attention (MQA) の中間的なアプローチです。MHAでは全てのQueryヘッドが独立したKey/Valueヘッドを持つ一方、MQAでは全てのQueryヘッドが単一のKey/Valueヘッドを共有します。GQAは複数のQueryヘッドが少数のKey/Valueヘッドを共有することで、推論時のKVキャッシュのメモリフットプリントを大幅に削減し、高速化と品質のバランスを取ります。

  • Performer (線形Attention) [2]:2020年9月23日に提案されたPerformerは、カーネル近似を用いることでAttentionの計算量をO(L * d_model^2)d_modelを一定と見なせばO(L))に線形化します。これにより、理論上は長系列に対応できますが、近似による品質低下のトレードオフが存在します。

  • Ring Attention [7]:2023年10月24日に発表されたRing Attentionは、分散学習環境において非常に長いシーケンスを処理するための手法です。各デバイスがシーケンスの一部を計算し、Attention計算に必要なKey/Valueをリング状に効率的に交換することで、単一GPUのメモリ制約を克服し、最大100万トークンを超えるコンテキストウィンドウを可能にします。

  • Sparse-FlashAttention [6]:2024年6月14日に発表されたSparse-FlashAttentionは、特定のパターンで疎なAttentionのみを計算することで、Attentionの計算量を線形に削減しつつ、FlashAttentionの高速性を活用します。

実験設定/再現性

Transformerモデルの学習において、Self-Attention機構の性能を最大限に引き出すためには、適切な実験設定と再現性の確保が重要です。

  • オプティマイザと学習率スケジューリング: 一般的にAdamWオプティマイザが広く用いられ、学習率にはWarmup期間を持つコサイン減衰などのスケジューリングが適用されます。これにより、学習初期の不安定性を回避し、モデルが最適な学習パスを見つけやすくなります。

  • 乱数シードの固定: 実験の再現性を保証するために、全ての乱数生成箇所(モデルの初期化、データシャッフル、ドロップアウトなど)でシードを固定することが不可欠です。

  • 環境と依存関係: FlashAttention-2やSparse-FlashAttentionといった高速化ライブラリは、特定のCUDAバージョンやPyTorchバージョンに依存するカスタムカーネルを使用することが多いため、実験環境の正確な記述と再現可能なパッケージ管理(例: condapipのrequirements.txt)が求められます。

  • ベンチマーク: 性能評価には、シーケンス長、バッチサイズ、ヘッド数などのパラメータを変化させ、Flops(浮動小数点演算数)、スループット(トークン/秒)、メモリ消費量などを測定します。

結果(表)

Self-Attentionの主要な効率化手法を比較した結果を以下の表に示します。

特徴 標準Self-Attention FlashAttention-2 [4] GQA (8クエリ/1KVヘッド) [5] Sparse-FlashAttention [6] Performer [2] Mamba [8]
カテゴリ ベースライン IO最適化 KVキャッシュ最適化 疎なAttention 線形近似 状態空間モデル
計算量 (理論) O(L²d) O(L²d) (IO最適化) O(L²d_Q + Ld_K*n_KV_heads) O(Ld) (近似) O(Ld²) (近似) O(Ld)
メモリ消費 (訓練) O(L² + Ld) O(Ld) (FlashAttention) O(L² + Ld) O(Ld) O(Ld) O(Ld)
メモリ消費 (推論/KV) O(Ldn_layers) O(Ldn_layers) O(Ld/n_groups * n_layers) (削減) O(Ldn_layers) O(Ldn_layers) O(Ldn_layers) (効率的)
実効レイテンシ Baseline 約2-4倍高速 (訓練) MQAより高性能、MHAより高速 高速 高速 Attention比で高速
モデル品質 Baseline 同等 MHAに近い (MQAより高) 同等 (要検証) 通常は若干低下 (近似のため) Attentionに匹敵または上回る
備考 全てのトークンを考慮 CUDAカーネルレベルの最適化 KVキャッシュ効率化 長文向け、疎なAttention カーネル近似、線形スケール 並列・線形スケーリング

*注釈: d = d_modelまたはd_headL = シーケンス長。n_layers = 層数。n_groups = GQAにおけるKVヘッドグループ数。d_Qはクエリ次元、d_Kはキー次元。

考察(仮説と根拠を分離)

Self-Attention機構は、その非局所的な情報集約能力により、RNNやCNNでは困難であった長距離依存関係のモデリングを可能にしました。これはTransformerが多くのタスクで高性能を示す主要な理由です。

強み:

  • 長距離依存関係の直接モデリング: 入力系列内の任意の2つのトークン間で直接的な関連性を計算できるため、遠く離れた単語間の意味的・文法的関係を効果的に捉えられます[1]。

  • 高い並列性: 各トークンのAttention計算が独立して行えるため、GPUなどの並列計算デバイスでの効率的な訓練が可能です。

課題と対策:

  • O(L^2)の計算量とメモリ消費: 系列長Lが長くなると、計算時間とメモリ使用量が急増します。このため、大規模なデータセットや非常に長いコンテキストの処理には制約が生じます。

    • IO最適化: FlashAttention-2 [4] は、Attention計算のIOボトルネックを解消し、実効速度とメモリ効率を向上させました。これは理論計算量を維持しつつ、GPUハードウェアに最適化することで性能を向上させるアプローチです。

    • KVキャッシュ効率化: 推論時のメモリボトルネックであるKVキャッシュを削減するため、GQA [5] はMHAとMQAの中間的なアプローチを取り、複数のQueryヘッドで少数のKey/Valueヘッドを共有することで、メモリ消費と速度のバランスを改善しています。

    • 近似と疎化: Performer [2] のような線形Attentionは、Attention行列を近似することで計算量を線形に削減します。また、Sparse-FlashAttention [6] は、Attention行列の大部分が不要であることを利用して、計算対象を限定し線形計算量を実現します。

    • 分散処理: Ring Attention [7] は、超長文コンテキストを分散環境で効率的に処理するためのアプローチであり、単一デバイスのメモリ限界を突破します。

  • Attentionの代替: Mamba [8] は、Attention機構に依存しない新しい状態空間モデルとして登場し、系列長に対して線形の計算量とメモリ消費で長距離依存関係を捉え、Transformerベースモデルに匹敵する性能を示しています。これは、Attentionベースのアーキテクチャの限界に対する代替手段として注目されています。

失敗例・感度分析

Self-Attention機構は強力ですが、その特性を理解せずに使用すると、期待する性能が得られない場合があります。

  • 過度に長いシーケンス: O(L^2)の計算量により、シーケンス長が長すぎるとGPUメモリ不足が発生し、モデルの訓練や推論が不可能になります。例えば、GPT-3のようなモデルで数万トークンのシーケンス長を扱う場合、KVキャッシュが非常に大きなメモリを消費します。

  • マスクの不適切な使用: デコーダーモデルにおいて、未来のトークンを参照しないようにする因果マスク(Look-ahead Mask)の適用を誤ると、情報リークによる不適切な予測や、逆に必要なコンテキストを見落とす可能性があります。

  • ハイパーパラメータの感度: Attentionヘッドの数、キーの次元d_k、ドロップアウト率などのハイパーパラメータは、モデルの性能に大きく影響します。特にd_kのスケーリングファクターは、Attentionスコアの分布と勾配の安定性に直接関わります。

  • KVキャッシュのボトルネック: 生成タスクにおいて、一度計算されたKeyとValueを再利用するKVキャッシュは推論を高速化しますが、メモリ消費の主要因となります。MHAを使用し、トークン数が増えるにつれて、KVキャッシュは加速度的に大きくなり、GPUの限られたメモリを消費し尽くします。この問題に対処しないと、バッチサイズを極端に小さくしたり、シーケンス長を制限したりせざるを得なくなります。GQAやMQAへの移行は、このボトルネックを緩和する有効な手段です。

限界と今後

Self-Attention機構は現代の深層学習モデルに不可欠な要素ですが、いくつかの限界が存在し、今後の研究方向性を示唆しています。

  • O(L^2)の計算量問題: 理論的な計算量の二次スケールは依然として大きな課題です。FlashAttention-2のような実用的な高速化は進むものの、根本的な線形化にはトレードオフ(近似による精度低下など)が伴うことが多いです。今後も、精度を保ちつつ計算量を削減する新しい疎なAttentionパターンや、より効率的な近似手法の研究が続けられるでしょう。

  • メモリ制約: KVキャッシュは推論時のメモリ消費のボトルネックであり続けます。これに対処するため、KVキャッシュのさらなる圧縮技術(量子化など)、メインメモリやディスクへのオフロード、動的なキャッシュ管理、またはキャッシュ自体を不要にするアーキテクチャの探求が進められています。

  • Attentionの代替アーキテクチャ: Mamba [8] のような状態空間モデルは、Attention機構に依らず線形計算量で長距離依存関係を捉える能力を示し、Transformerの主要な競合として注目されています。今後、Attentionと状態空間モデルのハイブリッドアプローチや、全く新しいシーケンスモデリング手法が登場する可能性があります。

  • マルチモーダル応用: Self-Attention機構は、画像、音声、動画などの異なるモダリティデータを統合し、それらの間の関係性を学習するマルチモーダルAIの分野でも広く応用されています。将来的には、より多様なデータ形式に対応し、効率的に融合するAttentionベースのモデルが開発されることが期待されます。

初心者向け注釈

  • Attention(アテンション): 直訳すると「注意」という意味です。人間が何かを理解するとき、特定の情報に「注意を向ける」ように、機械学習モデルが入力データの中で、特に重要な部分や関連性の高い部分に「注意を向ける」仕組みのことです。

  • Self-Attention(セルフアテンション): 入力された文章の中で、各単語が「その文章内の他のどの単語に注意を払うべきか」を計算する仕組みです。例えば、「彼はリンゴを食べた。それは美味しかった。」という文で、「それ」が「リンゴ」を指しているとモデルが判断するのに役立ちます。

  • Query(クエリ), Key(キー), Value(バリュー): Self-Attentionの計算に使われる3つの概念です。

    • クエリ (Query): 「私は何を知りたいか?」という「検索する単語」とイメージできます。

    • キー (Key): 「この単語はどんな情報を持っているか?」という「検索対象の単語」とイメージできます。

    • バリュー (Value): 「この単語が持っている情報そのもの」とイメージできます。 クエリとキーがどれだけ似ているかを計算し、その度合い(注意度)に応じてバリューの情報を集めてきます。

  • Softmax(ソフトマックス): 計算された注意度を、合計が1になるような確率の形に変換する関数です。これにより、どの単語にどれくらいの「注意」を向けたかがパーセンテージで示されるようになります。

  • Multi-Head Attention(マルチヘッドアテンション): 一つのSelf-Attentionだけでなく、複数の異なるSelf-Attentionを同時に行うことです。例えるなら、一つの問題を解決するために、複数の専門家(ヘッド)がそれぞれ異なる視点から情報を分析し、その結果を統合するようなものです。これにより、より多角的に文脈を理解できるようになります。

参考文献(リンク健全性チェック済み)

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems, 30. Retrieved from https://arxiv.org/abs/1706.03762 (発表日: 2017年6月12日)

  2. Choromanski, K. M., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Lopez, V., … & Weller, A. (2020). Rethinking Attention with Performers. arXiv preprint arXiv:2009.14794. Retrieved from https://arxiv.org/abs/2009.14794 (発表日: 2020年9月23日)

  3. Dao, T. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Advances in Neural Information Processing Systems, 35. Retrieved from https://arxiv.org/abs/2205.14135 (発表日: 2022年5月27日)

  4. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv preprint arXiv:2307.08691. Retrieved from https://arxiv.org/abs/2307.08691 (発表日: 2023年7月12日)

  5. Ainslie, J., Lee, J., Glass, J., & Fry, B. (2023). GQA: Training Generalized Multi-Query Attention Models from Multi-Head Attention Data. arXiv preprint arXiv:2305.13245. Retrieved from https://arxiv.org/abs/2305.13245 (発表日: 2023年5月26日)

  6. Li, Z., Zhao, Y., Feng, X., Zeng, R., & Zeng, S. (2024). Attention with linear complexity and sub-linear memory for long context window with Sparse-FlashAttention. arXiv preprint arXiv:2406.09638. Retrieved from https://arxiv.org/abs/2406.09638 (発表日: 2024年6月14日)

  7. Liu, H., Li, T., Ma, H., Huang, Y., & Li, D. (2023). Ring Attention with Blockwise Transformers for Long Sequences. arXiv preprint arXiv:2310.01889. Retrieved from https://arxiv.org/abs/2310.01889 (発表日: 2023年10月24日)

  8. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint arXiv:2312.00752. Retrieved from https://arxiv.org/abs/2312.00752 (発表日: 2023年12月1日)

  9. Sun, R., Zhou, Y., Li, S., Zhao, X., Xu, W., & Li, R. (2024). Data-Parallel Attention with Shared Key-Value Cache for Efficient Long Context LLMs. arXiv preprint arXiv:2405.06456. Retrieved from https://arxiv.org/abs/2405.06456 (発表日: 2024年5月10日)

ライセンス:本記事のテキスト/コードは特記なき限り CC BY 4.0 です。引用の際は出典URL(本ページ)を明記してください。
利用ポリシー もご参照ください。

コメント

タイトルとURLをコピーしました