#DiT(Diffusion Transformer)从数据到训练:一条完整的教学级讲解

这篇笔记整理自一次关于 DiT(Diffusion Transformer) 的讲解。目标不是复现论文级训练系统,而是把 DiT 的核心机制讲清楚:数据从哪里来、每一步张量长什么样、模型输入输出是什么、loss 怎么算、训练和采样到底在做什么。

一句话概括:DiT 不是一种新的扩散理论,而是把扩散模型里的 denoiser / vector-field backbone 从 U-Net 换成 Transformer。

传统扩散模型常见结构是:

noisy image / latent x_t
  -> U-Net
  -> predicted noise / velocity

DiT 的结构则更像 ViT:

noisy image / latent x_t
  -> patchify 成 tokens
  -> 加 position / timestep / condition
  -> Transformer blocks
  -> predict noise / velocity per patch
  -> unpatchify 回图像或 latent

如果只记一个核心公式,训练时通常是在做:

x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * ε
loss = || ε - ε_θ(x_t, t, y) ||²

这里的 x_0 是干净图片或 latent,ε 是随机高斯噪声,x_t 是第 t 步的加噪样本,模型的任务是根据 x_t、时间步 t 和条件 y 预测噪声。


#1. DiT 要解决什么问题?

扩散模型的生成过程可以理解为两件事:

  1. 正向过程:把干净数据逐步加噪,最后变成接近纯噪声。
  2. 反向过程:训练一个神经网络,从噪声中一步步预测如何去噪,最终恢复出数据分布中的样本。

早期图像扩散模型大多使用 U-Net 作为 denoiser。U-Net 的优势是卷积归纳偏置强、多尺度结构自然、适合局部纹理与空间细节。但随着生成模型规模扩大,研究者发现 Transformer 有几个很重要的优点:

  • 更适合 scaling:模型宽度、深度、数据量增大时,Transformer 的扩展规律更稳定。
  • 更统一的 token 接口:图片、视频、文本、动作、音频都可以转成 token 序列。
  • 更容易吸收大模型经验:ViT、GPT、LLM 中积累的大量架构和训练经验可以迁移过来。
  • 更适合多模态融合:条件可以是 class label、text embedding、image embedding、action token,甚至世界模型状态。

因此,DiT 的本质是:用 Transformer 来参数化扩散模型中的去噪函数


#2. 从数据开始:以 CIFAR-10 教学版 DiT 为例

为了讲清楚训练流程,可以先不做 Stable Diffusion 那种 latent diffusion,而是在 CIFAR-10 的 pixel space 上训练一个小 DiT。

原始数据是:

x_0: [B, 3, 32, 32]
y:   [B]

其中:

  • B 是 batch size,比如 128。
  • 3 是 RGB 通道。
  • 32 x 32 是 CIFAR-10 图片分辨率。
  • y 是类别标签,范围为 0 到 9。

通常会先把图片从 [0, 1] 归一化到 [-1, 1]

transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

所以一个训练 batch 进入扩散过程前大致是:

x_0: [128, 3, 32, 32], value range ≈ [-1, 1]
y:   [128]

#3. 正向扩散:从干净图片得到 noisy image

训练时,每张图片会随机采样一个时间步:

t: [B]

然后采样同形状高斯噪声:

ε: [B, 3, 32, 32]

根据 diffusion schedule,可以直接构造第 t 步的加噪样本:

x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * ε

所以:

x_0: [128, 3, 32, 32]
ε:   [128, 3, 32, 32]
t:   [128]
x_t: [128, 3, 32, 32]

模型看到的是 x_t,不是 x_0。训练目标是让模型从 x_t 里预测出刚才加进去的噪声 ε


#4. Patchify:把图像变成 Transformer tokens

Transformer 不直接吃 [B, C, H, W],而是吃 token 序列。所以 DiT 首先要把图像切 patch。

但这里很容易有一个误解:patchify 不是把图片“压缩成一句话”,也不是简单地丢掉空间结构;它是把二维图像重新排列成一串局部块,并把每个局部块投影到 Transformer 的 hidden dimension。

可以把它类比成 NLP 里的 tokenization:

