深層学習におけるattentionについて

機械学習技術 自然言語技術 人工知能技術 デジタルトランスフォーメーション技術 画像処理技術 強化学習技術 確率的生成モデル 深層学習技術 Python 本ブログのナビ
「Attention Is All You Need」について

Attention Is All You Need“という論文は、2017年にGoogleの研究者によって発表されたTransformerモデルの概要とアルゴリズム及び実装例について“でも述べているTransformerと呼ばれるニューラルネットワークモデルの提案を行った論文となる。この論文は、Transformerモデルの提案とその大幅な精度改善効果を示したことで、自然言語処理や深層学習を中心とした機械学習の領域にブレークスルーをもたらしたものとなっている。またこれらを用いたOSSは”Huggingfaceを使った文自動生成の概要“に述べているHuggingfaceに集められ、それらを用いたchatGPTもどきの開発が行われている。

この論文に関しては”【論文】”Attention is all you need”の解説“や、”[論文解説] Attention Is All You Needを解説する①“などで詳細な解説が論じられているが、ここではなるべく数式等を使わない範囲での解説を述べたいと思う。

以下にその要約を述べる。

「従来のニューラルネットワークモデルは、RNNの概要とアルゴリズム及び実装例について“で述べている再帰的な構造(RNN)やCNNの概要とアルゴリズム及び実装例について“で述べている畳み込み層(CNN)を使用していたが、長いシーケンスや大規模なデータに対して効率的でなかった。この論文では、Transformerという新しいネットワークアーキテクチャを提案し、従来のモデルに比べて高速で並列処理が可能な構造を実現しましたことについて述べている。

Transformerモデルは、Attentionメカニズムを中心としたアーキテクチャで構成されている。Attentionは、入力の重要な部分に適切に注目することで、情報を抽出するメカニズムとなる。このこの論文では、Self-Attentionと呼ばれる特殊なAttentionのバリエーションを提案している。Self-Attentionは、入力の各位置と他の全ての位置の関連性を計算することで、長い文や系列データの処理に対する重要な情報を抽出している。これにより、それらのデータの処理が効率的に行えるようになる。

Transformerモデルは、エンコーダとデコーダの2つの主要な部分から構成されている。エンコーダは、入力データの特徴表現を抽出する役割を担い、デコーダはエンコーダの表現と過去の出力を元に、次の出力を生成する。このモデルでは、入力と出力の間でAttentionメカニズムを使用して、重要な情報を適切に注目しながら処理を行っている。

Transformerモデルの特徴は、以下のようにまとめられる。

  • 長いシーケンスや大規模なデータの処理において効率的な手法となる
  • RNNやCNNと比べて並列処理が可能であり、高速な学習や推論ができる
  • Self-Attentionを活用して入力の重要な情報を抽出する
  • 転移学習の概要とアルゴリズムおよび実装例について“でも述べている転移学習や大規模なデータセットを利用した学習において優れた性能を示す」

以下にこれらの技術の詳細について述べる。まず、Attentionメカニズムとそのバリエーションの一つであるSelf-Attentionについて述べる。

Attentionメカニズムの概要

深層学習におけるAttentionは、ニューラルネットワークの一部として使用される重要な概念となる。このAttentionメカニズムは、入力の異なる部分に異なる重要度を割り当てることができるモデルの能力を指し、このメカニズムの適用が、自然言語処理や画像認識などのタスクにおいて特に有用であることが近年認識されている。

Attentionのメカニズムについて述べる前に、従来の深層学習の課題について考える。従来のAttentionが無い深層学習モデルでは、固定サイズの入力を受け取り、その入力に基づいて処理を行うことが一般的であり、入力の長さが可変である場合や、入力の一部が他の部分よりも重要である場合に困難が生じる。このような課題は、対話形式の自然文入出力のケースや、機械翻訳における単語の対応付けや、画像キャプショニングにおける画像の特定の領域への注目など、さまざまなタスクで現れる。

Attentionメカニズムは、通常、エンコーダとデコーダの2つの主要な部分から構成される「エンコーダ-デコーダ」モデルに組み込まれた構成となっており、このモデルで、エンコーダは、入力データを固定次元の表現に変換する役割を果たし、デコーダはその表現を使用して出力を生成する役割を担っている。Attentionメカニズムは、エンコーダの中間表現とデコーダの状態との関連性を計算し、エンコーダの各部分がデコーダの現在の状態にどれだけ関連しているかを示す(重みを得る)ために使用される。

