この記事は、GROOVE Xアドベントカレンダー2024 の10日目の記事です。
はじめに
こんにちは!ふるまいチームのソフトウェアエンジニア、橋本です。
主に意思決定エンジンの開発を行っており、その中で機械学習モデルに触れる機会もあります。
今回はその機械学習まわりの話をご紹介します。
※ふるまいチームについては、少し前の記事になりますが「LOVOTのふるまいづくり - Inside of LOVOT 」をご覧ください。
経緯
PyTorch で学習した Transformer 風のモデルを LOVOT 上で動かしたい!
→ LOVOT の SoC (Jetson Orin) に最適化するには別の形式に変換した方が良さそう
→ まずは ONNX という形式のファイルに変換する必要あり
というわけで、ONNX の読み方も知らなかったエンジニアが、PyTorch モデルの ONNX 変換に取り組むことになりました。
その過程で四苦八苦することになったわけですが、おかげで得られた知見もあったので共有できればと思います。
ONNX とは
「オニキス」と読みます。Open Neural Network Exchange の略です。
cf. https://onnx.ai/
機械学習モデルを表現するために使用されるオープンソースのフォーマットであり、乱立する機械学習フレームワークの橋渡し的な存在です。
たとえば私のケースでは PyTorch で学習したモデルを TensorRT という推論エンジンで動かしたかったのですが、そのためにまずは ONNX 形式に変換する必要がありました。
PyTorch → ONNX 変換の基本
基本的に PyTorch モデルを ONNX 変換するのは難しいことではありません。
公式ドキュメント にはこんな例が載っています。(一部改変しています)
import torch
class MyModel (torch.nn.Module):
def __init__ (self):
super (MyModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1 , 128 , 5 )
def forward (self, x):
return torch.relu(self.conv1(x))
input_tensor = torch.rand((1 , 1 , 128 , 128 ), dtype=torch.float32)
model = MyModel()
torch.onnx.export(
model,
(input_tensor,),
"my_model.onnx" ,
input_names=["input" ],
)
本質的に重要なのは torch.onnx.export
の所だけですね。とても簡単。
そのまま変換
具体的なコードを見ながら進めましょう。
Transformer と銘打ちましたが、かなり単純化するため (Causal) Self-Attention 周辺のみ、それも single head にしています。
class SimpleTransformer (nn.Module):
def __init__ (self, hidden_dim, max_length=128 ):
super ().__init__()
self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3 )
self.register_buffer(
"causal_mask" , torch.tril(torch.ones(max_length, max_length))
)
def forward (self, x):
q, k, v = self.linear_qkv(x).chunk(3 , dim=-1 )
mask = self.causal_mask[: x.size(1 ), : x.size(1 )]
output = self.single_head_attention(q, k, v, mask)
return output, k, v
def single_head_attention (self, q, k, v, mask=None ):
score = torch.einsum("... l d, ... m d -> ... l m" , q, k)
if mask is not None :
score.masked_fill_(torch.logical_not(mask), float ("-inf" ))
scale = q.size(-1 ) ** -0.5
normalized_score = torch.softmax(score * scale, dim=-1 )
return torch.einsum("... l m, ... m d -> ... l d" , normalized_score, v)
model = SimpleTransformer(hidden_dim=128 )
torch.onnx.export(
model,
(torch.randn(1 , 1 , HIDDEN_DIM),),
"simple_transformer.onnx" ,
input_names=["x" ],
dynamic_axes={"x" : {1 : "length" }},
)
モデルは先ほどと違いますが、変換方法はほぼ同じですね。
可変長の系列を受け取れるようにするため dynamic_axes
を指定していますが、難しいことはありません。
本題はここからです!
推論の効率化
上述のモデルは並列計算が可能な学習には適していますが、再帰的な推論に用いると計算効率が悪いです。
そのことを見るために、まずは Causal Self-Attention の計算内容をおさらいしましょう。
Causal Self-Attention
この計算を再帰的に行う場合、すべてのステップに対して毎回 query, key, value および出力を計算することになりますが、過去のステップに対して計算するのはムダです。
そこで、過去のステップの key, value を保存しておくことで効率化します。
このテクニックは KV Cache と呼ばれ、LLM の推論に広く用いられています。
(ちなみに Causal でない Self-Attention では KV Cache が使えませんが、少なくとも LLM では Causal なモデルが圧倒的な多数派です。)
Causal Self-Attention with KV Cache
先ほどの図と揃えるために Causal Mask を残していますが、すべて 1
のマスクなので削除しても同じです。
KV Cache を使った計算を ONNX に変換するために、先ほどとは別のモデルを用意します。
ただし、パラメーターは共有する必要があります。
こんな感じで実現できます。
class SimpleTransformerWithCache (nn.Module):
def __init__ (self, train_model, max_length=128 ):
super ().__init__()
self._train_model = train_model
def forward (self, x, key_cache, value_cache):
q, k, v = self._train_model.linear_qkv(x).chunk(3 , dim=-1 )
k_merged = torch.cat([key_cache, k], dim=1 )
v_merged = torch.cat([value_cache, v], dim=1 )
output = self._train_model.single_head_attention(q, k_merged, v_merged)
return output, k_merged, v_merged
ポイント
学習モデルのパラメーターを使って推論
KV Cache を利用
KV Cache を更新するために key, value も返す
ONNX 変換は以下のようにできます。
torch.onnx.export(
model_with_cache,
(
torch.randn(1 , 1 , HIDDEN_DIM),
torch.randn(1 , 1 , HIDDEN_DIM),
torch.randn(1 , 1 , HIDDEN_DIM),
),
"simple_transformer_with_cache.onnx" ,
input_names=["x" , "key_cache" , "value_cache" ],
dynamic_axes={
"key_cache" : {1 : "cache_length" },
"value_cache" : {1 : "cache_length" },
},
)
実際の推論
これで効率化もできました。
でもまだ終わりではありません!(もうちょっとお付き合いください)
LLM をイメージしてみましょう。
モデルはまずプロンプトを受け取り、そこから再帰的にトークンを生成していきます。
つまり推論には2つのフェーズがあるわけです。
入力系列を受け取って KV Cache を構築する(並列計算)
KV Cache を利用してトークンを生成(再帰的計算)
1つ目は SimpleTransformer
で実現できます。(厳密には key, value を返り値に追加するなど変更が必要です。詳細は最後のコードを参照してください。)
2つ目は SimpleTransformerWithCache
で実現できます。
つまり、両方のモデルを ONNX 変換し、推論時には使い分けることになります。
モデルをまとめる
ここで終わっても良いのですが、実は上で作成した2つのモデルは1つにまとめられます。
KV Cache を使った再帰的計算では、KV Cache の系列長が可変である一方、入力の系列長は1に固定していました。
しかし1に固定する必要はなく、可変長でも問題なく計算できます。
その上で KV Cache の長さが0の場合を考えると、この計算は KV Cache を使わない並列計算に対応します。
図にするとこんな感じです。
Causal Self-Attention with empty KV Cache
つまり、SimpleTransformerWithCache
について x
の系列長を可変にするだけで、SimpleTransformer
の計算を実現できるわけです。
そのため、SimpleTransformerWithCache
だけを(dynamic_axes
に x
の指定を加えて)ONNX 変換すれば良いということになります。
こちらも詳細は最後のコードを参照してください。
余談 - Whisper モデルのファイル名
Whisper という音声認識モデルがあります。
ここではその詳細は関係ないのですが、Optimum というライブラリで Whisper を ONNX 変換すると3つの decoder が作られます。
$ optimum-cli export onnx --model openai/whisper-tiny whisper_onnx
$ ls whisper_onnx/decoder_*.onnx
whisper_onnx/decoder_model.onnx whisper_onnx/decoder_model_merged.onnx whisper_onnx/decoder_with_past_model.onnx
この3つのモデルは、今回紹介した推論モデルに対応していると思われます。
(ちゃんと検証したわけではないですが...)
decoder_model.onnx: 可変長の入力を受け取って KV Cache を構築するモデル
decoder_with_past_model.onnx: KV Cache を使って1ステップの出力を計算するモデル
decoder_model_merged.onnx: 上記2つをまとめたモデル
前述の通り3つ目のモデルだけで事足りるので、最近は Optimum で最後のモデルだけ出力されることが多いようです。
GPT-2 などでも試してみましたが、ONNX ファイルは1つしか生成されませんでした。
さいごに
Transformer を ONNX 変換する過程で得られた知見をご紹介しました。
もちろん今回の変換方法がいつでも適用できるわけではなく、モデルが変わればやり方も変わってきます。
要するに、複雑なモデルの ONNX 変換は都度考える必要があって結構面倒ということですね。。
ただし、広く使われているモデルはこんな風に自前で変換する必要はありません。
前述の Optimum などでサクッと変換できます。
もし自作モデルを ONNX に変換したい時に、この記事が少しでも参考になれば幸いです。
最後に、モデルを変換して推論結果が変換前と一致することを確認するまでの一連のコードを貼っておきます。
コードまとめ
import numpy as np
import onnxruntime as ort
import torch
import torch.nn as nn
class SimpleTransformer (nn.Module):
def __init__ (self, hidden_dim, max_length=128 ):
super ().__init__()
self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3 )
self.register_buffer(
"causal_mask" , torch.tril(torch.ones(max_length, max_length))
)
def forward (self, x):
q, k, v = self.linear_qkv(x).chunk(3 , dim=-1 )
mask = self.causal_mask[: x.size(1 ), : x.size(1 )]
output = self.single_head_attention(q, k, v, mask)
return output, k, v
def single_head_attention (self, q, k, v, mask=None ):
score = torch.einsum("... l d, ... m d -> ... l m" , q, k)
if mask is not None :
score.masked_fill_(torch.logical_not(mask), float ("-inf" ))
scale = q.size(-1 ) ** -0.5
normalized_score = torch.softmax(score * scale, dim=-1 )
return torch.einsum("... l m, ... m d -> ... l d" , normalized_score, v)
class SimpleTransformerWithCache (nn.Module):
def __init__ (self, train_model, max_length=128 ):
super ().__init__()
self._train_model = train_model
self.register_buffer("full_mask" , torch.ones(max_length, max_length))
self.register_buffer(
"causal_mask" , torch.tril(torch.ones(max_length, max_length))
)
def forward (self, x, key_cache, value_cache):
q, k, v = self._train_model.linear_qkv(x).chunk(3 , dim=-1 )
k_merged = torch.cat([key_cache, k], dim=1 )
v_merged = torch.cat([value_cache, v], dim=1 )
mask = torch.concat(
[
self.full_mask[: q.size(1 ), : key_cache.size(1 )],
self.causal_mask[: q.size(1 ), : q.size(1 )],
],
dim=1 ,
)
output = self._train_model.single_head_attention(q, k_merged, v_merged, mask)
return output, k_merged, v_merged
@ torch.no_grad ()
def run_torch_model (model, x):
return model(torch.tensor(x))[0 ].numpy()
def run_ort_model (ort_session, x_prefill, x_decode):
outputs = []
output, key_cache, value_cache = ort_session.run(
None ,
dict (
x=x_prefill,
key_cache=np.zeros((1 , 0 , x_prefill.shape[-1 ]), dtype=np.float32),
value_cache=np.zeros((1 , 0 , x_prefill.shape[-1 ]), dtype=np.float32),
),
)
outputs.append(output)
for i in range (x_decode.shape[1 ]):
output, key_cache, value_cache = ort_session.run(
None ,
dict (
x=x_decode[:, i : i + 1 ],
key_cache=key_cache,
value_cache=value_cache,
),
)
outputs.append(output)
return np.concatenate(outputs, axis=1 )
if __name__ == "__main__" :
HIDDEN_DIM = 128
model = SimpleTransformer(HIDDEN_DIM)
model_with_cache = SimpleTransformerWithCache(model)
torch.onnx.export(
model_with_cache,
(
torch.randn(1 , 1 , HIDDEN_DIM),
torch.randn(1 , 1 , HIDDEN_DIM),
torch.randn(1 , 1 , HIDDEN_DIM),
),
"simple_transformer_merged.onnx" ,
input_names=["x" , "key_cache" , "value_cache" ],
dynamic_axes={
"x" : {1 : "length" },
"key_cache" : {1 : "cache_length" },
"value_cache" : {1 : "cache_length" },
},
)
SAMPLE_LENGTH = 10
x = np.random.randn(1 , SAMPLE_LENGTH, HIDDEN_DIM).astype(np.float32)
output_torch = run_torch_model(model, x)
output_ort = run_ort_model(
ort_session=ort.InferenceSession("simple_transformer_merged.onnx" ),
x_prefill=x[:, : SAMPLE_LENGTH // 2 ],
x_decode=x[:, SAMPLE_LENGTH // 2 :],
)
diff = np.abs(output_torch - output_ort)
print (f"Max diff: {diff.max()}" )
出力結果:Max diff: 1.1920928955078125e-07
一緒に働く仲間募集中
GROOVE Xでは、一緒に働く仲間を募集しています!
少しでも興味を持ってくださいましたら、下記のリンクをご覧ください。
recruit.jobcan.jp