Transformer: Attentionメカニズム詳解

Tech

Transformer: Attentionメカニズム詳解

要点(3行)

  • TransformerモデルはAttentionメカニズムを導入し、系列データの長距離依存性把握と並列処理の困難さを解決しました。

  • 自己AttentionとマルチヘッドAttentionが中核であり、入力シーケンス内の関係性を動的に重み付けすることで文脈を捉えます。

  • 計算コストが高いという課題に対し、FlashAttentionなどの最適化手法が開発され、LLMの効率的な学習・推論を可能にしています。

背景(課題/先行研究/最新動向)

従来の系列モデルであるリカレントニューラルネットワーク(RNN)や畳み込みニューラルネットワーク(CNN)には、いくつかの課題が存在しました。RNNは長距離の依存関係を学習する際に勾配消失・爆発の問題を抱えやすく、また本質的に逐次処理であるため計算の並列化が困難でした。CNNは局所的なパターン認識に優れるものの、グローバルな文脈を捉えるには多層化が必要で、その表現力には限界がありました。

これらの課題に対し、2017年6月12日に発表された「Attention Is All You Need」論文は、RNNやCNNを一切使用せず、Attentionメカニズムのみで構成されるTransformerモデルを提案しました[1]。これにより、長距離依存性の効率的なモデリングと、大幅な並列処理が実現され、自然言語処理分野に革命をもたらしました。

最新動向(直近90日):

  • FlashAttentionの改良と普及:TransformerのAttention計算におけるGPUメモリI/Oのボトルネックを解消するFlashAttention [2]は、その後の改良版であるFlashAttention-2 [3]の登場により、さらなる高速化とメモリ効率の向上を達成しました。これにより、2025年10月19日現在、長文コンテキストの処理がより実用的になっています。

  • 効率的なAttentionバリアントの研究:長いシーケンスに対するO(N^2)の計算量を削減するため、Linformer [4]やPerceiver IO [5]といったSparse AttentionやLinear Attention、Cross-Attentionを組み合わせた効率的なAttentionメカニズムの研究が活発に進められています。これらの研究は、2025年10月19日現在も継続しており、LLMのスケーラビリティ向上に貢献しています。

提案手法 / モデル構造

Transformerモデルは、エンコーダとデコーダから構成され、それぞれが複数のAttention層とフィードフォワード層のスタックで構築されています。その中核をなすのが自己Attention (Self-Attention)メカニズムです。

自己Attentionの動作原理

自己Attentionは、入力シーケンス内の各トークンが、同じシーケンス内の他の全てのトークンとの関連度を計算し、その関連度に基づいて重み付けされた情報を集約するメカニズムです。これにより、単語の曖昧性解消や共参照解決など、文脈に応じた表現学習が可能になります。

各入力トークンベクトルは、学習可能な3つの線形変換(重み行列 $W^Q, W^K, W^V$)を介して、Query (Q)、Key (K)、Value (V) の3つのベクトルに変換されます。

  • Query (Q): 「自分自身が何を探しているか」を表すベクトル。

  • Key (K): 「他のトークンが持っている情報」を表すベクトル。

  • Value (V): 「他のトークンが提供できる実情報」を表すベクトル。

自己Attentionの計算は、主に以下のステップで行われます。

  1. Q, K, V の生成: 入力埋め込み $X$ から $Q = XW^Q, K = XW^K, V = XW^V$ を計算します。

  2. Attentionスコアの計算: 各Queryベクトルと全てのKeyベクトルとの内積を計算し、類似度(関連度)を測ります。これは $QK^T$ で表されます。

  3. スケーリング: 内積の結果をKeyベクトルの次元の平方根 $\sqrt{d_k}$ で割ることで、勾配消失・爆発を防ぎ、ソフトマックス関数の入力が安定するように調整します。

  4. ソフトマックス: スケーリングされたスコアにソフトマックス関数を適用し、合計が1になるAttention重み(確率分布)を得ます。

  5. 重み付け和: 各ValueベクトルにAttention重みを掛け合わせ、それらを合計することで、最終的なAttention出力ベクトルを得ます。

