Inside of LOVOT

GROOVE X 技術ブログ

Transformer を ONNX 形式に変換するのに苦労した話

この記事は、GROOVE Xアドベントカレンダー2024 の10日目の記事です。

はじめに

こんにちは!ふるまいチームのソフトウェアエンジニア、橋本です。
主に意思決定エンジンの開発を行っており、その中で機械学習モデルに触れる機会もあります。
今回はその機械学習まわりの話をご紹介します。

※ふるまいチームについては、少し前の記事になりますが「LOVOTのふるまいづくり - Inside of LOVOT」をご覧ください。

経緯

PyTorch で学習した Transformer 風のモデルを LOVOT 上で動かしたい!
→ LOVOT の SoC (Jetson Orin) に最適化するには別の形式に変換した方が良さそう
→ まずは ONNX という形式のファイルに変換する必要あり

というわけで、ONNX の読み方も知らなかったエンジニアが、PyTorch モデルの ONNX 変換に取り組むことになりました。
その過程で四苦八苦することになったわけですが、おかげで得られた知見もあったので共有できればと思います。

ONNX とは

「オニキス」と読みます。Open Neural Network Exchange の略です。
cf. https://onnx.ai/

機械学習モデルを表現するために使用されるオープンソースのフォーマットであり、乱立する機械学習フレームワークの橋渡し的な存在です。
たとえば私のケースでは PyTorch で学習したモデルを TensorRT という推論エンジンで動かしたかったのですが、そのためにまずは ONNX 形式に変換する必要がありました。

PyTorch → ONNX 変換の基本

基本的に PyTorch モデルを ONNX 変換するのは難しいことではありません。
公式ドキュメントにはこんな例が載っています。(一部改変しています)

import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 128, 5)

    def forward(self, x):
        return torch.relu(self.conv1(x))

input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)
model = MyModel()
torch.onnx.export(
    model,                  # model to export
    (input_tensor,),        # inputs of the model,
    "my_model.onnx",        # filename of the ONNX model
    input_names=["input"],  # Rename inputs for the ONNX model
)

本質的に重要なのは torch.onnx.export の所だけですね。とても簡単。

Transformer の場合

そのまま変換

具体的なコードを見ながら進めましょう。
Transformer と銘打ちましたが、かなり単純化するため (Causal) Self-Attention 周辺のみ、それも single head にしています。

class SimpleTransformer(nn.Module):
    def __init__(self, hidden_dim, max_length=128):
        super().__init__()
        self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
        self.register_buffer(
            "causal_mask", torch.tril(torch.ones(max_length, max_length))
        )

    def forward(self, x):
        # (batch, length, hidden_dim) x 3
        q, k, v = self.linear_qkv(x).chunk(3, dim=-1)
        mask = self.causal_mask[: x.size(1), : x.size(1)]
        output = self.single_head_attention(q, k, v, mask)
        return output, k, v

    def single_head_attention(self, q, k, v, mask=None):
        # l: length (query)
        # m: length (key, value)
        # d: hidden_dim
        score = torch.einsum("... l d, ... m d -> ... l m", q, k)
        if mask is not None:
            score.masked_fill_(torch.logical_not(mask), float("-inf"))
        scale = q.size(-1) ** -0.5
        normalized_score = torch.softmax(score * scale, dim=-1)
        return torch.einsum("... l m, ... m d -> ... l d", normalized_score, v)
        

model = SimpleTransformer(hidden_dim=128)
# 学習は省略
torch.onnx.export(
    model,
    (torch.randn(1, 1, HIDDEN_DIM),),  # x (batch, length, hidden_dim)
    "simple_transformer.onnx",
    input_names=["x"],
    # x は系列長 (axis=1) が可変
    dynamic_axes={"x": {1: "length"}},
)

モデルは先ほどと違いますが、変換方法はほぼ同じですね。
可変長の系列を受け取れるようにするため dynamic_axes を指定していますが、難しいことはありません。
本題はここからです!

推論の効率化

上述のモデルは並列計算が可能な学習には適していますが、再帰的な推論に用いると計算効率が悪いです。
そのことを見るために、まずは Causal Self-Attention の計算内容をおさらいしましょう。

Causal Self-Attention

この計算を再帰的に行う場合、すべてのステップに対して毎回 query, key, value および出力を計算することになりますが、過去のステップに対して計算するのはムダです。
そこで、過去のステップの key, value を保存しておくことで効率化します。
このテクニックは KV Cache と呼ばれ、LLM の推論に広く用いられています。 (ちなみに Causal でない Self-Attention では KV Cache が使えませんが、少なくとも LLM では Causal なモデルが圧倒的な多数派です。)

Causal Self-Attention with KV Cache

先ほどの図と揃えるために Causal Mask を残していますが、すべて 1 のマスクなので削除しても同じです。

KV Cache を使った計算を ONNX に変換するために、先ほどとは別のモデルを用意します。
ただし、パラメーターは共有する必要があります。
こんな感じで実現できます。