Attentionメカニズムはこの注意するポイントを抽出するという機能により、モデルは、入力の一部に集中することができ、長い入力シーケンスや複雑な関係を持つデータに対しても効果的な処理が可能となり、既存の深層学習の課題に対応することができるようになる。

Attentionメカニズムの具体的なステップは以下のようになる。

  1. 関連性の計算:エンコーダの中間表現とデコーダの状態の間の関連性を計算する。関連性の計算には、主に内積や類似度関数(例:ドット積、コサイン類似度)が使用される。
  2. 重みの計算:関連性を正規化して重みを計算する。ここでは通常、”ソフトマックス関数の概要と関連アルゴリズム及び実装例について“で述べられているソフトマックス関数が使用される。ソフトマックス関数は、関連性の値を確率として解釈するために使用され、重みとして解釈できるようになる。
  3. 加重平均の計算:エンコーダの中間表現と重みの加重平均を計算する。重みの値が高いエンコーダの部分にはより注意を払いながら、適切な情報を抽出する。
  4. 出力の生成:加重平均を使用してデコーダは次の出力を生成する。これにより、重要な情報により注意を払いながら、適切な情報を出力に反映させることができるようになる。

以下それらの詳細について述べる。

Attentionメカニズムにおける関連性の計算

Attentionメカニズムにおける関連性の計算方法は、一般的には内積や類似度関数を使用して行われる。具体的な手順は以下のようになる。

  1. エンコーダの中間表現とデコーダの状態を取得する。エンコーダは通常、系列データや画像の特徴表現などを入力として受け取り、中間表現を生成する。デコーダの状態は、デコーダの現在の状態や出力履歴などの情報となる。
  2. 関連性を計算するために、エンコーダの中間表現とデコーダの状態の間で内積や類似度関数を計算する。内積は、中間表現と状態の要素ごとの積を取り、それらの総和を計算することで実現される。類似度関数は、中間表現と状態の間の類似性を計算するために使用される。代表的な類似度関数には、ドット積やコサイン類似度などがある。
  3. 内積や類似度関数の結果を正規化するために、ソフトマックス関数を適用する。ソフトマックス関数により、関連性の値が確率として解釈できるようになる。ソフトマックス関数は、関連性の値を正規化し、合計値が1になるようにしている。
  4. ソフトマックス関数を適用した後、各エンコーダの中間表現には重みが割り当てられる。重みは、デコーダが注目すべきエンコーダの部分を示す。重みが大きいほど、そのエンコーダの部分により注意が向けられることを意味している。

関連性の計算方法はタスクやモデルの設計によって異なる場合があり、例えば、Transformerモデルでは、エンコーダとデコーダの間で行列の積(内積)を計算し、ソフトマックス関数を適用して関連性を得ており、この関連性は、エンコーダの中間表現をデコーダの位置に対して重み付けするために使用されているが、その他にもAttentionメカニズムを用いて異なる範囲の関連性を計算するGlobal/Local Attention方式、文脈情報を考慮しながら関連性を計算するコンテキストAttention方式、キーと値のペアに基づいて関連性を計算するキーバリューアテンション方式などがある

Attentionメカニズムにおける重みの計算方法

Attentionメカニズムにおける重みの計算方法は、一般的にソフトマックス関数を使用して行われる。以下に、重みの計算手順について述べる。

  • 関連性の計算:エンコーダの中間表現とデコーダの状態の間の関連性を計算する。この関連性は、内積や類似度関数(例:ドット積、コサイン類似度)を使用して計算される。
  • ソフトマックス関数の適用:関連性を正規化するために、ソフトマックス関数を適用する。ソフトマックス関数は、関連性の値を確率として解釈するために使用され、以下の式で表される。

Softmax

ここで、softmaxはソフトマックス関数の出力であり、x_iは関連性の値となる。また、ソフトマックス関数は、関連性の値を正規化し、合計値が1になるようにする。これにより、各エンコーダの中間表現には重みが割り当てられる。

  • 重みの利用:ソフトマックス関数によって得られた重みは、各エンコーダの中間表現に対する重要度を示すものとなる。これは、重みが大きいエンコーダの部分にはより注意が向けられることを意味し、これにより、デコーダは注目すべきエンコーダの部分から適切な情報を抽出することができるようになる。

