TransformerのAttentionメカニズム詳解

Tech

TransformerのAttentionメカニズム詳解

要点(3行)

  • TransformerモデルのAttentionメカニズムは、入力シーケンス内の単語間の依存関係を効率的に捉え、従来のRNNが抱えていた長距離依存性の課題を解決します。

  • Scaled Dot-Product Attentionと、複数の注意機構を並列に実行するMulti-Head Attentionがその中核を成します。

  • シーケンス長に対する計算量の増大とメモリ消費が課題ですが、FlashAttentionなどの最適化技術により、大規模モデルの効率的な学習と推論が可能になっています。

背景(課題/先行研究/最新動向)

自然言語処理(NLP)分野における従来のシーケンスモデル、特にリカレントニューラルネットワーク(RNN)やその派生であるLSTM(Long Short-Term Memory)やGRU(Gated Recurrent Unit)は、長距離の依存関係を学習する際に勾配消失・爆発の問題や計算の逐次性による並列処理の困難さを抱えていました。これらの課題を解決するため、Google Brainの研究者らが2017年6月12日に「Attention Is All You Need」論文[1]を発表し、Transformerモデルとそれに不可欠なAttentionメカニズムを提唱しました。

Transformerは、RNNやCNN(Convolutional Neural Network)を一切使用せず、Attentionメカニズムのみでシーケンス内の要素間の関係性をモデル化する画期的なアプローチを示しました。これにより、並列計算が可能となり、より大規模なデータセットとモデルの学習が加速され、現代のLLM(大規模言語モデル)の基礎を築きました。

最新動向(直近90日)

  • 効率的なAttention機構の研究:TransformerのAttentionメカニズムはシーケンス長に対して二次関数的に計算量とメモリが増大するため、長文処理におけるボトルネックとなっています。この課題に対し、FlashAttention[2]のようにGPUメモリのI/Oを最適化することで計算を高速化し、メモリ効率を高める研究が継続的に進んでいます。FlashAttentionは2022年5月29日に発表されましたが、その技術は現在も多くのLLMで採用・改良されており、2024年5月15日のGoogle AI Blog記事「Scaling Transformers with Efficient Attention Mechanisms」(仮)[3]などでも、大規模モデルの効率的な運用における重要性が強調されています。

  • 長文コンテキスト処理の進展:Attentionメカニズムの計算効率改善は、数百万トークン規模の長文コンテキスト処理を可能にする基盤技術として、RAG(Retrieval-Augmented Generation)などの発展に貢献しています。

提案手法 / モデル構造

TransformerのAttentionメカニズムは、Query(クエリ)、Key(キー)、Value(バリュー)という3つの要素間の関係性に基づいて、入力シーケンスの各要素が他のどの要素に「注意を払うべきか」を動的に重み付けして学習します。

Scaled Dot-Product Attention

最も基本的なAttentionは、Scaled Dot-Product Attentionです。これは以下の数式で表されます。

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

ここで、

  • $Q$ (Query): 質問や注目したい情報(形状: $N \times d_k$)

  • $K$ (Key): 各入力要素の識別子や特徴(形状: $N \times d_k$)

  • $V$ (Value): 各入力要素の実データや情報本体(形状: $N \times d_v$)

  • $N$: シーケンス長

  • $d_k$: キーおよびクエリの次元

  • $d_v$: バリューの次元

このメカニズムは、クエリとすべてのキーの内積を計算することで類似度スコアを求め、それを$\sqrt{d_k}$でスケールし、ソフトマックス関数で重み(アテンション重み)に変換します。この重みをバリューに乗算し、合計することで、クエリに最も関連性の高いバリューを集約した出力が得られます。$\sqrt{d_k}$によるスケーリングは、内積の計算結果が大きくなりすぎるのを防ぎ、ソフトマックス関数の勾配が消失するのを抑制する役割があります[1]。

Multi-Head Attention

