生成AI活用の介護システムを作ろうとして、”動くのに使えない”からLLM高速化アルゴリズムを作った話

LLM高速化アルゴリズム

重みを上下に割った──VRAMを半分にするHigh/Low分割とCUDA地獄


②の続き、別の壁

②では非同期プレフェッチをやった。
多少は改善した。でも前提が崩れた。

帯域に頼る設計は、その帯域が消えた瞬間に破綻する。
ここで方針を変えた。

重みそのものを分割する。


① 全部載せると死ぬ(従来構造)

フル精度の重みとKVキャッシュをすべてVRAMに載せると、トークンが増えるにつれてメモリが爆発する。

フル精度重み+KVキャッシュでVRAMが飽和する

なぜ上位ビットと下位ビットを分けるのか

FP16は16bitでできている。

FP16(16bit)
 ├ 符号(1bit)
 ├ 指数部(5bit)  ← 数値の大きさ
 └ 仮数部(10bit) ← 精度

この構造は非対称だ。

上位側には符号と指数部が入る。つまり「この重みがだいたいどのくらいの大きさか」は上位8bitだけで決まる。下位8bitは仮数部の細かい精度だ。「だいたいの計算」には必要ない。精度の補正項に過ぎない。

だから分けられる。

High(上位8bit)→ VRAM常駐
Low(下位8bit) → DRAM待機
uint16_t raw = *reinterpret_cast<uint16_t*>(&w);
uint8_t high = raw >> 8;
uint8_t low  = raw & 0xFF;

② High/Low分割の構造

Highだけで計算を回し、Lowは必要なときだけ補完する。

HighはVRAM常駐、LowはDRAMから必要時のみ補完
Highだけで計算を回す(常時)
Lowは非同期で転送
必要なときだけ合成
Δ = logit_top1 - logit_top2
Δ < threshold → Lowを待つ

なぜEmbeddingとHeadは分けないのか

EmbeddingとHeadは基準点なので分割しない。

  • Embeddingがズレる → 入力全体が歪む
  • Headがズレる → 出力順位が壊れる

なぜこれが生成AIの問題なのか

入力 → モデル → 出力(1回)
トークン生成 → 入力に戻す → 繰り返し

KVキャッシュが増え続け、IOが支配する。

compute < IO

CUDA地獄:1週間

設計崩壊

PYBIND11_MODULE(...) // 複数定義 → 死亡

対策:1箇所のみ

リンク崩壊

ImportError: undefined symbol
  • .cu未リンク
  • ABI不一致
  • リンク順

ビルド成功 ≠ 動く

stubで突破

HIGHONLY_KERNEL=stub

まず動かす。


③ 圧縮+直前デコード(完成形)

圧縮したまま保持し、GEMM直前だけデコードする。

圧縮状態で保持し、GEMM直前のみデコード
4bit blockwise
+差分辞書
+融合デコード

ただし未実装。

なぜなら安価なDRAMに逃がそうとしたが、DRAM暴騰で死ぬことが見えたからだ。


この設計ができた理由

生成AIで試行回数を回した。

仮説 → 実装 → エラー → 次

知識ではなく試行回数


まとめ

  • KVで詰まる
  • Prefetchで逃げる
  • DRAMで死ぬ
  • 重み分割
  • CUDAで死ぬ
  • 動く
  • 遅い
  • また死ぬ

問題はメモリじゃなかった。

IOが支配している。


次はIOが支配的だと気づいた話。


コメント

タイトルとURLをコピーしました