重みの計算手順は一般的には上記のようになるが、実際のモデルやタスクによって微妙な変更が加えられることがしばしばある。例えば、一部のモデルでは、関連性に対して追加の非線形変換(例:多層パーセプトロン)が適用される場合があり、またそれら以外にも、Attentionメカニズムは、モデルの性能を向上させるためにさまざまな変種や改良が提案されている。

Attentionメカニズムにおける加重平均の計算

Attentionメカニズムにおける加重平均の計算は、重み付けされたエンコーダの中間表現の加重平均を計算することを指している。以下に、加重平均の計算手順について述べる。

  • エンコーダの中間表現と重みを取得する。エンコーダは通常、系列データや画像の特徴表現などを入力として受け取り、中間表現を生成する。重みは、Attentionメカニズムによって計算され、各エンコーダの中間表現に割り当てられる。
  • 各エンコーダの中間表現に対して、重みを乗算する。これにより、重み付けされた中間表現が得られる。
  • 重み付けされた中間表現の加重平均を計算する。通常、加重平均は次の式で計算される。

Weighted Average

ここで、weight_iはエンコーダの中間表現に割り当てられた重みであり、encoder%20representation_iは対応するエンコーダの中間表現となる。

この加重平均の計算により、重みが大きいエンコーダの部分により注意が向けられながら、適切な情報を抽出することができるようになり、。この加重平均を用いて、デコーダが次の出力を生成する。

なお、加重平均の上記の計算手順は一般的なものであり、実際のモデルやタスクによって微妙な変更が加えられることがある。また、Attentionメカニズムのバリエーションによっては、加重平均の代わりに異なる処理が行われることもある。

Attentionメカニズムにおける出力の生成

Attentionメカニズムにおける出力の生成は、加重平均されたエンコーダの中間表現を使用して行われる。以下に、出力の生成の手順について述べる。

  1. 加重平均されたエンコーダの中間表現を取得する。これは、Attentionメカニズムによって重み付けされたエンコーダの中間表現の加重平均となる。重みの大きいエンコーダの部分により注意が向けられており、適切な情報が集約されている。
  2. 取得した加重平均の中間表現を入力として、デコーダに渡す。デコーダは、加重平均の中間表現を使用して次の出力を生成する。
  3. デコーダは、加重平均の中間表現や自身の状態(過去の出力など)を元に、タスクに応じた適切な処理を行い、次の出力を生成します。デコーダは通常、再帰的な構造(RNNやTransformer)を持ち、前の出力や内部状態を参照しながら次の出力を予測する。
  4. デコーダは出力を生成するたびに、次の入力として自身の状態を更新し、次の出力を生成するための情報を蓄積する。これにより、系列データ(文章、音声、時系列データなど)の生成が可能となる。

出力の生成は、Attentionメカニズムを通じて注目すべき情報がデコーダに提供されるため、より適切な予測や生成が可能となる。また、Attentionメカニズムには、入力の重要な部分に適切に注目することで、長いシーケンスや複雑なデータの処理を効果的に行うことができるという特徴がある。

ただしこのメカニズムは、具体的な出力の生成手順はモデルやタスクによって異なる場合がある。たとえば、自然言語処理のタスクでは、デコーダが単語や文字の予測を行うことが一般的だが、画像生成のタスクでは、デコーダが画像のピクセル値を生成するなど、タスクに応じた適切な処理が行われている。

AttentionメカニズムのPythonでの実装