NLP:
  sentence -> word/subword tokens -> embedding vectors

DiT / ViT:
  image or latent -> patch tokens -> embedding vectors

区别是,NLP 的 token 本来就是离散 ID,而图像 patch 是一个连续的小张量,需要通过线性层或卷积投影成向量。

假设:

image_size = 32
patch_size = 4
in_channels = 3
hidden_size = 384

那么输入 noisy image 是:

x_t: [B, 3, 32, 32]

#4.1 一个 patch 到底长什么样?

patch_size = 4 表示每个 patch 覆盖图像上的一个 4 x 4 小方块。因为图片有 3 个 RGB 通道,所以一个 patch 原始包含:

3 * 4 * 4 = 48 个数

也就是说,一个 patch 的原始形状可以看成:

patch: [3, 4, 4]

如果把它展平,就是:

patch_flat: [48]

对于一张 32 x 32 的图,横向有:

32 / 4 = 8 个 patch

纵向也有:

32 / 4 = 8 个 patch

所以每张图总共有:

N = 8 * 8 = 64 个 patch

这 64 个 patch 可以按 raster scan 顺序排成一串:

(row 0, col 0), (row 0, col 1), ..., (row 0, col 7),
(row 1, col 0), (row 1, col 1), ..., (row 7, col 7)

于是,从“显式切块 + 展平”的角度看,patchify 是:

x_t: [B, 3, 32, 32]
  -> split into 8 * 8 patches
     [B, 8, 8, 3, 4, 4]
  -> flatten each patch
     [B, 8, 8, 48]
  -> flatten grid positions
     [B, 64, 48]

到这一步,每个 token 还只是一个 48 维的原始像素块向量。

#4.2 为什么还要做 patch embedding?

Transformer block 里 attention 和 MLP 的 hidden size 通常是一个较大的维度,比如 384、768、1024。原始 patch 只有 48 维,不适合直接作为 Transformer 的 hidden state。

所以还要做一个线性投影:

[B, 64, 48] -> Linear(48, 384) -> [B, 64, 384]

也就是说,每个 patch token 从一个 48 维像素块,被映射成一个 384 维语义/特征向量。

这和 NLP 里的 embedding lookup 很像:

word_id -> embedding vector

只不过图像里没有 word_id,而是:

patch pixels -> linear projection -> patch embedding

#4.3 为什么代码里常用 Conv2d 实现?

虽然概念上可以写成“切 patch、flatten、linear projection”,但代码里通常直接用一个卷积完成:

nn.Conv2d(
    in_channels=3,
    out_channels=hidden_size,
    kernel_size=patch_size,
    stride=patch_size,
)

在当前例子里就是:

nn.Conv2d(
    in_channels=3,
    out_channels=384,
    kernel_size=4,
    stride=4,
)

这个卷积有两个关键参数:

  • kernel_size = 4:每次看一个 4 x 4 patch。
  • stride = 4:卷积窗口每次移动 4 个像素,所以 patch 之间不重叠。

它的输出 shape 是:

x_t: [B, 3, 32, 32]
  -> Conv2d(kernel=4, stride=4, out_channels=384)
     [B, 384, 8, 8]

这里的 [8, 8] 就是 patch grid。每个空间位置 (i, j) 对应原图中的一个 4 x 4 patch;每个位置上的 384 个 channel,就是这个 patch 的 embedding。

然后再把二维 patch grid 展平成 token sequence:

[B, 384, 8, 8]
  -> flatten spatial dimensions
     [B, 384, 64]
  -> transpose channel and sequence dimensions
     [B, 64, 384]

最终得到:

patch tokens: [B, N, D] = [128, 64, 384]

从数学上看,Conv2d(kernel_size=patch_size, stride=patch_size) 基本等价于:

每个 patch 展平成 48 维
  -> 乘一个共享的 Linear(48, 384)

“共享”指的是:所有位置的 patch 都用同一个投影矩阵。这和卷积的权重共享是一致的。

#4.4 Patchify 后空间信息丢了吗?