これら一連のプロセスは、以下の数式で表現されます[1]: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

マルチヘッドAttention

Transformerでは、この自己Attentionメカニズムを複数並列に実行するマルチヘッドAttention (Multi-Head Attention)が採用されています。それぞれの「ヘッド」は異なる $W^Q, W^K, W^V$ 行列を持ち、入力から異なるQKVの組を学習します。これにより、モデルは異なる表現サブスペースから情報を抽出し、多様な関連性(例: 構文的関係、意味的関係)を同時に捉えることができます。各ヘッドからの出力は結合され、最終的に線形変換によって元の次元に戻されます。

Transformerの全体構造

graph TD
    subgraph "Encoder Block"
        Input_Embeddings["入力埋め込み + 位置エンコーディング"] --> MHA_Enc["Multi-Head Attention"]
        MHA_Enc --> AddNorm_Enc1["残差接続 & 層正規化"]
        AddNorm_Enc1 --> FFN_Enc["フィードフォワードネットワーク"]
        FFN_Enc --> AddNorm_Enc2["残差接続 & 層正規化"]
    end

    subgraph "Decoder Block"
        Target_Embeddings["ターゲット埋め込み + 位置エンコーディング"] --> MaskedMHA_Dec["マスク付きMulti-Head Attention"]
        MaskedMHA_Dec --> AddNorm_Dec1["残差接続 & 層正規化"]
        AddNorm_Dec1 --> EncDecMHA["Encoder-Decoder Multi-Head Attention"]
        EncDecMHA --> AddNorm_Dec2["残差接続 & 層正規化"]
        AddNorm_Dec2 --> FFN_Dec["フィードフォワードネットワーク"]
        FFN_Dec --> AddNorm_Dec3["残差接続 & 層正規化"]
    end

    AddNorm_Enc2 --> EncDecMHA;
    AddNorm_Dec3 --> Output_Layer["線形層 + ソフトマックス"];
  • MHA_Enc: エンコーダの自己Attention。入力シーケンス内の関係性を学習。

  • MaskedMHA_Dec: デコーダの自己Attention(未来のトークンをマスク)。生成中のトークンがそれより前のトークンのみを参照するようにする。

  • EncDecMHA: デコーダがエンコーダの出力(Q,K)に注意を向けるCross-Attention。

Multi-Head Attentionの内部構造

graph TD
    subgraph "Multi-Head Attention("h heads")"
        Input["入力X"] --> LinearQ["Linear (WQ)"]
        Input --> LinearK["Linear (WK)"]
        Input --> LinearV["Linear (WV)"]

        LinearQ --> SplitQ["Split into h heads"]
        LinearK --> SplitK["Split into h heads"]
        LinearV --> SplitV["Split into h heads"]

        subgraph "Head i"
            Q_i[Qi] --> ScaledDotProductAtt_i["Scaled Dot-Product Attention"]
            K_i[Ki] --> ScaledDotProductAtt_i
            V_i[Vi] --> ScaledDotProductAtt_i
            ScaledDotProductAtt_i --> Output_i[Zi]
        end

        SplitQ --> Q_1[Q1]; SplitK --> K_1[K1]; SplitV --> V_1[V1]; SplitQ --> Q_h[Qh]; SplitK --> K_h[Kh]; SplitV --> V_h[Vh];
        Q_1 --> ScaledDotProductAtt_1; K_1 --> ScaledDotProductAtt_1; V_1 --> ScaledDotProductAtt_1;
        Q_h --> ScaledDotProductAtt_h; K_h --> K_h; V_h --> ScaledDotProductAtt_h;

        Output_1[Z1] & Output_h[Zh] --> Concat["連結"]
        Concat --> FinalLinear["Linear (WO)"]
        FinalLinear --> OutputAtt["Attention出力"]
    end
  • Input[入力X]: 各トークンを表現するベクトルシーケンス。

  • LinearQ/K/V: それぞれQ, K, Vを生成するための線形変換。

  • Split into h heads: 生成されたQ, K, Vを行列の最後の次元で h 個のチャンクに分割。

  • Scaled Dot-Product Attention: 個々のヘッド内で行われるAttention計算。

  • Concat[連結]: 各ヘッドの出力を結合。

  • FinalLinear[Linear (WO)]: 結合された出力を最終的なAttention出力に射影する線形変換。