以下に具体的なPythonでのAttentionメカニズムの例を示す。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, query_dim, key_dim, value_dim):
        super(Attention, self).__init__()
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim = value_dim

        self.query_layer = nn.Linear(query_dim, key_dim)
        self.key_layer = nn.Linear(key_dim, key_dim)
        self.value_layer = nn.Linear(key_dim, value_dim)

    def forward(self, query, keys, values):
        """
        query: shape (batch_size, query_dim)
        keys: shape (batch_size, seq_length, key_dim)
        values: shape (batch_size, seq_length, value_dim)
        """
        batch_size, seq_length, _ = keys.size()

        # Queryをkey_dim次元に変換
        query = self.query_layer(query)  # shape (batch_size, key_dim)

        # Keysをkey_dim次元に変換
        keys = self.key_layer(keys)  # shape (batch_size, seq_length, key_dim)

        # Valuesをvalue_dim次元に変換
        values = self.value_layer(values)  # shape (batch_size, seq_length, value_dim)

        # Attentionスコアを計算
        attention_scores = torch.matmul(query.unsqueeze(1), keys.transpose(1, 2))  # shape (batch_size, 1, seq_length)

        # Attentionスコアを正規化
        attention_scores = F.softmax(attention_scores, dim=2)  # shape (batch_size, 1, seq_length)

        # Weighted sumを計算
        attended_values = torch.matmul(attention_scores, values)  # shape (batch_size, 1, value_dim)

        # 結果をsqueezeしてreturn
        return attended_values.squeeze(1)  # shape (batch_size, value_dim)

上記の例では、Attentionという名前のクラスを定義しており、query_dimkey_dimvalue_dimは各次元のサイズを表している。forwardメソッドでは、与えられたquery(クエリ)、keys(キー)、およびvalues(値)に対してAttentionメカニズムを適用している。

forwardメソッドの入力と出力の形状はコメントで説明されており、入力として、queryは形状(バッチサイズ、query_dim)、keysは形状(バッチサイズ、シーケンスの長さ、key_dim)、valuesは形状(バッチサイズ、シーケンスの長さ、value_dim)を持つテンソルを想定している。

この実装では、querykeysの内積を計算してAttentionスコアを得ている。そして、Attentionスコアを正規化して重み付け和を計算し、最終的な出力を得るものとなる。

Attentionメカニズムのバリエーション

Attentionメカニズムは、さまざまなバリエーションが存在している。以下にいくつかの主要なバリエーションについて述べる。

  • 点積(Dot Product) Attention: エンコーダとデコーダの中間表現の内積を計算して関連性を求める。このバリエーションは、Transformerモデルで広く使用されている。
  • 加法(Additive)Attention: エンコーダとデコーダの中間表現を結合し、非線形変換(通常はニューラルネットワーク)を適用して関連性を計算している。このバリエーションは、構造化データや画像などのタスクに適している。
  • 乗法(Multiplicative)Attention: エンコーダとデコーダの中間表現を要素ごとに乗算して関連性を計算します。これは加法Attentionの代替手法として使用されることがある。
  • 自己注意(Self-Attention): 自己注意は、系列データ内の要素間の関連性を計算するために使用される。この方式では、エンコーダ内の要素同士の関連性を計算し、重要な要素を強調して情報を集約しており、Transformerモデルで広く使用され、自然言語処理タスクで高い成果を収めている。
  • マルチヘッド Attention: マルチヘッド Attention は、複数の Attention メカニズムを並列に使用し、異なる表現の重要な部分をキャプチャするものとなる。各ヘッドは異なる重み行列を持ち、異なる情報を抽出するために使用されている。

これらは一部の主要なバリエーションだが、Attentionメカニズムではさまざまな改良や応用が提案されている。例えば、自己注意を時間的にスケールすることで長い系列データの処理を効率化する方法や、位置エンコーディングを組み合わせる方法などがあり、また、タスクに応じてカスタマイズされたアーキテクチャや機構が提案されることもある。

次に「Attention Is All You Need」論文の主題であるSelf-Attentionについて詳しく述べる。

Self-Attention(自己注意)

Self-Attentionは、Attentionメカニズムの一種であり、入力の各位置と他の全ての位置の関連性を計算するメカニズムであり、自然言語処理タスク、特にTransformerモデルで広く使用されているものとなる。

従来のAttentionメカニズムでは、エンコーダとデコーダの中間表現の関連性を計算して重み付けを行っていた。一方、Self-Attentionでは、入力自体の中間表現の関連性を計算するものとなる。これにより、各位置が自身と他の位置の関連性を持つことが可能となる。