patchify 后,[B, 3, 32, 32] 变成 [B, 64, 384],二维结构确实被摊平成一维序列了。但这并不意味着空间信息完全丢失,因为有两层机制保留空间关系:

  1. token 顺序:第 0 个 token、第 1 个 token……本身对应固定的 patch grid 位置。
  2. position embedding:后面会给每个 patch token 加上位置编码,显式告诉模型它在图像中的二维位置。

如果没有 position embedding,Transformer 的 self-attention 对 token 顺序本身不敏感,模型很难知道某个 patch 来自左上角还是右下角。因此,patchify 后紧接着加 position embedding 是非常关键的。

#4.5 patch_size 的取舍:token 数、计算量与细节

patch size 会直接决定 token 数,而 attention 的计算量大致随 token 数平方增长。

对于 32 x 32 图像:

patch_size = 2 -> N = 16 * 16 = 256 tokens
patch_size = 4 -> N = 8  * 8  = 64 tokens
patch_size = 8 -> N = 4  * 4  = 16 tokens

更小的 patch size:

  • 优点:保留更细粒度的局部信息。
  • 缺点:token 数更多,attention 更贵。

更大的 patch size:

  • 优点:token 数更少,训练更便宜。
  • 缺点:每个 token 覆盖区域更大,细节建模更粗。

所以 patch size 本质上是在做一个 trade-off:

空间细节分辨率 vs Transformer 计算成本

在高分辨率 latent diffusion 里,因为 VAE 已经把 512 x 512 图片压到类似 [4, 64, 64] 的 latent,DiT 通常是在 latent grid 上 patchify,而不是直接在原始像素上 patchify。这样可以显著降低 token 数。

#4.6 Patchify 和 unpatchify 是一对逆过程吗?

概念上,它们是一对“形状上的逆过程”,但不是严格的信息逆变换。

训练时:

x_t image/latent
  -> patchify + embedding
  -> Transformer
  -> predict patch outputs
  -> unpatchify
  -> ε_pred image/latent

其中 patchify 开头的投影是:

[48] -> [384]

它是把 patch 放到 hidden space 里,不要求可逆。真正用于输出的是 final layer:

[384] -> [48]

然后再把 64 个 [48] patch 拼回:

[B, 64, 48] -> [B, 3, 32, 32]

所以可以理解为:

  • patchify:把 noisy image/latent 切成 token,并升维成 hidden states。
  • unpatchify:把每个 token 的预测结果还原成对应 patch 的噪声/速度场,再拼回原图形状。

最终 DiT 不是直接输出一串 token 给用户,而是输出和输入 x_t 同形状的预测噪声:

ε_pred: [B, 3, 32, 32]

这就是为什么 DiT 虽然中间是 Transformer token 序列,但训练目标仍然是图像/latent 空间里的 denoising loss。


#5. Position、timestep 与 class condition

DiT 需要三类信息:

  1. 图像内容 token:来自 x_t 的 patch tokens。
  2. 位置编码:告诉模型每个 patch 在图像中的位置。
  3. 条件信息:包括 timestep t 和类别/文本条件 y

#5.1 位置编码

位置编码加到 patch tokens 上:

x: [B, 64, 384]
pos_embed: [1, 64, 384]
x + pos_embed: [B, 64, 384]

原始 DiT 常使用 fixed 2D sin-cos positional embedding;教学实现里也可以用 learnable positional embedding,方便理解。

#5.2 Timestep embedding

t 是扩散过程中的时间步,shape 是:

t: [B]

通常先用 sinusoidal embedding,再接一个 MLP:

t -> sinusoidal embedding -> MLP -> t_emb

得到:

t_emb: [B, 384]

#5.3 Class condition

对于 CIFAR-10 class-conditional generation,类别标签 y 通过 embedding table 得到:

y:     [B]
y_emb: [B, 384]

然后把时间条件和类别条件相加:

c = t_emb + y_emb
c: [B, 384]

这个 c 会被送进每个 DiT block,用来调制 LayerNorm 和 residual branch。


#6. DiT Block:Transformer + AdaLN-Zero

DiT block 可以理解为一个被条件 c 调制的 Transformer block。

普通 Transformer block 大致是:

x -> LayerNorm -> Self-Attention -> residual
x -> LayerNorm -> MLP            -> residual