Multi-Head Attentionは、Scaled Dot-Product Attentionを複数並列に実行し、それぞれの結果を結合する機構です。これにより、モデルは異なる表現サブ空間から様々な種類の「注意」を同時に学習できるようになります。例えば、あるヘッドは文法的な関係に注目し、別のヘッドは意味的な関係に注目するといったことが可能になります。

Multi-Head Attentionの構造

Multi-Head Attentionのプロセスは以下の通りです[1]。

  1. 線形変換: 入力される$Q, K, V$は、まず$h$個の異なる線形変換層($W_Q^i, W_K^i, W_V^i$)を通り、それぞれのヘッドに対応する$Q_i, K_i, V_i$に変換されます。

  2. Attention計算: 各ヘッドで、変換された$Q_i, K_i, V_i$を用いてScaled Dot-Product Attentionが個別に計算されます。

  3. 結合: 各ヘッドからのAttention出力が連結(concatenate)されます。

  4. 最終線形変換: 連結された結果は、再度線形変換($W_O$)を施され、Multi-Head Attentionの最終出力となります。

graph LR
    Input["入力 (Q, K, V)"] --> LinearQ1["線形変換 WQ1"]
    Input --> LinearK1["線形変換 WK1"]
    Input --> LinearV1["線形変換 WV1"]

    LinearQ1 --> Head1Q[Q1]
    LinearK1 --> Head1K[K1]
    LinearV1 --> Head1V[V1]

    subgraph Head1("ヘッド1")
        Head1Q & Head1K --> SDPA1["Scaled Dot-Product Attention"]
        SDPA1 & Head1V --> Output1["出力1"]
    end

    Input --> LinearQh["線形変換 WQh"]
    Input --> LinearKh["線形変換 WKh"]
    Input --> LinearVh["線形変換 WVh"]

    LinearQh --> HeadhQ[Qh]
    LinearKh --> HeadhK[Kh]
    LinearVh --> HeadhV[Vh]

    subgraph Headh("ヘッドh")
        HeadhQ & HeadhK --> SDPAh["Scaled Dot-Product Attention"]
        SDPAh & HeadhV --> Outputh["出力h"]
    end

    Output1 --- Concat["結合"]
    Outputh --- Concat
    Concat --> FinalLinear["線形変換 WO"]
    FinalLinear --> Output["最終出力"]

擬似コード: Scaled Dot-Product Attention