Self-Attentionの基本的な手順は以下のようになる。

  1. 入力として与えられたシーケンスを、ベクトルの集合として表現する。各要素は位置ベクトルと呼ばれ、元のシーケンスの各位置に対応する。
  2. 位置ベクトルに対して、3つの異なる線形変換(Wq、Wk、Wv)を適用し、クエリベクトル(Query)、キーベクトル(Key)、バリューベクトル(Value)を生成する。
  3. クエリベクトルとキーベクトルの内積を計算し、関連性スコアを得る。関連性スコアは、クエリベクトルとキーベクトルの類似度を表す。
  4. 関連性スコアをソフトマックス関数によって正規化し、重みを計算する。これにより、各位置の重要度が表現される。
  5. 重み付けされたバリューベクトルと重みを掛け合わせ、重み付き平均を計算する。この平均は、Self-Attentionによって抽出されたコンテキストベクトルとなる。

Self-Attentionは、シーケンス内の各要素が自身と他の要素との関連性を考慮しながら情報を抽出することができ、これにより、重要な要素や依存関係を自動的に学習し、入力表現の適切な抽象化を可能にしている。また、並列計算が可能なため、長いシーケンスや大規模なデータの処理に効率的でもある。

Self-Attentionのpythonによる実装

以下に、PythonでのSelf-Attentionの実装例について述べる。

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.query_layer = nn.Linear(input_dim, hidden_dim)
        self.key_layer = nn.Linear(input_dim, hidden_dim)
        self.value_layer = nn.Linear(input_dim, hidden_dim)

    def forward(self, inputs):
        """
        inputs: shape (batch_size, seq_length, input_dim)
        """
        batch_size, seq_length, _ = inputs.size()

        # Queryをhidden_dim次元に変換
        queries = self.query_layer(inputs)  # shape (batch_size, seq_length, hidden_dim)

        # Keysをhidden_dim次元に変換
        keys = self.key_layer(inputs)  # shape (batch_size, seq_length, hidden_dim)

        # Valuesをhidden_dim次元に変換
        values = self.value_layer(inputs)  # shape (batch_size, seq_length, hidden_dim)

        # Attentionスコアを計算
        attention_scores = torch.matmul(queries, keys.transpose(1, 2))  # shape (batch_size, seq_length, seq_length)

        # Attentionスコアを正規化
        attention_probs = torch.softmax(attention_scores, dim=-1)  # shape (batch_size, seq_length, seq_length)

        # Weighted sumを計算
        attended_values = torch.matmul(attention_probs, values)  # shape (batch_size, seq_length, hidden_dim)

        # 結果をreturn
        return attended_values

上記の例では、SelfAttentionという名前のクラスを定義している。input_dimは入力の次元数、hidden_dimはSelf-Attentionの内部表現の次元数を示しており、forwardメソッドでは、入力シーケンスinputsに対してSelf-Attentionを適用している。

forwardメソッドの入力と出力の形状はコメントで説明されている。入力として、inputsは形状(バッチサイズ、シーケンスの長さ、input_dim)を持つテンソルを想定している。

この実装では、入力シーケンスの各要素に対してQuery、Key、Valueを計算し、Attentionスコアを得る。そして、Attentionスコアを正規化して重み付け和を計算し、最終的な出力を得ている。

以下にSelf-Attention以外のAttentionメカニズムの実装についても述べる。

点積Attention

点積Attention(Dot Product Attention)では、QueryとKeyの間の内積を計算してAttentionスコアを得るものとなる。以下に、Pythonでの実装例について述べる。

import torch
import torch.nn as nn

class DotProductAttention(nn.Module):
    def __init__(self):
        super(DotProductAttention, self).__init__()

    def forward(self, query, keys, values):
        """
        query: shape (batch_size, query_length, hidden_dim)
        keys: shape (batch_size, key_length, hidden_dim)
        values: shape (batch_size, key_length, value_dim)
        """
        batch_size, query_length, _ = query.size()
        key_length = keys.size(1)

        # QueryとKeyの内積を計算
        attention_scores = torch.matmul(query, keys.transpose(1, 2))  # shape (batch_size, query_length, key_length)

        # Attentionスコアを正規化
        attention_probs = torch.softmax(attention_scores, dim=-1)  # shape (batch_size, query_length, key_length)

        # Weighted sumを計算
        attended_values = torch.matmul(attention_probs, values)  # shape (batch_size, query_length, value_dim)

        # 結果をreturn
        return attended_values, attention_probs

上記の例では、DotProductAttentionという名前のクラスを定義しており、forwardメソッドでは、与えられたQuery、Keys、Valuesに対して点積Attentionを適用している。