DiT 的关键改造是 AdaLN-Zero。它不是简单地把 timestep/class embedding 加到 token 上,而是用条件向量 c 生成 LayerNorm 的 shift、scale 和 residual gate。

一个典型 DiT block 会从 c 中生成六组参数:

shift_msa, scale_msa, gate_msa,
shift_mlp, scale_mlp, gate_mlp = adaLN_modulation(c).chunk(6, dim=1)

含义是:

  • shift_msa, scale_msa:调制 attention 前的 LayerNorm 输出。
  • gate_msa:控制 attention residual branch 的强度。
  • shift_mlp, scale_mlp:调制 MLP 前的 LayerNorm 输出。
  • gate_mlp:控制 MLP residual branch 的强度。

直觉上,条件 c 在告诉模型:

在当前 timestep 和当前类别条件下,attention/MLP 应该怎样处理这些 noisy patch tokens。

AdaLN-Zero 中的 “Zero” 指的是调制层和最终输出层常常被零初始化,使模型一开始接近恒等映射或零输出,从而稳定训练很深的 Transformer denoiser。


#7. Final layer 与 unpatchify:从 token 回到噪声图

经过若干 DiT blocks 后,token shape 仍然是:

x: [B, 64, 384]

最终需要把每个 token 映射回一个 patch 的像素噪声。

在 CIFAR-10 教学设定中:

patch_size = 4
channels = 3
patch_dim = 4 * 4 * 3 = 48

所以 final linear layer 输出:

[B, 64, 384] -> [B, 64, 48]

然后 unpatchify:

[B, 64, 48]
  -> [B, 8, 8, 4, 4, 3]
  -> [B, 3, 32, 32]

最终模型输出:

ε_pred: [B, 3, 32, 32]

和真实噪声 ε 形状完全一致。


#8. 一次训练 step 的完整 shape 流

把所有步骤串起来,一次训练 step 是:

1. 从 dataloader 取数据
   x0: [128, 3, 32, 32]
   y:  [128]

2. 采样 timestep 和噪声
   t:     [128]
   ε:     [128, 3, 32, 32]

3. 构造 noisy image
   x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * ε
   x_t: [128, 3, 32, 32]

4. DiT patch embedding
   x_t -> patch tokens
   tokens: [128, 64, 384]

5. 加位置编码
   tokens + pos_embed: [128, 64, 384]

6. 构造条件向量
   t_emb: [128, 384]
   y_emb: [128, 384]
   c:     [128, 384]

7. 经过 DiT blocks
   [128, 64, 384] -> [128, 64, 384]

8. final projection + unpatchify
   [128, 64, 384] -> [128, 64, 48] -> [128, 3, 32, 32]

9. 计算 loss
   loss = MSE(ε_pred, ε)
   loss: scalar

对应的核心训练代码就是:

t = torch.randint(0, diffusion_steps, size=(B,), device=device)
noise = torch.randn_like(x0)
x_t = diffusion.q_sample(x0, t, noise)
pred_noise = model(x_t, t, y)
loss = F.mse_loss(pred_noise, noise)

optimizer.zero_grad()
loss.backward()
optimizer.step()

这就是 DiT 训练最核心的闭环。


#9. 教学版 PyTorch 实现骨架

下面是一个教学级 CIFAR-10 DiT 的模块拆分。完整代码可以按这个结构组织成 train_dit_cifar10.py

@dataclass
class Config:
    image_size: int = 32
    in_channels: int = 3
    num_classes: int = 10
    patch_size: int = 4
    hidden_size: int = 384
    depth: int = 8
    num_heads: int = 6
    mlp_ratio: float = 4.0
    diffusion_steps: int = 1000
    batch_size: int = 128
    lr: float = 1e-4
    epochs: int = 100
    data_dir: str = "./data"
    out_dir: str = "./outputs"

核心模块包括:

CIFAR10 dataloader
GaussianDiffusion
PatchEmbed
TimestepEmbedder
ClassEmbedder
DiTBlock with AdaLN-Zero
FinalLayer
DiT
training loop
sampling loop

如果显存较小,可以把配置降到:

batch_size = 32
hidden_size = 256
depth = 4
num_heads = 4