class SimpleTransformerWithCache(nn.Module):
    def __init__(self, train_model, max_length=128):
        super().__init__()
        self._train_model = train_model

    def forward(self, x, key_cache, value_cache):
        # (batch, length=1, hidden_dim) x 3
        q, k, v = self._train_model.linear_qkv(x).chunk(3, dim=-1)
        # key, value はキャッシュと連結
        # (batch, cache_length + length, hidden_dim)
        k_merged = torch.cat([key_cache, k], dim=1)
        v_merged = torch.cat([value_cache, v], dim=1)
        # (batch, length, hidden_dim)
        output = self._train_model.single_head_attention(q, k_merged, v_merged)
        return output, k_merged, v_merged

ポイント

  • 学習モデルのパラメーターを使って推論
  • KV Cache を利用
  • KV Cache を更新するために key, value も返す

ONNX 変換は以下のようにできます。

torch.onnx.export(
    model_with_cache,
    (
        # x (batch, length=1, hidden_dim)
        torch.randn(1, 1, HIDDEN_DIM),
        # key_cache (batch, cache_length, hidden_dim)
        torch.randn(1, 1, HIDDEN_DIM),
        # value_cache (batch, cache_length, hidden_dim)
        torch.randn(1, 1, HIDDEN_DIM),
    ),
    "simple_transformer_with_cache.onnx",
    input_names=["x", "key_cache", "value_cache"],
    dynamic_axes={
        "key_cache": {1: "cache_length"},
        "value_cache": {1: "cache_length"},
    },
)

実際の推論

これで効率化もできました。
でもまだ終わりではありません!(もうちょっとお付き合いください)

LLM をイメージしてみましょう。
モデルはまずプロンプトを受け取り、そこから再帰的にトークンを生成していきます。
つまり推論には2つのフェーズがあるわけです。

  1. 入力系列を受け取って KV Cache を構築する(並列計算)
  2. KV Cache を利用してトークンを生成(再帰的計算)

1つ目は SimpleTransformer で実現できます。(厳密には key, value を返り値に追加するなど変更が必要です。詳細は最後のコードを参照してください。)
2つ目は SimpleTransformerWithCache で実現できます。

つまり、両方のモデルを ONNX 変換し、推論時には使い分けることになります。

モデルをまとめる

ここで終わっても良いのですが、実は上で作成した2つのモデルは1つにまとめられます。

KV Cache を使った再帰的計算では、KV Cache の系列長が可変である一方、入力の系列長は1に固定していました。
しかし1に固定する必要はなく、可変長でも問題なく計算できます。
その上で KV Cache の長さが0の場合を考えると、この計算は KV Cache を使わない並列計算に対応します。
図にするとこんな感じです。

Causal Self-Attention with empty KV Cache

つまり、SimpleTransformerWithCache について x の系列長を可変にするだけで、SimpleTransformer の計算を実現できるわけです。
そのため、SimpleTransformerWithCache だけを(dynamic_axesx の指定を加えて)ONNX 変換すれば良いということになります。
こちらも詳細は最後のコードを参照してください。

余談 - Whisper モデルのファイル名

Whisper という音声認識モデルがあります。
ここではその詳細は関係ないのですが、Optimum というライブラリで Whisper を ONNX 変換すると3つの decoder が作られます。

$ optimum-cli export onnx --model openai/whisper-tiny whisper_onnx
$ ls whisper_onnx/decoder_*.onnx
whisper_onnx/decoder_model.onnx  whisper_onnx/decoder_model_merged.onnx  whisper_onnx/decoder_with_past_model.onnx

この3つのモデルは、今回紹介した推論モデルに対応していると思われます。
(ちゃんと検証したわけではないですが...)

  • decoder_model.onnx: 可変長の入力を受け取って KV Cache を構築するモデル
  • decoder_with_past_model.onnx: KV Cache を使って1ステップの出力を計算するモデル
  • decoder_model_merged.onnx: 上記2つをまとめたモデル

前述の通り3つ目のモデルだけで事足りるので、最近は Optimum で最後のモデルだけ出力されることが多いようです。
GPT-2 などでも試してみましたが、ONNX ファイルは1つしか生成されませんでした。

さいごに

Transformer を ONNX 変換する過程で得られた知見をご紹介しました。
もちろん今回の変換方法がいつでも適用できるわけではなく、モデルが変わればやり方も変わってきます。
要するに、複雑なモデルの ONNX 変換は都度考える必要があって結構面倒ということですね。。

ただし、広く使われているモデルはこんな風に自前で変換する必要はありません。
前述の Optimum などでサクッと変換できます。
もし自作モデルを ONNX に変換したい時に、この記事が少しでも参考になれば幸いです。

最後に、モデルを変換して推論結果が変換前と一致することを確認するまでの一連のコードを貼っておきます。

コードまとめ

import numpy as np
import onnxruntime as ort
import torch
import torch.nn as nn