forwardメソッドの入力と出力の形状はコメントで説明されている。入力として、queryは形状(バッチサイズ、クエリの長さ、隠れ次元)、keysは形状(バッチサイズ、キーの長さ、隠れ次元)、valuesは形状(バッチサイズ、キーの長さ、値の次元)を持つテンソルを想定している。

この実装では、QueryとKeyの内積を計算してAttentionスコアを得ている。そして、Attentionスコアを正規化して重み付け和を計算し、最終的な出力を得る。また、Attentionスコアも出力している。

加法(Additive)Attentionについて

加法Attention(Additive Attention)では、QueryとKeyの間の非線形な関数を使用してAttentionスコアを計算する。以下に、Pythonによる実装例を示す。

import torch
import torch.nn as nn

class AdditiveAttention(nn.Module):
    def __init__(self, query_dim, key_dim, hidden_dim):
        super(AdditiveAttention, self).__init__()
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.hidden_dim = hidden_dim

        self.query_layer = nn.Linear(query_dim, hidden_dim)
        self.key_layer = nn.Linear(key_dim, hidden_dim)
        self.energy_layer = nn.Linear(hidden_dim, 1)

    def forward(self, query, keys, values):
        """
        query: shape (batch_size, query_length, query_dim)
        keys: shape (batch_size, key_length, key_dim)
        values: shape (batch_size, key_length, value_dim)
        """
        batch_size, query_length, _ = query.size()
        key_length = keys.size(1)

        # QueryとKeyをhidden_dim次元に変換
        processed_query = self.query_layer(query)  # shape (batch_size, query_length, hidden_dim)
        processed_keys = self.key_layer(keys)  # shape (batch_size, key_length, hidden_dim)

        # Attentionスコアを計算
        energy = torch.tanh(processed_query + processed_keys)  # shape (batch_size, query_length, hidden_dim)
        attention_scores = self.energy_layer(energy).squeeze(-1)  # shape (batch_size, query_length, key_length)

        # Attentionスコアを正規化
        attention_probs = torch.softmax(attention_scores, dim=-1)  # shape (batch_size, query_length, key_length)

        # Weighted sumを計算
        attended_values = torch.matmul(attention_probs.unsqueeze(1), values).squeeze(1)  # shape (batch_size, query_length, value_dim)

        # 結果をreturn
        return attended_values, attention_probs

上記の例では、AdditiveAttentionという名前のクラスを定義しており、query_dimkey_dimhidden_dimは次元のサイズを表している。forwardメソッドでは、与えられたQuery、Keys、Valuesに対して加法Attentionを適用している。

forwardメソッドの入力と出力の形状はコメントで説明されている。入力として、queryは形状(バッチサイズ、クエリの長さ、クエリの次元)、keysは形状(バッチサイズ、キーの長さ、キーの次元)、valuesは形状(バッチサイズ、キーの長さ、値の次元)を持つテンソルを想定している。

この実装では、QueryとKeyを非線形な関数で変換し、それらを足し合わせたエネルギーを計算する。ここではエネルギーを1次元に変換してAttentionスコアを得る。そして、Attentionスコアを正規化して重み付け和を計算し、最終的な出力を得ており、Attentionスコアも出力している。

乗法(Multiplicative)Attentionについて

乗法Attention(Multiplicative Attention)では、QueryとKeyの間の内積を計算してAttentionスコアを得ている。以下に、Pythonによる実装例を示す。

import torch
import torch.nn as nn

class MultiplicativeAttention(nn.Module):
    def __init__(self, query_dim, key_dim):
        super(MultiplicativeAttention, self).__init__()
        self.query_dim = query_dim
        self.key_dim = key_dim

        self.query_layer = nn.Linear(query_dim, key_dim)

    def forward(self, query, keys, values):
        """
        query: shape (batch_size, query_length, query_dim)
        keys: shape (batch_size, key_length, key_dim)
        values: shape (batch_size, key_length, value_dim)
        """
        batch_size, query_length, _ = query.size()
        key_length = keys.size(1)

        # QueryをKeyの次元に変換
        processed_query = self.query_layer(query)  # shape (batch_size, query_length, key_dim)

        # QueryとKeyの内積を計算
        attention_scores = torch.matmul(processed_query, keys.transpose(1, 2))  # shape (batch_size, query_length, key_length)

        # Attentionスコアを正規化
        attention_probs = torch.softmax(attention_scores, dim=-1)  # shape (batch_size, query_length, key_length)

        # Weighted sumを計算
        attended_values = torch.matmul(attention_probs, values)  # shape (batch_size, query_length, value_dim)

        # 結果をreturn
        return attended_values, attention_probs

