5月頃にStable Diffusion3の論文が発表されていたので、眺めてみようと思います。
- Rectified Flowモデルの新規性
- Rectified Flowモデルのアーキテクチャ
- MM-DiTの処理
- Rectified Flowモデル(Stable Diffusion3)の学習方法
- Rectified Flowモデル(Stable Diffusion3)の性能評価
- 3種類のモデル
- 最後に
Rectified Flowモデルの新規性
見た感じ新規性は↓あたりかと思います。
- Text Encoderの改良
- ノイズスケジューラーを改良したこと
- 従来のLinearなスケジュールではなく、Logit-NormalやModeサンプリングなどの手法を使用
- 新しいトランスフォーマーアーキテクチャ(Diffusion-Transformer, DiT)を導入
- txt2imgにおいて、既存の最先端モデルを上回る性能を実証
Rectified Flowモデルのアーキテクチャ
簡単に整理すると、大きく変化した部分はUNetが廃止になってMM-DiTに置き換わった部分で、他についてはモデルやパラメータを増やして、性能向上を図った形と思います。
Text Encoder
従来のStable Diffusion
CLIPテキストエンコーダー(大規模な画像とテキストのペアで事前学習されたエンコーダー)を使用し、テキストプロンプトをエンコードして画像生成モデルに入力します。
Stable Diffusion3
CLIPとT5テキストエンコーダーの組み合わせを使用。CLIPのL/14モデルとOpenCLIPのbigG/14モデルに加え、T5-v1.1-XXLモデルを使用し、それぞれのエンコーダーから得られた出力を結合して使用します。
UNet
従来のStable Diffusion
生成モデルの中核としてUNetアーキテクチャを使用。多層の畳み込みニューラルネットワークで、画像のエンコーディングとデコーディングを行います。
Stable Diffusion3
UNetではなく、DiT(Diffusion Transformer)アーキテクチャを採用。DiTは、テキストと画像のトークンを一緒に扱うことができるトランスフォーマーベースのモデルです。特に、テキストと画像のトークンを別々の重みで処理しながら、両者間で双方向の情報の流れを実現するMM-DiT(Multimodal-DiT)ブロックを使用。
VAE
従来のStable Diffusion
VAE(変分オートエンコーダー)を使用して、画像を低次元の潜在空間に圧縮し、その潜在表現を使って生成プロセスを行います。これにより、効率的な画像生成と高解像度の画像出力が可能になります。
Stable Diffusion3
Stable Diffusionと同様に、事前訓練されたオートエンコーダーを使用しますが、異なるポイントは、潜在空間の次元数が増加していることです。表現能力が上がったので性能が向上しています。
MM-DiTの処理
MM-DiTのアーキテクチャは、テキストと画像の両方に対応するように設計されています。以下に処理の流れを簡単に示します。
前処理
1. 画像エンコーディング(image embedding)
元のRGB画像 を事前に学習されたVAEによって低次元の潜在表現に変換します。得られた潜在表現は、空間的な位置エンコーディングとともに、2x2のパッチに分割します。
2. テキストエンコーディング(text embedding)
テキストを事前学習されたモデル(CLIPとT5)を用いて埋め込みベクトルに変換します。CLIPの出力およびT5の最終的なhidden stateを使用します。
MM-DiTブロック
3. Modulationと線形変換
マルチモダルなモデルでは、モダリティのどれかに判断が偏ってしまうと良くないので、値の調整が行われることがほとんどです。ここでは標準化などの処理を通して値の調整を行っています。
- Modulation
- embeddingのスケーリングやバイアスの調整を行います。この時、ノイズレベルトークンの情報を利用して処理を調整します
- ※noise level tokenについて論文内に具体的な言及は見つかりませんでしたが、おそらく学習時のノイズスケジューリングに対応して、現在のステップがどの程度のノイズレベルかを示すものです。
- Linear
- 埋め込みベクトルを線形変換し、次元を次の入力に適した形に圧縮します。
4. 潜在表現の連結
1と2と連結し、一つのシーケンスにします。
5. Joint Attention
4を受け取り、QKV Attentionに入力します。ここでは両方の埋め込み間で相互の情報を共有し、統合された表現を生成します。
6. Text StreamとImage Streamに分離
画像トークンとテキストトークンの位置が明確に決まっているため、注意機構の出力からそれぞれのトークンを元の位置に基づいて分離します。
7. 再度の線形変換とModulation
6で得られた2つのStreamを加工し、次のMLPに適した形状に変換します。これにより、2つのストリームがそれぞれの特徴を強調しながら、必要に応じて調整されます。
ここでもNoise Level Tokenにより処理の調整が行われます。
8. MLP
線形変換とモジュレーションの後のデータを、非線形変換を含む複数の層に通じてさらに処理します。これにより、より複雑な特徴を学習し、次のブロックに渡します。
9. 繰り返し
ブロックの処理全体が、MM-DiTブロックを積み重ねた回数だけ繰り返されます。
後処理
9. 潜在表現から画像生成
最終的にMM-DiTブロックから得られた潜在表現をVAEでコードして画像を得ます。
Rectified Flowモデル(Stable Diffusion3)の学習方法
ノイズスケジュールとサンプリング手法
従来のStable Diffusion
Stable Diffusionでは、学習時に固定されたノイズスケジュール(例:線形または余弦スケジュール)を使用します。データからノイズへの拡散プロセスを通じて、逆にノイズからデータを生成する方法を学習します。
Stable Diffusion3
Rectified Flowモデルでは、新しいノイズスケジュールを導入し、特定のタイムステップにおけるノイズサンプリングをより効率的に行うことを目指しています。Logit-NormalやModeサンプリングなどの手法を使用して、ノイズのスケールを調整し、中間ステップでのトレーニング効果を高めています。
Rectified Flowモデル(Stable Diffusion3)の性能評価
FIDやCLIPといった指標で既存のモデルより良い数値を出しています。正直数値を見てもへーとしかならないので他の評価内容は端折ります。
論文に載ってる生成画像も引用しておきます。プロンプトを文章にして生成することが可能になっているようです。さらに文字も破綻なく生成できています。
3種類のモデル
Stable Diffusion3は3種類のモデルがあり、モデル間で異なっている点はおそらくMM-DiTブロックが何層積みあがっているかだと思われます。以下の表にパラメータ数を示しておきます。論文に38層の場合は8Bと記載があるので、Ultraは38層と記載しています。
モデル | ブロック数 | パラメータ数 |
---|---|---|
Small | - | 800M |
Medium | - | 2B |
Large | - | 4B |
Ultra | 38 | 8B |
最後に
使う側の視点で覚えとく点は↓あたりです。
- TextEncoderの理解力が向上
- さらに文章を解釈可能に
- 文字が破綻なく生成できそう
- デフォルトでマルチモダル(テキストと画像)生成をサポート
- 生成精度が向上
- パラメータは増加、必要なモデルも増加
- 少し重め、遅めにはなりそう
MM-DiTブロック周りはまだつかみ切れていない部分がありますが、外観は把握できた気になったので、いったん読むのはここまで。
モデルもつい先日公開されたようです。以下の記事でStable Diffusion3を使用する手順を紹介しているので読んでみてください。
その他、Stable Diffusion周りで読んだ論文は以下の記事でまとめているので、興味ある方はぜひご活用ください。