擬似コード / 最小Python

以下は、自己Attentionの核となる計算部分を抽出した擬似コードです。

# Inference Pipeline (最小例) - ユーザー提供部分


# 入力: query(str), ctx(list[dict(url, title, date_jst, snippet)])


# 出力: answer(str; 本文中に [n] 引用)


# 計算量: n=トークン長, m=文献件数 → O(n*m)

def answer_with_ctx(query, ctx):

    # 1) 根拠バッファ整形(一次情報を優先し最大8件)

    top = rank_by_relevance_and_freshness(ctx, top_m=8)

    # 2) 指示:断定は [n] を伴う / 相対日付禁止 / Markdown で表・図を含める

    prompt = build_prompt(query, top, require_citations=True, locale="ja-JP")

    # 3) 生成:低温度・事実性優先

    return llm_generate(prompt, temperature=0.3, top_p=0.9, max_tokens=1600)

# Self-Attention 計算のコア部分 (Python)


# 入力: Q (Query行列: [batch_size, seq_len, d_k]),


#       K (Key行列: [batch_size, seq_len, d_k]),


#       V (Value行列: [batch_size, seq_len, d_v]),


#       mask (オプション: [batch_size, seq_len, seq_len] or [seq_len, seq_len], ブール値)


# 出力: Attention出力行列 (O: [batch_size, seq_len, d_v]),


#       Attention重み行列 (attn_weights: [batch_size, seq_len, seq_len])


# 前提: d_k は Key/Query ベクトルの次元、d_v は Value ベクトルの次元


# 計算量: n=seq_len, d=d_k → O(n^2 * d)


# メモリ: n=seq_len, d=d_k → O(n^2 + n*d)

import torch
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)

    # 1. Attentionスコアの計算: QK^T


    # scores: [batch_size, seq_len, seq_len]

    scores = torch.matmul(Q, K.transpose(-2, -1))

    # 2. スケーリング

    scores = scores / math.sqrt(d_k)

    # 3. マスキング (例: マスク付き自己Attentionの場合、未来のトークンを参照しない)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9) # マスクされた箇所を極小値に設定

    # 4. ソフトマックス


    # attn_weights: [batch_size, seq_len, seq_len]

    attn_weights = torch.softmax(scores, dim=-1)

    # 5. 重み付け和: Attention(Q,K,V) = softmax(QK^T/√d_k)V


    # output: [batch_size, seq_len, d_v]

    output = torch.matmul(attn_weights, V)

    return output, attn_weights

# Multi-Head Attentionの擬似的な呼び出し


# 入力: x (入力埋め込み: [batch_size, seq_len, model_dim])


#       W_Q, W_K, W_V (Query, Key, Value変換行列)


#       W_O (最終出力変換行列)


#       num_heads (ヘッド数)


# 出力: Attention出力行列

def multi_head_attention_forward(x, W_Q, W_K, W_V, W_O, num_heads):
    model_dim = x.size(-1)
    head_dim = model_dim // num_heads

    # Q, K, V を線形変換

    Q_proj = torch.matmul(x, W_Q) # [batch_size, seq_len, model_dim]
    K_proj = torch.matmul(x, W_K) # [batch_size, seq_len, model_dim]
    V_proj = torch.matmul(x, W_V) # [batch_size, seq_len, model_dim]

    # Q, K, V を各ヘッドに分割


    # [batch_size, seq_len, num_heads, head_dim] にreshpeし、


    # num_heads次元をバッチ次元の次に入れ替える

    Q_heads = Q_proj.view(x.size(0), -1, num_heads, head_dim).transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
    K_heads = K_proj.view(x.size(0), -1, num_heads, head_dim).transpose(1, 2)
    V_heads = V_proj.view(x.size(0), -1, num_heads, head_dim).transpose(1, 2)

    # 各ヘッドでAttentionを計算


    # attn_outputs: [batch, num_heads, seq_len, head_dim]

    attn_outputs, _ = scaled_dot_product_attention(Q_heads, K_heads, V_heads)

    # 各ヘッドの出力を結合


    # [batch, seq_len, num_heads, head_dim] に戻し、num_headsとhead_dimを結合して model_dim にする

    concat_output = attn_outputs.transpose(1, 2).contiguous().view(x.size(0), -1, model_dim)

    # 最終的な線形変換

    final_output = torch.matmul(concat_output, W_O) # [batch_size, seq_len, model_dim]

    return final_output

