#DiT(Diffusion Transformer)从数据到训练:一条完整的教学级讲解
这篇笔记整理自一次关于 DiT(Diffusion Transformer) 的讲解。目标不是复现论文级训练系统,而是把 DiT 的核心机制讲清楚:数据从哪里来、每一步张量长什么样、模型输入输出是什么、loss 怎么算、训练和采样到底在做什么。
一句话概括:DiT 不是一种新的扩散理论,而是把扩散模型里的 denoiser / vector-field backbone 从 U-Net 换成 Transformer。
传统扩散模型常见结构是:
noisy image / latent x_t
-> U-Net
-> predicted noise / velocity
DiT 的结构则更像 ViT:
noisy image / latent x_t
-> patchify 成 tokens
-> 加 position / timestep / condition
-> Transformer blocks
-> predict noise / velocity per patch
-> unpatchify 回图像或 latent
如果只记一个核心公式,训练时通常是在做:
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * ε
loss = || ε - ε_θ(x_t, t, y) ||²
这里的 x_0 是干净图片或 latent,ε 是随机高斯噪声,x_t 是第 t 步的加噪样本,模型的任务是根据 x_t、时间步 t 和条件 y 预测噪声。
#1. DiT 要解决什么问题?
扩散模型的生成过程可以理解为两件事:
- 正向过程:把干净数据逐步加噪,最后变成接近纯噪声。
- 反向过程:训练一个神经网络,从噪声中一步步预测如何去噪,最终恢复出数据分布中的样本。
早期图像扩散模型大多使用 U-Net 作为 denoiser。U-Net 的优势是卷积归纳偏置强、多尺度结构自然、适合局部纹理与空间细节。但随着生成模型规模扩大,研究者发现 Transformer 有几个很重要的优点:
- 更适合 scaling:模型宽度、深度、数据量增大时,Transformer 的扩展规律更稳定。
- 更统一的 token 接口:图片、视频、文本、动作、音频都可以转成 token 序列。
- 更容易吸收大模型经验:ViT、GPT、LLM 中积累的大量架构和训练经验可以迁移过来。
- 更适合多模态融合:条件可以是 class label、text embedding、image embedding、action token,甚至世界模型状态。
因此,DiT 的本质是:用 Transformer 来参数化扩散模型中的去噪函数。
#2. 从数据开始:以 CIFAR-10 教学版 DiT 为例
为了讲清楚训练流程,可以先不做 Stable Diffusion 那种 latent diffusion,而是在 CIFAR-10 的 pixel space 上训练一个小 DiT。
原始数据是:
x_0: [B, 3, 32, 32]
y: [B]
其中:
B是 batch size,比如 128。3是 RGB 通道。32 x 32是 CIFAR-10 图片分辨率。y是类别标签,范围为 0 到 9。
通常会先把图片从 [0, 1] 归一化到 [-1, 1]:
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
所以一个训练 batch 进入扩散过程前大致是:
x_0: [128, 3, 32, 32], value range ≈ [-1, 1]
y: [128]
#3. 正向扩散:从干净图片得到 noisy image
训练时,每张图片会随机采样一个时间步:
t: [B]
然后采样同形状高斯噪声:
ε: [B, 3, 32, 32]
根据 diffusion schedule,可以直接构造第 t 步的加噪样本:
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * ε
所以:
x_0: [128, 3, 32, 32]
ε: [128, 3, 32, 32]
t: [128]
x_t: [128, 3, 32, 32]
模型看到的是 x_t,不是 x_0。训练目标是让模型从 x_t 里预测出刚才加进去的噪声 ε。
#4. Patchify:把图像变成 Transformer tokens
Transformer 不直接吃 [B, C, H, W],而是吃 token 序列。所以 DiT 首先要把图像切 patch。
但这里很容易有一个误解:patchify 不是把图片“压缩成一句话”,也不是简单地丢掉空间结构;它是把二维图像重新排列成一串局部块,并把每个局部块投影到 Transformer 的 hidden dimension。
可以把它类比成 NLP 里的 tokenization:
NLP:
sentence -> word/subword tokens -> embedding vectors
DiT / ViT:
image or latent -> patch tokens -> embedding vectors
区别是,NLP 的 token 本来就是离散 ID,而图像 patch 是一个连续的小张量,需要通过线性层或卷积投影成向量。
假设:
image_size = 32
patch_size = 4
in_channels = 3
hidden_size = 384
那么输入 noisy image 是:
x_t: [B, 3, 32, 32]
#4.1 一个 patch 到底长什么样?
patch_size = 4 表示每个 patch 覆盖图像上的一个 4 x 4 小方块。因为图片有 3 个 RGB 通道,所以一个 patch 原始包含:
3 * 4 * 4 = 48 个数
也就是说,一个 patch 的原始形状可以看成:
patch: [3, 4, 4]
如果把它展平,就是:
patch_flat: [48]
对于一张 32 x 32 的图,横向有:
32 / 4 = 8 个 patch
纵向也有:
32 / 4 = 8 个 patch
所以每张图总共有:
N = 8 * 8 = 64 个 patch
这 64 个 patch 可以按 raster scan 顺序排成一串:
(row 0, col 0), (row 0, col 1), ..., (row 0, col 7),
(row 1, col 0), (row 1, col 1), ..., (row 7, col 7)
于是,从“显式切块 + 展平”的角度看,patchify 是:
x_t: [B, 3, 32, 32]
-> split into 8 * 8 patches
[B, 8, 8, 3, 4, 4]
-> flatten each patch
[B, 8, 8, 48]
-> flatten grid positions
[B, 64, 48]
到这一步,每个 token 还只是一个 48 维的原始像素块向量。
#4.2 为什么还要做 patch embedding?
Transformer block 里 attention 和 MLP 的 hidden size 通常是一个较大的维度,比如 384、768、1024。原始 patch 只有 48 维,不适合直接作为 Transformer 的 hidden state。
所以还要做一个线性投影:
[B, 64, 48] -> Linear(48, 384) -> [B, 64, 384]
也就是说,每个 patch token 从一个 48 维像素块,被映射成一个 384 维语义/特征向量。
这和 NLP 里的 embedding lookup 很像:
word_id -> embedding vector
只不过图像里没有 word_id,而是:
patch pixels -> linear projection -> patch embedding
#4.3 为什么代码里常用 Conv2d 实现?
虽然概念上可以写成“切 patch、flatten、linear projection”,但代码里通常直接用一个卷积完成:
nn.Conv2d(
in_channels=3,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
)
在当前例子里就是:
nn.Conv2d(
in_channels=3,
out_channels=384,
kernel_size=4,
stride=4,
)
这个卷积有两个关键参数:
kernel_size = 4:每次看一个4 x 4patch。stride = 4:卷积窗口每次移动 4 个像素,所以 patch 之间不重叠。
它的输出 shape 是:
x_t: [B, 3, 32, 32]
-> Conv2d(kernel=4, stride=4, out_channels=384)
[B, 384, 8, 8]
这里的 [8, 8] 就是 patch grid。每个空间位置 (i, j) 对应原图中的一个 4 x 4 patch;每个位置上的 384 个 channel,就是这个 patch 的 embedding。
然后再把二维 patch grid 展平成 token sequence:
[B, 384, 8, 8]
-> flatten spatial dimensions
[B, 384, 64]
-> transpose channel and sequence dimensions
[B, 64, 384]
最终得到:
patch tokens: [B, N, D] = [128, 64, 384]
从数学上看,Conv2d(kernel_size=patch_size, stride=patch_size) 基本等价于:
每个 patch 展平成 48 维
-> 乘一个共享的 Linear(48, 384)
“共享”指的是:所有位置的 patch 都用同一个投影矩阵。这和卷积的权重共享是一致的。
#4.4 Patchify 后空间信息丢了吗?
patchify 后,[B, 3, 32, 32] 变成 [B, 64, 384],二维结构确实被摊平成一维序列了。但这并不意味着空间信息完全丢失,因为有两层机制保留空间关系:
- token 顺序:第 0 个 token、第 1 个 token……本身对应固定的 patch grid 位置。
- position embedding:后面会给每个 patch token 加上位置编码,显式告诉模型它在图像中的二维位置。
如果没有 position embedding,Transformer 的 self-attention 对 token 顺序本身不敏感,模型很难知道某个 patch 来自左上角还是右下角。因此,patchify 后紧接着加 position embedding 是非常关键的。
#4.5 patch_size 的取舍:token 数、计算量与细节
patch size 会直接决定 token 数,而 attention 的计算量大致随 token 数平方增长。
对于 32 x 32 图像:
patch_size = 2 -> N = 16 * 16 = 256 tokens
patch_size = 4 -> N = 8 * 8 = 64 tokens
patch_size = 8 -> N = 4 * 4 = 16 tokens
更小的 patch size:
- 优点:保留更细粒度的局部信息。
- 缺点:token 数更多,attention 更贵。
更大的 patch size:
- 优点:token 数更少,训练更便宜。
- 缺点:每个 token 覆盖区域更大,细节建模更粗。
所以 patch size 本质上是在做一个 trade-off:
空间细节分辨率 vs Transformer 计算成本
在高分辨率 latent diffusion 里,因为 VAE 已经把 512 x 512 图片压到类似 [4, 64, 64] 的 latent,DiT 通常是在 latent grid 上 patchify,而不是直接在原始像素上 patchify。这样可以显著降低 token 数。
#4.6 Patchify 和 unpatchify 是一对逆过程吗?
概念上,它们是一对“形状上的逆过程”,但不是严格的信息逆变换。
训练时:
x_t image/latent
-> patchify + embedding
-> Transformer
-> predict patch outputs
-> unpatchify
-> ε_pred image/latent
其中 patchify 开头的投影是:
[48] -> [384]
它是把 patch 放到 hidden space 里,不要求可逆。真正用于输出的是 final layer:
[384] -> [48]
然后再把 64 个 [48] patch 拼回:
[B, 64, 48] -> [B, 3, 32, 32]
所以可以理解为:
patchify:把 noisy image/latent 切成 token,并升维成 hidden states。unpatchify:把每个 token 的预测结果还原成对应 patch 的噪声/速度场,再拼回原图形状。
最终 DiT 不是直接输出一串 token 给用户,而是输出和输入 x_t 同形状的预测噪声:
ε_pred: [B, 3, 32, 32]
这就是为什么 DiT 虽然中间是 Transformer token 序列,但训练目标仍然是图像/latent 空间里的 denoising loss。
#5. Position、timestep 与 class condition
DiT 需要三类信息:
- 图像内容 token:来自
x_t的 patch tokens。 - 位置编码:告诉模型每个 patch 在图像中的位置。
- 条件信息:包括 timestep
t和类别/文本条件y。
#5.1 位置编码
位置编码加到 patch tokens 上:
x: [B, 64, 384]
pos_embed: [1, 64, 384]
x + pos_embed: [B, 64, 384]
原始 DiT 常使用 fixed 2D sin-cos positional embedding;教学实现里也可以用 learnable positional embedding,方便理解。
#5.2 Timestep embedding
t 是扩散过程中的时间步,shape 是:
t: [B]
通常先用 sinusoidal embedding,再接一个 MLP:
t -> sinusoidal embedding -> MLP -> t_emb
得到:
t_emb: [B, 384]
#5.3 Class condition
对于 CIFAR-10 class-conditional generation,类别标签 y 通过 embedding table 得到:
y: [B]
y_emb: [B, 384]
然后把时间条件和类别条件相加:
c = t_emb + y_emb
c: [B, 384]
这个 c 会被送进每个 DiT block,用来调制 LayerNorm 和 residual branch。
#6. DiT Block:Transformer + AdaLN-Zero
DiT block 可以理解为一个被条件 c 调制的 Transformer block。
普通 Transformer block 大致是:
x -> LayerNorm -> Self-Attention -> residual
x -> LayerNorm -> MLP -> residual
DiT 的关键改造是 AdaLN-Zero。它不是简单地把 timestep/class embedding 加到 token 上,而是用条件向量 c 生成 LayerNorm 的 shift、scale 和 residual gate。
一个典型 DiT block 会从 c 中生成六组参数:
shift_msa, scale_msa, gate_msa,
shift_mlp, scale_mlp, gate_mlp = adaLN_modulation(c).chunk(6, dim=1)
含义是:
shift_msa, scale_msa:调制 attention 前的 LayerNorm 输出。gate_msa:控制 attention residual branch 的强度。shift_mlp, scale_mlp:调制 MLP 前的 LayerNorm 输出。gate_mlp:控制 MLP residual branch 的强度。
直觉上,条件 c 在告诉模型:
在当前 timestep 和当前类别条件下,attention/MLP 应该怎样处理这些 noisy patch tokens。
AdaLN-Zero 中的 “Zero” 指的是调制层和最终输出层常常被零初始化,使模型一开始接近恒等映射或零输出,从而稳定训练很深的 Transformer denoiser。
#7. Final layer 与 unpatchify:从 token 回到噪声图
经过若干 DiT blocks 后,token shape 仍然是:
x: [B, 64, 384]
最终需要把每个 token 映射回一个 patch 的像素噪声。
在 CIFAR-10 教学设定中:
patch_size = 4
channels = 3
patch_dim = 4 * 4 * 3 = 48
所以 final linear layer 输出:
[B, 64, 384] -> [B, 64, 48]
然后 unpatchify:
[B, 64, 48]
-> [B, 8, 8, 4, 4, 3]
-> [B, 3, 32, 32]
最终模型输出:
ε_pred: [B, 3, 32, 32]
和真实噪声 ε 形状完全一致。
#8. 一次训练 step 的完整 shape 流
把所有步骤串起来,一次训练 step 是:
1. 从 dataloader 取数据
x0: [128, 3, 32, 32]
y: [128]
2. 采样 timestep 和噪声
t: [128]
ε: [128, 3, 32, 32]
3. 构造 noisy image
x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * ε
x_t: [128, 3, 32, 32]
4. DiT patch embedding
x_t -> patch tokens
tokens: [128, 64, 384]
5. 加位置编码
tokens + pos_embed: [128, 64, 384]
6. 构造条件向量
t_emb: [128, 384]
y_emb: [128, 384]
c: [128, 384]
7. 经过 DiT blocks
[128, 64, 384] -> [128, 64, 384]
8. final projection + unpatchify
[128, 64, 384] -> [128, 64, 48] -> [128, 3, 32, 32]
9. 计算 loss
loss = MSE(ε_pred, ε)
loss: scalar
对应的核心训练代码就是:
t = torch.randint(0, diffusion_steps, size=(B,), device=device)
noise = torch.randn_like(x0)
x_t = diffusion.q_sample(x0, t, noise)
pred_noise = model(x_t, t, y)
loss = F.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
这就是 DiT 训练最核心的闭环。
#9. 教学版 PyTorch 实现骨架
下面是一个教学级 CIFAR-10 DiT 的模块拆分。完整代码可以按这个结构组织成 train_dit_cifar10.py。
@dataclass
class Config:
image_size: int = 32
in_channels: int = 3
num_classes: int = 10
patch_size: int = 4
hidden_size: int = 384
depth: int = 8
num_heads: int = 6
mlp_ratio: float = 4.0
diffusion_steps: int = 1000
batch_size: int = 128
lr: float = 1e-4
epochs: int = 100
data_dir: str = "./data"
out_dir: str = "./outputs"
核心模块包括:
CIFAR10 dataloader
GaussianDiffusion
PatchEmbed
TimestepEmbedder
ClassEmbedder
DiTBlock with AdaLN-Zero
FinalLayer
DiT
training loop
sampling loop
如果显存较小,可以把配置降到:
batch_size = 32
hidden_size = 256
depth = 4
num_heads = 4
需要强调:这个版本是为了理解 DiT 的教学实现,不是论文级实现。它有几个简化:
- 使用 pixel-space CIFAR-10,而不是 latent diffusion。
- 使用简单 DDPM 训练和采样,而不是 DDIM、DPM-Solver、Flow Matching 或 Rectified Flow。
- 位置编码可以用 learnable embedding,而原始 DiT 常用 fixed 2D sin-cos embedding。
- 没有加入 AMP、EMA、分布式训练、FlashAttention、大规模数据管线等工程组件。
#10. 采样:从纯噪声反向生成图片
训练时模型学习的是:
ε_θ(x_t, t, y)
采样时则从纯噪声开始:
x_T ~ N(0, I)
然后从 T-1 到 0 逐步反推:
x = torch.randn(shape, device=device)
for i in reversed(range(timesteps)):
t = torch.full((B,), i, device=device, dtype=torch.long)
x = p_sample(model, x, t, y)
每一步都用 DiT 预测当前噪声,再根据 DDPM 公式得到更干净一点的 x_{t-1}。
最后得到:
x_0_pred: [B, 3, 32, 32]
再把它从 [-1, 1] 转回 [0, 1],保存成图片。
#11. 如果换成 Stable Diffusion / 大模型版本,会发生什么?
真实大模型很少直接在 pixel space 做高分辨率扩散,因为成本太高。Stable Diffusion 类模型通常在 VAE latent space 里训练。
此时数据流从:
image: [B, 3, 512, 512]
变成:
VAE encoder -> latent z_0: [B, 4, 64, 64]
然后 DiT 不再处理 RGB 图像,而是处理 latent:
z_t: [B, 4, 64, 64]
-> patchify
-> Transformer / DiT / MMDiT
-> predict noise or velocity in latent space
-> VAE decoder -> image
所以大模型版 DiT 的本质没有变,只是:
- 输入从 image pixel 变成 VAE latent。
- 条件从 class label 变成 text embedding / multimodal embedding。
- 训练目标可能从 epsilon prediction 变成 v-prediction、flow matching 或 rectified flow。
- Transformer 结构可能从单流 DiT 变成双流/多流 MMDiT。
#12. DiT 与 ViT、GPT、MMDiT、Flux、Sora 的关系
可以这样理解 DiT 在模型谱系中的位置:
ViT:
image -> patches -> Transformer -> classification
DiT:
noisy image/latent -> patches -> Transformer -> denoising / vector field
MMDiT:
image/video latent tokens + text tokens -> multimodal Transformer -> denoising / flow
Video DiT / Sora-like world model:
noisy video latent -> spatiotemporal tokens -> Transformer -> video generation / prediction
DiT 把扩散模型和 Transformer scaling 结合起来,因此它自然成为图像、视频、多模态生成和世界模型中的重要底座。
从这个角度看,DiT 的意义不只是“图像生成模型换了一个 backbone”,而是让扩散模型进入了更统一的 token-based scaling 框架:
- 图像可以是 token。
- 视频可以是 token。
- 文本条件可以是 token。
- 动作和状态也可以是 token。
- 未来的世界模型、robotics policy、latent reasoning 也可能共享类似接口。
#13. 最重要的理解
最后可以把 DiT 的主线压缩成一句话:
图片/latent x0
-> 加噪得到 x_t
-> patchify 成 tokens
-> 加位置编码
-> 加 timestep/class/text condition
-> DiT blocks
-> 预测 patch-level noise / velocity
-> unpatchify 成整图噪声或 latent velocity
-> 和真实噪声/速度场做回归 loss
所以,理解 DiT 的关键不是背结构名,而是抓住三件事:
- 扩散训练目标:模型在每个时间步学习如何从 noisy sample 中预测噪声或速度场。
- Transformer tokenization:图像或 latent 被切成 patch tokens,进入 Transformer。
- 条件调制机制:timestep、class/text condition 通过 AdaLN-Zero 等方式控制每一层的去噪行为。
只要这三点打通,DiT、MMDiT、Flux、视频 diffusion transformer、甚至一部分 world model 的结构都会变得更容易理解。