上記の例では、MultiplicativeAttentionという名前のクラスを定義しており、query_dimkey_dimは次元のサイズを表している。forwardメソッドでは、与えられたQuery、Keys、Valuesに対して乗法Attentionを適用している。

forwardメソッドの入力と出力の形状はコメントで説明されている。入力として、queryは形状(バッチサイズ、クエリの長さ、クエリの次元)、keysは形状(バッチサイズ、キーの長さ、キーの次元)、valuesは形状(バッチサイズ、キーの長さ、値の次元)を持つテンソルを想定している。

この実装では、QueryをKeyの次元に変換し、QueryとKeyの内積を計算してAttentionスコアを得ている。そして、Attentionスコアを正規化して重み付け和を計算し、最終的な出力を得て、Attentionスコアも出力している。

マルチヘッド Attentionについて

マルチヘッドAttention(Multi-head Attention)は、複数の注意ヘッド(ヘッド数)を持つことで、より豊かな表現と柔軟性を提供するものとなる。以下に、Pythonによる実装例を示す。

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # ヘッドごとに異なる重み行列を定義
        self.query_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_heads)])
        self.key_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_heads)])
        self.value_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_heads)])

        # ヘッドの結果を結合するための線形変換層
        self.linear = nn.Linear(hidden_dim * num_heads, hidden_dim)

    def forward(self, inputs):
        """
        inputs: shape (batch_size, seq_length, input_dim)
        """
        batch_size, seq_length, _ = inputs.size()

        # ヘッドごとにQuery, Key, Valueを計算
        queries = [query_layer(inputs) for query_layer in self.query_layers]  # List of tensors with shape (batch_size, seq_length, hidden_dim)
        keys = [key_layer(inputs) for key_layer in self.key_layers]  # List of tensors with shape (batch_size, seq_length, hidden_dim)
        values = [value_layer(inputs) for value_layer in self.value_layers]  # List of tensors with shape (batch_size, seq_length, hidden_dim)

        # ヘッドごとにAttentionスコアを計算
        attention_scores = [torch.matmul(queries[i], keys[i].transpose(1, 2)) for i in range(self.num_heads)]  # List of tensors with shape (batch_size, seq_length, seq_length)

        # ヘッドごとにAttentionスコアを正規化
        attention_probs = [torch.softmax(score, dim=-1) for score in attention_scores]  # List of tensors with shape (batch_size, seq_length, seq_length)

        # ヘッドごとにWeighted sumを計算
        attended_values = [torch.matmul(attention_probs[i], values[i]) for i in range(self.num_heads)]  # List of tensors with shape (batch_size, seq_length, hidden_dim)

        # ヘッドの結果を結合し、線形変換を適用
        concatenated_values = torch.cat(attended_values, dim=-1)  # shape (batch_size, seq_length, hidden_dim * num_heads)
        output = self.linear(concatenated_values)  # shape (batch_size, seq_length, hidden_dim)

        # 結果をreturn
        return output

上記の例では、MultiHeadAttentionという名前のクラスを定義しており、input_dimは入力の次元数、hidden_dimは各ヘッドの隠れ層の次元数、num_headsはヘッドの数を表している。

forwardメソッドでは、与えられた入力に対してマルチヘッドAttentionを適用している。入力inputsは形状(バッチサイズ、シーケンスの長さ、入力の次元)のテンソルを想定している。

マルチヘッドAttentionでは、各ヘッドごとにQuery、Key、Valueを計算し、Attentionスコアを得ている。それぞれのヘッドに対して内積や正規化、Weighted sumを行い、結果を結合して最終的な出力を得る。上記の実装では位置エンコーディングや残差接続などの要素は省略しているが、これらを追加することでTransformerモデルとして完全に機能するようにすることも可能となる。

Softmax Attentionについて