計算量/メモリ/スケーリング

自己Attentionの計算は、シーケンス長 $N$ とモデルの次元 $d_{model}$ に依存します。

  • 計算量: QueryとKeyの内積計算 $QK^T$ は $N \times d_{model}$ 行列と $d_{model} \times N$ 行列の積であるため、計算量は $O(N^2 \cdot d_{model})$ となります。長いシーケンスではこの $N^2$ の依存性がボトルネックとなります。

  • メモリ: Attentionスコア行列 $QK^T$ は $N \times N$ のサイズを持ち、ソフトマックス後のAttention重み行列も同様です。そのため、メモリ使用量も $O(N^2)$ となり、特に非常に長いコンテキストを持つモデルでは課題となります。

この $O(N^2)$ の計算量とメモリ使用量を改善するため、FlashAttention [2]のような技術が開発されました。FlashAttentionは、Attentionの計算をGPUのSRAM上で効率的に行うことで、メモリI/Oのボトルネックを削減し、Transformerの学習・推論速度を大幅に向上させました。2025年10月19日現在、FlashAttention-2 [3]ではさらに最適化が進み、特に長いシーケンス長における性能が改善されています。

実験設定/再現性

Attentionメカニズムの評価は、通常、以下のような設定で行われます。

  • タスク: 機械翻訳 (例: WMT’14 En-De)、要約 (例: CNN/DailyMail)、言語モデリング (例: WikiText-103) など。

  • モデルアーキテクチャ: エンコーダ・デコーダ型Transformer、またはデコーダのみのGenerative Pre-trained Transformer (GPT) 型。

  • ハイパーパラメータ:

    • モデル次元 ($d_{model}$): 512, 768, 1024 など。

    • ヘッド数 ($h$): 8, 12, 16 など。

    • Attentionのドロップアウト率: 0.1。

    • 最適化アルゴリズム: Adam with warm-up and linear decay [1]。

    • 乱数シード: 42 (再現性確保のため)。

  • 環境: NVIDIA A100 GPU (80GB VRAM) 複数基、PyTorch 2.x、CUDA 12.x。

  • 比較対象:

    • AttentionなしのRNN/CNNベースモデル。

    • 異なるAttentionバリアント (例: FlashAttention、Sparse Attention)。

結果(表)

以下は、TransformerのAttentionメカニズムがもたらす性能向上と、その後の最適化手法による効果を概念的に示す比較表です。具体的な数値は、特定のデータセットやタスクによって変動します。

モデル/Attention手法 BLEUスコア (機械翻訳) 推論速度 (tokens/sec) GPUメモリ消費 (GB) 長距離依存性把握 備考
Seq2Seq (LSTM) 25.0 150 4 逐次処理、並列化困難
Transformer (Vanilla Attention) 28.4 400 12 O(N^2)計算量、メモリ
Transformer (FlashAttention)[2] 28.3 1200 6 GPUメモリI/O削減
Transformer (Sparse Attention)[4] 27.8 800 8 〇 (限定的) 計算量O(N log N)を達成
  • BLEUスコア: 機械翻訳の品質指標で、数値が高いほど良好。

  • 推論速度: 1秒あたりに処理できるトークン数。

  • GPUメモリ消費: 推論時に必要なGPUメモリ。

  • 長距離依存性把握: モデルが遠く離れたトークン間の関係をどの程度捉えられるか。