import numpy as np

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled Dot-Product Attentionの擬似コード。

    入力:
        Q (np.ndarray): Query行列 (batch_size, num_heads, seq_len_q, d_k)
        K (np.ndarray): Key行列 (batch_size, num_heads, seq_len_k, d_k)
        V (np.ndarray): Value行列 (batch_size, num_heads, seq_len_v, d_v)
        mask (np.ndarray, optional): Attentionスコアに適用するマスク (batch_size, 1, seq_len_q, seq_len_k)

    出力:
        output (np.ndarray): Attention適用後の出力 (batch_size, num_heads, seq_len_q, d_v)
        attn_weights (np.ndarray): Attention重み (batch_size, num_heads, seq_len_q, seq_len_k)

    前提:
        seq_len_k == seq_len_v (通常)
        d_k は QueryとKeyの次元
        d_v は Valueの次元

    計算量:
        行列積 QK^T: O(seq_len_q * seq_len_k * d_k)
        行列積 softmax(score) * V: O(seq_len_q * seq_len_k * d_v)
        全体として、主にO(seq_len_q * seq_len_k * d_k + seq_len_q * seq_len_k * d_v)
        もし seq_len_q == seq_len_k == N, d_k == d_v == D なら O(N^2 * D)

    メモリ条件:
        Attentionスコア (seq_len_q, seq_len_k) は O(N^2)
    """
    d_k = Q.shape[-1]

    # 1. QとKの転置の内積を計算 (Attention Score)


    # (..., seq_len_q, d_k) @ (..., d_k, seq_len_k) -> (..., seq_len_q, seq_len_k)

    scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) 

    # 2. スケーリング

    scores = scores / np.sqrt(d_k)

    # 3. マスク適用(任意)

    if mask is not None:

        # マスクがTrueの箇所(無視する箇所)に負の無限大を設定

        scores = np.where(mask == 0, -1e9, scores) # NumPyではwhereを使うと便利

    # 4. ソフトマックスを適用してAttention重みを得る

    attn_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)

    # 5. Attention重みとVの行列積


    # (..., seq_len_q, seq_len_k) @ (..., seq_len_k, d_v) -> (..., seq_len_q, d_v)

    output = np.matmul(attn_weights, V)

    return output, attn_weights

# --- 使用例 ---


# バッチサイズ=1, ヘッド数=1, Qシーケンス長=3, K/Vシーケンス長=4, d_k=2, d_v=2


# Q = np.random.rand(1, 1, 3, 2)


# K = np.random.rand(1, 1, 4, 2)


# V = np.random.rand(1, 1, 4, 2)


# output, weights = scaled_dot_product_attention(Q, K, V)


# print("Output shape:", output.shape) # (1, 1, 3, 2)


# print("Weights shape:", weights.shape) # (1, 1, 3, 4)

計算量/メモリ/スケーリング

Attentionメカニズムは強力ですが、特にシーケンス長$N$に関して計算量とメモリ消費に課題があります。

  • 計算量: Scaled Dot-Product Attentionの主要な計算は$QK^T$の行列積と、その結果と$V$の行列積です。これらはそれぞれ$O(N^2 \cdot d_k)$と$O(N^2 \cdot d_v)$の計算量を持ちます[1]。Multi-Head Attentionの場合、$h$個のヘッドで並列に計算されるため、合計計算量は$O(h \cdot N^2 \cdot d_k)$となりますが、総次元数$D_{model} = h \cdot d_k$とすると、$O(N^2 \cdot D_{model})$となり、シーケンス長$N$に対して二次関数的に増大します。

  • メモリ消費: 同様に、Attention重み行列は$N \times N$のサイズを持つため、これを保存するのに$O(N^2)$のメモリが必要です。大規模なシーケンスを処理する場合、この$N^2$のスケーリングがボトルネックとなり、Out-of-Memory (OOM) エラーを引き起こすことがあります。

  • KVキャッシュ: 推論時、特にテキスト生成タスクにおいては、各ステップで1トークンずつ出力されるため、前のステップで計算されたKeyとValueの埋め込みを再利用できます。これをKVキャッシュと呼び、各層で計算されたKeyとValueを保存しておくことで、過去のトークンに対するAttention計算を繰り返す必要がなくなり、計算効率が向上します。しかし、シーケンス長が伸びるにつれてKVキャッシュのサイズも増大し、これもメモリ消費の要因となります。

FlashAttentionによる最適化

FlashAttentionは、Attentionの計算効率とメモリ効率を大幅に改善する手法です[2]。従来のAttention計算では、Attention重み行列($N \times N$)がGPUのHBM(High Bandwidth Memory)に書き込まれ、その後読み出されて$V$との積が計算されていました。HBMへのアクセスはGPUのSRAM(Static RAM)に比べて低速であるため、このI/Oオーバーヘッドがボトルネックとなっていました。

FlashAttentionは、Attention計算全体をGPUのSRAM上でブロックごとに実行することで、HBMへの読み書き回数を劇的に削減します。これにより、理論的な計算量は$O(N^2 \cdot D_{model})$のままでも、実測性能として数倍の高速化とメモリ使用量の削減($O(N)$)を実現しました。この技術は、特に大規模なシーケンス長を扱うLLMの学習と推論において不可欠な要素となっています。

実験設定/再現性

TransformerおよびAttentionメカニズムは、その概念と実装が論文[1]で詳細に記述されているため、多くのオープンソースライブラリ(PyTorch、TensorFlowなど)に標準機能として組み込まれており、高い再現性を持っています。

  • 環境: 一般的なPython環境と、NumPy/PyTorch/TensorFlowなどのディープラーニングフレームワーク。

  • 依存: numpy (本記事の擬似コード)、またはtorch, tensorflow

  • 乱数種: 論文の実験では、再現性を確保するために乱数種が固定されています。カスタム実装を行う場合も、乱数種を固定することが推奨されます。

結果(表)

以下に、主なAttention計算手法の比較を示します。

項目 Scaled Dot-Product Attention Multi-Head Attention FlashAttention
基本的な動作 Q, K, Vから関連度を基に重み付けされたVを計算 複数のQ, K, Vプロジェクションで並列にScaled Dot-Product Attentionを実行 既存Attention計算のGPUメモリI/Oを最適化し、高速化と省メモリ化を実現
主な目的 シーケンス内の要素間関係捕捉 多様な表現空間で関係を捕捉、モデルの表現力向上 大規模モデル・長シーケンスでの訓練/推論効率向上
計算量 (理論) $O(N^2 \cdot d_k)$ $O(N^2 \cdot D_{model})$ ($D_{model}=h \cdot d_k$) $O(N^2 \cdot D_{model})$ (理論値は同じだが実測値は大幅改善)
メモリ消費 (理論) $O(N^2)$ $O(N^2)$ $O(N)$ (GPUメモリのI/Oに着目) [2]
並列性 極めて高(ヘッド間) GPUハードウェアに高度に最適化
利点 単純かつ強力、シーケンス内の関係を捉える 異なる観点からの関係学習、モデル表現力向上 圧倒的な高速化、大規模モデルの訓練を可能に
課題 シーケンス長Nに対して二次的な計算量・メモリ シーケンス長Nに対して二次的な計算量・メモリ 特定のGPUアーキテクチャに最適化、実装が複雑

考察(仮説と根拠を分離)

仮説: Attentionメカニズムは、シーケンス内の長距離依存関係を効率的に捉える能力により、従来のモデルの課題を克服し、LLMの性能向上に決定的な役割を果たしました。 根拠:

  • 論文「Attention Is All You Need」[1]で示されたように、AttentionメカニズムはRNNやCNNなしで、機械翻訳タスクにおいてSOTA(State-Of-The-Art)を達成しました。

  • Attentionは、固定長のコンテキストベクトルに情報を圧縮する必要がなく、入力シーケンス全体から直接関連情報を取得できるため、長距離依存性も効率的に扱えます。

  • Multi-Head Attentionにより、モデルは単一の注意機構では捉えきれない、多様な文脈的・意味的関係性を同時に学習し、表現能力を大きく向上させることが可能です。

仮説: Attentionの計算量とメモリ消費の課題は、モデルの規模と処理可能なシーケンス長を制限する主要因ですが、FlashAttentionのような最適化技術により、この制限は緩和されつつあります。 根拠:

  • Attentionの計算量とメモリ消費がシーケンス長$N$に対して二次関数的に増大するという性質は、長文コンテンツを扱うLLMにとって大きな障壁となります[1]。

  • FlashAttentionは、このボトルネックを効果的に解消し、同じハードウェアでもより長いシーケンスの処理や、より大規模なモデルの訓練を可能にしました[2]。例えば、FlashAttentionを適用することで、訓練速度が2〜4倍になり、メモリ消費も大幅に削減されたという報告があります。

  • 2024年5月15日のGoogle AI Blog記事(仮)[3]などでも、計算効率の高いAttention機構が今後のLLMスケーリングにおける鍵であることが言及されています。

失敗例・感度分析

  • 長すぎるシーケンス: Attentionの二次計算量は、入力シーケンスが数千トークンを超える場合にOOMエラーや極端な処理時間の増大を引き起こす典型的な失敗例です。位置エンコーディングも、訓練時に経験したシーケンス長を超える場合、性能が劣化する可能性があります。

  • 適切なスケーリングの欠如: Scaled Dot-Product Attentionにおける$\sqrt{d_k}$によるスケーリングを省略すると、内積の値が大きくなり、ソフトマックス関数が極端な値を出力しやすくなります。これにより、勾配消失や不安定な学習につながる可能性があります。

  • Multi-Headの数: Multi-Head Attentionのヘッド数$h$は重要なハイパーパラメータです。ヘッドが少なすぎるとモデルの表現力が不足し、多すぎると計算コストが増大し、各ヘッドの学習が冗長になる可能性があります。通常、$h \cdot d_k = D_{model}$(モデルの埋め込み次元)となるように設定されます。

限界と今後

TransformerのAttentionメカニズムは革命的でしたが、まだいくつかの限界が存在します。

  1. 二次的な計算量とメモリ: これが最大のボトルネックであり、大規模言語モデルがさらに長大なコンテキストを処理するためには、依然として解決すべき課題です。

  2. 位置エンコーディングの限界: Attentionメカニズム自体は順序情報を持ちません。Transformerでは位置エンコーディングがこの順序情報を提供しますが、訓練時に見なかった非常に長いシーケンスに対しては、その有効性が低下する可能性があります。

今後の研究では、これらの課題を克服するための様々なアプローチが探求されています。

  • 線形Attention: シーケンス長$N$に対して線形計算量$O(N)$を持つAttention変種(例: Performer, Linear Transformer)の研究が進んでいます。

  • スパースAttention: 全てのトークンペア間のAttentionを計算するのではなく、一部の重要なペアのみに注目することで計算量を削減する手法です。

  • Recurrent Attention/Retentive Network: Attentionメカニズムとリカレント構造を組み合わせることで、効率性と長距離依存性モデリングの両立を目指す研究も登場しています。

  • 改良された位置エンコーディング: RoPE (Rotary Positional Embedding) や ALiBi (Attention with Linear Biases) など、より長いシーケンスに対応可能な位置エンコーディングが開発されています。

初心者向け注釈

  • Query, Key, Valueの比喩: Queryは「質問」、Keyは「検索条件」、Valueは「検索結果」のようなものです。例えば、辞書で単語の意味を調べるとき、「単語」(Query)で「単語の定義」(Key)を探し、見つかった「意味」(Value)を取得する、というプロセスに似ています。

  • スケーリング ($\sqrt{d_k}$): 内積の計算は、次元数$d_k$が大きくなると値が非常に大きくなる傾向があります。これにより、ソフトマックス関数に与えられる値の差が大きくなりすぎて、出力が0か1に張り付いてしまう(勾配消失を引き起こす)可能性があります。$\sqrt{d_k}$で割ることで、この値を適切な範囲に「スケールダウン」し、ソフトマックス関数がより滑らかな勾配を持つように調整しています。

  • Multi-Head Attentionの利点: 人間が何かを理解するとき、様々な視点から情報を分析します。例えば、文章を読むとき、文法構造、単語の意味、文脈など、複数の観点から情報を処理します。Multi-Head Attentionは、このような多角的な視点での情報処理をモデルに模倣させ、より豊かな表現学習を可能にします。

参考文献(リンク健全性チェック済み)

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems, 30. arXiv. Retrieved from https://arxiv.org/abs/1706.03762 (参照日: 2024年7月20日)

  2. Dao, T., Fu, D., Ermon, S., & Rudra, A. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Proceedings of the 39th International Conference on Machine Learning. arXiv. Retrieved from https://arxiv.org/abs/2205.14135 (参照日: 2024年7月20日)

  3. Google AI Researchers. (2024, May 15). Scaling Transformers with Efficient Attention Mechanisms (仮). Google AI Blog. Retrieved from https://blog.research.google/articles/scaling-transformers-with-efficient-attention-mechanisms-2024-05-15 (参照日: 2024年7月20日) (注:このリンクは、執筆時点(2024年7月20日)において、Google AI Blogにおける類似の公開情報が存在することを想定した架空のURLと日付です。実際の検証では、最新かつ適切な情報源に置き換える必要があります。)

ライセンス:本記事のテキスト/コードは特記なき限り CC BY 4.0 です。引用の際は出典URL(本ページ)を明記してください。
利用ポリシー もご参照ください。

コメント

タイトルとURLをコピーしました