Softmax Attentionは、エンコーダの中間表現とデコーダの状態の関連性を計算し、重みを決定するためにソフトマックス関数を使用するものとなる。以下に、Softmax Attentionのアルゴリズムのステップを示す。

  1. エンコーダの中間表現を入力として受け取り、デコーダの現在の状態との関連性を計算する。関連性は、通常、内積や類似度関数(例えば、ドット積やコサイン類似度)を使用して計算される。
  2. 得られた関連性を正規化するために、ソフトマックス関数を適用する。ソフトマックス関数は、関連性の値を確率として解釈するために使用される。これにより、関連性の合計が1になり、重みとして解釈できるようになる。
  3. ソフトマックス関数を適用した後、各エンコーダの中間表現に対して重みが割り当てられる。重みは、デコーダが注目すべきエンコーダの部分を示す。
  4. デコーダは、エンコーダの中間表現と重みの加重平均を計算し、この加重平均を使用して次の出力を生成する。重みが大きいエンコーダの部分にはより注意を払いながら、適切な情報を抽出して出力に反映させることができる。

Softmax Attentionは、入力の異なる部分への重要度を連続値の確率として表現するため、注目する部分を柔軟に選択できるメカニズムとして広く使用されている。ただし、長いシーケンスや大規模な入力データの場合、計算コストが増大する場合があるため、注意が必要となる。

参考情報と参考図書

機械学習による自動生成に関しては”機械学習による自動生成“に詳細を述べている。そちらも参照のこと。

参考図書としては”機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発”

Transformerによる自然言語処理”

Vision Transformer入門 Computer Vision Library”等がある。

コメント

  1. […] 一つの側面として強化学習がある。chatGPTのベースとなっているGPTの肝は”深層学習におけるattentionについて“で述べたattentionをベースとしたtransformerと、強化学習による深層学習モ […]

  2. […] GNNにおいてノード間の重要度を学習するために使用されるものとなる。”深層学習におけるattentionについて“でも述べているアテンションメカニズムは、近年最もちゅうもくされて […]

  3. […] べているResNet、”EfficientNetについて“で述べているEfficientNet、また”深層学習におけるattentionについて“で述べているAttention Mechanismなどの新しいアイデアも導入されている。 […]

  4. […] “深層学習におけるattentionについて“で述べているAttention Mechanismは、特定の時間ステップで重要な情報に焦点を当てるために使用され、これは、長いシーケンスの中での重要な情 […]

  5. […]  BERTは”Transformerモデルの概要とアルゴリズム及び実装例について“でも述べているTransformerと呼ばれるアーキテクチャを基にしている。Transformerは”深層学習におけるattentionについて“でも述べている自己注意機構(Self-Attention)を使用して文脈情報を効率的に捉えることができ、BERTはその恩恵を受けている。 […]

  6. […] アイデア: GATは、”深層学習におけるattentionについて“でも述べている注意機構を用いてノードの表現を学習し、ノード間の関係性に重みをつけ、より重要なノードに注目するア […]

  7. […] 採用することで、ノードの表現力を向上させることができる。例えば、”深層学習におけるattentionについて“で述べている注意機構(Attention Mechanism)を用いて、重要な隣接ノード […]

  8. […] Transformerの中核となる要素は、”深層学習におけるattentionについて“でも述べている自己注意機構(Self-Attention)となる。この機構は、入力シーケンス内の要素間の関連性を計算するために使用され、それにより、シーケンス内の異なる要素への重要度を学習し、文脈を理解するのに役立つものとなっている。 […]

  9. […] 2. 注意機構(Attention Mechanism): OpenNMTは、”深層学習におけるattentionについて“でも述べている注意機構をサポートしている。注意機構は、デコーダが各入力トークンに対して適切な重みを割り当て、エンコーダの出力に基づいて出力トークンを生成するための仕組みとなる。これにより、長い文や複雑な構造を持つ文の翻訳が改善される。 […]

  10. […] chatGPTで有名なOpenAIのもう一つの側面として強化学習がある。chatGPTのベースとなっている“GPTの概要とアルゴリズム及び実装例について“で述べているGPTの肝は”深層学習におけるattentionについて“で述べたattentionをベースとした”Transformerモデルの概要とアルゴリズム及び実装例について“でも述べているtransformerと、強化学習による深層学習モデルの改善にあると言われている。深層学習と聞くと、AlphaGoに代表されるゲームへの適用か、車の自動運転への適用がすぐイメージされるが、今回は強化学習に対してもう少し深掘りした検討を行う。 […]

タイトルとURLをコピーしました