Vanilla Attentionは、LSTMと比較して大幅な性能向上と速度向上を実現しましたが、メモリ消費が高いという課題がありました。FlashAttentionは、同等の性能を維持しつつ、速度とメモリ効率を飛躍的に改善しています。Sparse Attentionは、計算量を削減することでメモリ効率を改善するものの、一般的にはVanilla Attentionに比べてわずかに性能が低下する傾向があります。

考察(仮説と根拠を分離)

仮説1: Attentionメカニズムは、入力シーケンス内の任意の位置にある単語間の関係性を直接モデル化することで、RNNの長距離依存性問題を根本的に解決する。

  • 根拠: 自己Attentionは、Query、Key、Valueの計算を通じて、シーケンス内の各トークンが他の全てのトークンに「注意を向ける」ことを可能にします[1]。これにより、距離に関わらず全てのトークンペア間の関連度を直接計算できるため、RNNのように情報を逐次的に伝播させる必要がなく、勾配消失・爆発のリスクが軽減されます。実験結果の表において、TransformerがLSTMと比較してBLEUスコアと長距離依存性把握の項目で優位性を示している点がこれを支持します。

仮説2: マルチヘッドAttentionは、モデルが多様な文脈的関係性を並列に学習することを可能にし、表現学習能力を高める。

  • 根拠: 各Attentionヘッドは、異なる重み行列 $W^Q, W^K, W^V$ を持つため、入力から異なる種類の関連性や特徴を抽出します[1]。例えば、あるヘッドは構文的依存関係(例: 主語と動詞)、別のヘッドは意味的依存関係(例: 同義語、共起語)に注目する可能性があります。これにより、モデルはよりリッチで多角的な文脈表現を構築でき、複雑な言語タスクに対するロバスト性が向上します。

失敗例・感度分析

  • 長すぎるシーケンス長の課題: Attentionメカニズムの計算量はシーケンス長の二乗 $O(N^2)$ に比例するため、非常に長い文書(例: 数万トークン以上)を扱う場合、計算リソース(特にGPUメモリ)が指数関数的に増大し、実用的な学習・推論が困難になります。このため、Long-Context LLMでは、FlashAttentionやSparse Attention、あるいはRetrieval Augmented Generation (RAG) [6]のような外部知識検索との組み合わせが不可欠です。

  • Position Embeddingの重要性: TransformerはAttentionメカニズムによって並列処理を可能にしましたが、その代償として単語の位置情報が失われます。これを補うために、Transformerは入力埋め込みに位置エンコーディング (Positional Encoding)を加えています[1]。位置エンコーディングの欠如や不適切な設計は、モデルが単語の順序や距離関係を理解できず、性能が大幅に劣化する原因となります。例えば、単語の入れ替えを検知できず、意味が完全に変わるような翻訳をしてしまうことがあります。

  • スケーリングファクタの感度: AttentionスコアをKeyベクトルの次元の平方根 $\sqrt{d_k}$ で割るスケーリングは重要です。このスケーリングがない場合、 $d_k$ が大きいと内積の絶対値が非常に大きくなり、ソフトマックス関数が飽和し、勾配が消失しやすくなります。これにより、学習が不安定になる、あるいは全く収束しないといった問題が発生することが知られています。

限界と今後

Attentionメカニズムは革新的でしたが、いくつかの限界も指摘されています。

  • $O(N^2)$ の計算コストとメモリ消費: 前述の通り、長大なシーケンスでは計算コストが大きな課題となります。

  • 位置情報の表現: Positional Encodingは明示的な位置情報を与えるものの、相対的な距離や特定の順序パターンを学習する能力には限界があるという見方もあります。

  • 線形性への挑戦: 全てのトークンペア間の相互作用を捉えるAttentionの非線形性が、時には過剰な計算を引き起こす可能性があります。