需要强调:这个版本是为了理解 DiT 的教学实现,不是论文级实现。它有几个简化:

  • 使用 pixel-space CIFAR-10,而不是 latent diffusion。
  • 使用简单 DDPM 训练和采样,而不是 DDIM、DPM-Solver、Flow Matching 或 Rectified Flow。
  • 位置编码可以用 learnable embedding,而原始 DiT 常用 fixed 2D sin-cos embedding。
  • 没有加入 AMP、EMA、分布式训练、FlashAttention、大规模数据管线等工程组件。

#10. 采样:从纯噪声反向生成图片

训练时模型学习的是:

ε_θ(x_t, t, y)

采样时则从纯噪声开始:

x_T ~ N(0, I)

然后从 T-10 逐步反推:

x = torch.randn(shape, device=device)
for i in reversed(range(timesteps)):
    t = torch.full((B,), i, device=device, dtype=torch.long)
    x = p_sample(model, x, t, y)

每一步都用 DiT 预测当前噪声,再根据 DDPM 公式得到更干净一点的 x_{t-1}

最后得到:

x_0_pred: [B, 3, 32, 32]

再把它从 [-1, 1] 转回 [0, 1],保存成图片。


#11. 如果换成 Stable Diffusion / 大模型版本,会发生什么?

真实大模型很少直接在 pixel space 做高分辨率扩散,因为成本太高。Stable Diffusion 类模型通常在 VAE latent space 里训练。

此时数据流从:

image: [B, 3, 512, 512]

变成:

VAE encoder -> latent z_0: [B, 4, 64, 64]

然后 DiT 不再处理 RGB 图像,而是处理 latent:

z_t: [B, 4, 64, 64]
  -> patchify
  -> Transformer / DiT / MMDiT
  -> predict noise or velocity in latent space
  -> VAE decoder -> image

所以大模型版 DiT 的本质没有变,只是:

  • 输入从 image pixel 变成 VAE latent。
  • 条件从 class label 变成 text embedding / multimodal embedding。
  • 训练目标可能从 epsilon prediction 变成 v-prediction、flow matching 或 rectified flow。
  • Transformer 结构可能从单流 DiT 变成双流/多流 MMDiT。

#12. DiT 与 ViT、GPT、MMDiT、Flux、Sora 的关系

可以这样理解 DiT 在模型谱系中的位置:

ViT:
  image -> patches -> Transformer -> classification

DiT:
  noisy image/latent -> patches -> Transformer -> denoising / vector field

MMDiT:
  image/video latent tokens + text tokens -> multimodal Transformer -> denoising / flow

Video DiT / Sora-like world model:
  noisy video latent -> spatiotemporal tokens -> Transformer -> video generation / prediction

DiT 把扩散模型和 Transformer scaling 结合起来,因此它自然成为图像、视频、多模态生成和世界模型中的重要底座。

从这个角度看,DiT 的意义不只是“图像生成模型换了一个 backbone”,而是让扩散模型进入了更统一的 token-based scaling 框架:

  • 图像可以是 token。
  • 视频可以是 token。
  • 文本条件可以是 token。
  • 动作和状态也可以是 token。
  • 未来的世界模型、robotics policy、latent reasoning 也可能共享类似接口。

#13. 最重要的理解

最后可以把 DiT 的主线压缩成一句话:

图片/latent x0
  -> 加噪得到 x_t
  -> patchify 成 tokens
  -> 加位置编码
  -> 加 timestep/class/text condition
  -> DiT blocks
  -> 预测 patch-level noise / velocity
  -> unpatchify 成整图噪声或 latent velocity
  -> 和真实噪声/速度场做回归 loss

所以,理解 DiT 的关键不是背结构名,而是抓住三件事:

  1. 扩散训练目标:模型在每个时间步学习如何从 noisy sample 中预测噪声或速度场。
  2. Transformer tokenization:图像或 latent 被切成 patch tokens,进入 Transformer。
  3. 条件调制机制:timestep、class/text condition 通过 AdaLN-Zero 等方式控制每一层的去噪行为。

只要这三点打通,DiT、MMDiT、Flux、视频 diffusion transformer、甚至一部分 world model 的结构都会变得更容易理解。