class SimpleTransformer(nn.Module):
    def __init__(self, hidden_dim, max_length=128):
        super().__init__()
        self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
        self.register_buffer(
            "causal_mask", torch.tril(torch.ones(max_length, max_length))
        )

    def forward(self, x):
        # (batch, length, hidden_dim) x 3
        q, k, v = self.linear_qkv(x).chunk(3, dim=-1)
        mask = self.causal_mask[: x.size(1), : x.size(1)]
        output = self.single_head_attention(q, k, v, mask)
        return output, k, v

    def single_head_attention(self, q, k, v, mask=None):
        # l: length (query)
        # m: length (key, value)
        # d: hidden_dim
        score = torch.einsum("... l d, ... m d -> ... l m", q, k)
        if mask is not None:
            score.masked_fill_(torch.logical_not(mask), float("-inf"))
        scale = q.size(-1) ** -0.5
        normalized_score = torch.softmax(score * scale, dim=-1)
        return torch.einsum("... l m, ... m d -> ... l d", normalized_score, v)


class SimpleTransformerWithCache(nn.Module):
    def __init__(self, train_model, max_length=128):
        super().__init__()
        self._train_model = train_model
        self.register_buffer("full_mask", torch.ones(max_length, max_length))
        self.register_buffer(
            "causal_mask", torch.tril(torch.ones(max_length, max_length))
        )

    def forward(self, x, key_cache, value_cache):
        # (batch, length, hidden_dim) x 3
        q, k, v = self._train_model.linear_qkv(x).chunk(3, dim=-1)
        # key, value はキャッシュと連結
        # (batch, cache_length + length, hidden_dim)
        k_merged = torch.cat([key_cache, k], dim=1)
        v_merged = torch.cat([value_cache, v], dim=1)
        # (batch, length, cache_length + length)
        mask = torch.concat(
            [
                self.full_mask[: q.size(1), : key_cache.size(1)],
                self.causal_mask[: q.size(1), : q.size(1)],
            ],
            dim=1,
        )
        # (batch, length, hidden_dim)
        output = self._train_model.single_head_attention(q, k_merged, v_merged, mask)
        return output, k_merged, v_merged


@torch.no_grad()
def run_torch_model(model, x):
    return model(torch.tensor(x))[0].numpy()


def run_ort_model(ort_session, x_prefill, x_decode):
    outputs = []

    # 並列計算
    output, key_cache, value_cache = ort_session.run(
        None,
        dict(
            x=x_prefill,
            # KV Cache は空
            key_cache=np.zeros((1, 0, x_prefill.shape[-1]), dtype=np.float32),
            value_cache=np.zeros((1, 0, x_prefill.shape[-1]), dtype=np.float32),
        ),
    )
    outputs.append(output)

    # 再帰的計算
    for i in range(x_decode.shape[1]):
        output, key_cache, value_cache = ort_session.run(
            None,
            dict(
                x=x_decode[:, i : i + 1],
                key_cache=key_cache,
                value_cache=value_cache,
            ),
        )
        outputs.append(output)

    return np.concatenate(outputs, axis=1)


if __name__ == "__main__":
    HIDDEN_DIM = 128
    model = SimpleTransformer(HIDDEN_DIM)
    model_with_cache = SimpleTransformerWithCache(model)

    torch.onnx.export(
        model_with_cache,
        (
            # x (batch, length, hidden_dim)
            torch.randn(1, 1, HIDDEN_DIM),
            # key_cache (batch, cache_length, hidden_dim)
            torch.randn(1, 1, HIDDEN_DIM),
            # value_cache (batch, cache_length, hidden_dim)
            torch.randn(1, 1, HIDDEN_DIM),
        ),
        "simple_transformer_merged.onnx",
        input_names=["x", "key_cache", "value_cache"],
        dynamic_axes={
            "x": {1: "length"},
            "key_cache": {1: "cache_length"},
            "value_cache": {1: "cache_length"},
        },
    )

    SAMPLE_LENGTH = 10
    x = np.random.randn(1, SAMPLE_LENGTH, HIDDEN_DIM).astype(np.float32)

    output_torch = run_torch_model(model, x)

    output_ort = run_ort_model(
        ort_session=ort.InferenceSession("simple_transformer_merged.onnx"),
        # ここでは便宜上 x の前半を並列計算、後半を再帰的計算に使っている
        # 実際には x_decode は逐次的に得られるものであり、このように事前に得られるものではない
        x_prefill=x[:, : SAMPLE_LENGTH // 2],
        x_decode=x[:, SAMPLE_LENGTH // 2 :],
    )

    diff = np.abs(output_torch - output_ort)
    print(f"Max diff: {diff.max()}")

出力結果:Max diff: 1.1920928955078125e-07

一緒に働く仲間募集中

GROOVE Xでは、一緒に働く仲間を募集しています!
少しでも興味を持ってくださいましたら、下記のリンクをご覧ください。

recruit.jobcan.jp