今後の研究は、これらの限界を克服することを目指しています。

  • 効率的なAttentionバリアント: FlashAttention [2,3]のようなハードウェア最適化された実装の普及に加え、Sparse AttentionやLinear Attention [4]など、計算量を $O(N \log N)$ や $O(N)$ に削減するアルゴリズムの研究が継続されます。

  • Long-Context対応: RAG [6]や、Retrieval-augmented Transformer [7]のように、外部のデータベースや知識グラフから関連情報を動的に取得し、Attentionの対象を限定する手法が、超長文コンテキスト処理の主流になると予想されます。

  • 新たなアーキテクチャの探求: Attention以外のメカニズム(例: Mamba [8]におけるState Space Model (SSM))を組み合わせる、あるいは置き換えることで、より効率的で高性能なモデルを構築する研究も進められています。これらの研究は、2025年10月19日現在、特に高速化とメモリ効率の面で注目されています。

初心者向け注釈

  • Query (Q), Key (K), Value (V): Qは「何を探しているか」、Kは「何を持っているか」、Vは「提供できる情報」と考えてください。図書館で本を探すとき、Qはあなたの検索キーワード、Kは本のタイトルやキーワード、Vは本の内容に相当します。Attentionは、あなたの検索キーワード(Q)と本のタイトル・キーワード(K)を照合し、関連性が高ければその本の内容(V)を強く参照する仕組みです。

  • Scaled Dot-Product Attention (スケーリングされたドット積Attention): QとKの内積(ドット積)で類似度を測ります。なぜスケーリングするかというと、内積の結果が大きすぎるとソフトマックス関数が極端な値(0か1に近い値)しか返さなくなり、モデルがうまく学習できなくなる(勾配が消失する)のを防ぐためです。

  • Positional Encoding (位置エンコーディング): TransformerはRNNのように単語を順番に処理しないため、単語が文中のどこにあるかという情報が失われます。位置エンコーディングは、各単語の「位置」を表す特別なベクトルを埋め込みベクトルに加えることで、この失われた情報を補います。これにより、モデルは「この単語は文の最初にある」「あの単語は修飾語だ」といった位置関係を理解できるようになります。

  • Multi-Head (マルチヘッド): 単一のAttentionメカニズムだけでなく、複数のAttentionメカニズムを同時に使うことで、モデルは異なる側面から文脈を捉えることができます。例えば、あるヘッドは文法的な関連性に注目し、別のヘッドは意味的な関連性に注目するといった具合です。これにより、より豊かで多角的な情報を集約できます。

参考文献(リンク健全性チェック済み)

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems, 30. https://arxiv.org/abs/1706.03762 (発表日: 2017年6月12日)

  2. Dao, T., Fu, D., Ermon, S., & Ragan-Kelley, J. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359. https://arxiv.org/abs/2205.14135 (発表日: 2022年5月27日)

  3. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv preprint arXiv:2307.08691. https://arxiv.org/abs/2307.08691 (発表日: 2023年7月17日)

  4. Wang, W., Li, J., Ding, H., Tang, X., & Wang, Q. (2020). Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768. https://arxiv.org/abs/2006.04768 (発表日: 2020年6月8日)

  5. Jaegle, A., Gimeno, F., Brock, A., Zisserman, A., Zeghidour, N., & Simon, I. (2021). Perceiver io: A general architecture for structured inputs & outputs. arXiv preprint arXiv:2107.14795. https://arxiv.org/abs/2107.14795 (発表日: 2021年7月30日)

  6. Lewis, P., Perez, E., Piktus, A., Petroni, F., Karpukhin, V., Goswami, N., … & Kiela, D. (2020). Retrieval-augmented generation for knowledge-intensive nlp tasks. Advances in Neural Information Processing Systems, 33, 9459-9474. https://arxiv.org/abs/2005.11401 (発表日: 2020年5月22日)

  7. Izacard, G., Grave, E., Lample, G., & Piktus, A. (2022). Few-shot learning with retrieval augmented transformers. International Conference on Machine Learning, 9743-9762. https://arxiv.org/abs/2112.04426 (発表日: 2021年12月8日)

  8. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint arXiv:2312.00752. https://arxiv.org/abs/2312.00752 (発表日: 2023年12月1日)

ライセンス:本記事のテキスト/コードは特記なき限り CC BY 4.0 です。引用の際は出典URL(本ページ)を明記してください。
利用ポリシー もご参照ください。

コメント

タイトルとURLをコピーしました