#Megatron reported TFLOP/s 的真实含义:small SWA、GQA、FP8、CP 与 fused kernel 的分子/分母分析

我查了 Liangguang 实际跑的 Megatron 源码。先说最关键的一点:

Megatron 日志里的 throughput per GPU (TFLOP/s/GPU)估算 FLOPs / 实际 step time / GPU 数,不是 GPU 硬件真实计数器测出来的 FLOPs。

源码位置:

/mnt/project_modelware/liangguang/llm100/megatron/Megatron-LM/megatron/training/training.py:2214-2216

核心公式就是:

throughput = num_floating_point_operations(args, batch_size) / (
    elapsed_time_per_iteration * 10**12 * args.world_size
)

人话:

reported TFLOP/s/GPU =
Megatron 估算这一 training step 有多少 FLOPs
/ 这一 step 花了多少秒
/ GPU 数量

所以你问的这些:

small SWA window 127
GQA
FP8
context parallel
fused kernel

本质都要分两层看:

1. 它们有没有改变 Megatron 估算的 FLOPs 分子?
2. 它们有没有改变真实 step time 分母?

#1. Megatron reported throughput 底层怎么算

FLOPs 估算入口在:

megatron/training/training.py:281
def num_floating_point_operations(args, batch_size):

Transformer 的主公式在:

training.py:648-700

大概结构是:

total_flops =
batch_size
× seq_length
× (
    MLP FLOPs
  + self-attention FLOPs
  + logits FLOPs
  + 其他项
)

其中 attention 的公式在:

training.py:531-555

核心是:

query_projection_size = kv_channels * num_attention_heads
key_projection_size = kv_channels * num_query_groups
value_projection_size = kv_channels * num_query_groups
standard_self_attn_term =
    3 * 2 * (
        hidden_size * (query_projection_size + key_projection_size + value_projection_size)
        + query_projection_size * seq_length / 2 * 2
        + query_projection_size * hidden_size
    )

解释一下:

3 = forward + backward wgrad + backward dgrad
2 = FMA,一个 multiply-add 算 2 FLOPs

里面三块分别是:

hidden_size * (Q + K + V)
= QKV projection FLOPs

query_projection_size * seq_length / 2 * 2
= attention score QK^T + attention value PV
= full causal attention 的 O(seq_len^2) 项

query_projection_size * hidden_size
= output projection FLOPs

外面还会再乘:

batch_size × seq_length

所以 full attention 的核心 FLOPs 近似是:

batch × seq_len × seq_len × hidden

也就是大家常说的:

O(B × S^2 × H)

#2. small SWA window 127 底层怎么算

你现在看到的 high throughput 配置是:

--window-size 127,0
--window-attn-skip-freq 6

#2.1 真实计算上发生了什么

full causal attention 里,每个 token 可以 attend 到前面所有 token。

如果:

seq_len = 16384

full causal attention 的 token-pair 数大约是:

S^2 / 2
= 16384^2 / 2
≈ 134M pairs

small SWA 的 window 是 127,意思是大部分 local attention 层里,每个 token 只看附近约 127 个 token。

那么 attention pair 数大概变成:

S × W
= 16384 × 127
≈ 2.08M pairs

比例是:

(S × W) / (S^2 / 2)
= 2W / S
= 254 / 16384
≈ 1.55%

所以在 window attention 层里,attention score/value 这部分真实计算量从 full attention 的 100% 变成约 1.55%。

这就是 small SWA 为什么快。

#2.2 skip freq = 6 是什么意思

源码里判断哪些层是 window attention 的逻辑在:

megatron/core/transformer/utils.py:453-465

FLOPs 统计里也有类似逻辑:

training.py:458-474

逻辑大概是:

if layer_idx % window_attn_skip_freq != 0:
    use window attention
else:
    use global/full attention

所以:

window_attn_skip_freq = 6

意思是:

第 6, 12, 18, 24, 30 ... 层是 global attention
其他层是 window attention

对于 7B 32 层:

global layers: 6, 12, 18, 24, 30 = 5 层
window layers: 27 层

对于 3B 16 层:

global layers: 6, 12 = 2 层
window layers: 14 层

也就是说 7B 里大约:

27 / 32 = 84.4%

的层都在用 tiny local attention。

#2.3 但是 reported FLOPs 里怎么算?

这里有一个非常关键的发现:

Megatron 的 reported FLOPs 公式没有把 window_size=127 代入 core attention FLOPs。

我查了 Liangguang 这份源码:

training.py

里面 window_size 只出现在:

training.py:458-474

用于统计哪些层是 window attention。

但是核心 attention FLOPs 仍然写死用了:

query_projection_size * args.seq_length / 2 * 2

位置:

training.py:546-550

也就是它仍然按 full causal attention 的 seq_length 算,而不是按 window_size=127 算。

所以 small SWA 的效果是:

真实 step time 变短了
但 reported FLOPs 分子没有等比例减少
所以 reported TFLOP/s/GPU 和 MFU 会变高

这就是为什么 small SWA 的 MFU 要小心解释。

它不是单纯“GPU 利用率变高”,而是:

真实 attention 工作变少
step time 下降
reported FLOPs 仍接近 full-attention 口径
=> reported throughput 上升

#3. GQA 底层怎么算

你的 7B 配置是:

--group-query-attention
--num-query-groups 8
--num-attention-heads 32
--kv-channels 128

#3.1 普通 MHA

普通 MHA 里:

Q heads = 32
K heads = 32
V heads = 32

每个 head dim 是:

kv_channels = 128

所以:

Q projection size = 32 × 128 = 4096
K projection size = 32 × 128 = 4096
V projection size = 32 × 128 = 4096

QKV projection 总输出大小:

4096 + 4096 + 4096 = 12288

#3.2 GQA

GQA 以后:

Q heads = 32
K/V groups = 8

所以:

Q projection size = 32 × 128 = 4096
K projection size = 8 × 128 = 1024
V projection size = 8 × 128 = 1024

QKV projection 总输出大小:

4096 + 1024 + 1024 = 6144

也就是 QKV projection 这块大约少了一半。

源码里 FLOPs 公式正是这么算的:

query_projection_size = args.kv_channels * args.num_attention_heads
key_projection_size = args.kv_channels * args.num_query_groups
value_projection_size = args.kv_channels * args.num_query_groups

位置:

training.py:531-534

#3.3 GQA 对 core attention 有什么影响?

注意:GQA 减少的是 K/V head 数,但 Q heads 还是 32 个。

attention score 本质是:

每个 query head 去和对应 K group 做 attention

所以 core attention 的 query 侧规模仍然和 Q heads 相关。

Megatron FLOPs 公式里 core attention 项用的是:

query_projection_size * seq_length / 2 * 2

位置:

training.py:546-550

也就是说 reported FLOPs 里:

GQA 会减少 K/V projection FLOPs
但不会减少 core attention 的 S^2 项

这也符合直觉:GQA 主要减少 K/V projection、KV activation、KV cache、KV 通信/带宽;不一定大幅减少 QK/PV 的 query-side 计算。

#3.4 为什么实验里 GQA 没让 reported throughput 上升?

你的实验里:

3B:
3b_06_tp2_mb4_gb64_sp: 608.850
3b_07_gqa: 601.767

7B:
7b_06_tp4_mb4_gb256_sp_no_gqa: 521.150
7b_07_gqa: 496.433

原因可能有两层:

第一,reported FLOPs 分子变小了。

因为 GQA 降低 K/V projection FLOPs,所以:

reported TFLOP/s = estimated FLOPs / time

即使真实 step time 稍微变快,分子也变小,reported TFLOP/s 不一定上升。

第二,GQA kernel 路径不一定比 MHA kernel 更高效。

TE 虽然支持 GQA:

megatron/core/extensions/transformer_engine.py:1401-1409

但不同 shape 下 kernel occupancy、memory layout、通信模式都可能变差。

所以 GQA 更像是:

降低 KV projection / KV memory 的结构优化

不一定是:

提高 reported MFU 的性能优化

#4. FP8 delayed hybrid 底层怎么算

配置是:

--transformer-impl transformer_engine
--fp8-format hybrid
--fp8-recipe delayed
--fp8-amax-compute-algo max
--fp8-amax-history-len 1024

#4.1 FP8 改了什么?

FP8 不改变矩阵形状。

比如原来 MLP GEMM 是:

[B×S, H] × [H, FFN]

FP8 以后 shape 还是这个 shape。

变化是:

输入/权重用 FP8 表示
H100 用 FP8 Tensor Core 计算
中间 accumulation / output 用更高精度处理

所以数学 FLOPs 数量基本不变。

但是硬件执行速度更快、显存带宽压力更小。

#4.2 hybrid 是啥?

hybrid 一般是 E4M3 + E5M2 混用:

E4M3: 精度更高,动态范围小
E5M2: 动态范围更大,精度稍低

常见逻辑是:

forward 更偏 E4M3
backward 更偏 E5M2

用来平衡精度和动态范围。

#4.3 delayed 是啥?

FP8 需要 scale。

因为 FP8 表示范围很小,需要把 tensor 缩放到合适范围:

fp8_value = original_value × scale

Transformer Engine 会记录 amax:

amax = tensor 绝对值最大值

delayed recipe 的意思是:

当前 step 使用历史 amax 算出来的 scale
当前 step 只收集新的 amax
新的 amax 以后再用于更新 scale

源码相关位置:

megatron/core/fp8_utils.py:545-571
megatron/core/fp8_utils.py:628-630
megatron/core/transformer/transformer_block.py:759-770

#4.4 FP8 对 reported throughput 怎么影响?

Megatron reported FLOPs 公式不看 dtype。

也就是说:

BF16 GEMM 和 FP8 GEMM
在 reported FLOPs 分子里基本一样

但是 FP8 让 step time 下降,所以:

reported TFLOP/s/GPU 上升

你的实验里很明显:

3B:
TE+flash BF16: 494.533
FP8: 660.500

7B:
TE+flash BF16: 433.650
FP8: 528.250

这和 small SWA 不一样:

FP8 = 同样形状的矩阵乘,用更快的数据类型/硬件路径做
small SWA = attention 实际工作量变少,但 reported FLOPs 公式没有完全扣掉

所以 FP8 是更“正统”的 runtime/kernel 加速。

不过注意:

我们 MFU 用的是:

989 TFLOP/s/GPU

这是 H100 BF16 口径。

如果按 H100 FP8 峰值算,分母会更大,MFU 百分比会低很多。


#5. Context Parallel 底层怎么算

你的 16k 复现配置是:

--seq-length 16384
--tensor-model-parallel-size 2
--context-parallel-size 2
--cp-comm-type a2a

总共 8 张卡:

world_size = 8
TP = 2
CP = 2
PP = 1
DP = world_size / (TP × CP × PP)
   = 8 / (2 × 2 × 1)
   = 2

也就是:

2-way data parallel
2-way tensor parallel
2-way context parallel

#5.1 CP 做了什么?

Context Parallel 是把 sequence 维度切开。

如果:

seq_len = 16384
CP = 2

那么每个 CP rank 本地大概只持有:

16384 / 2 = 8192 tokens

这能显著降低每张卡上的 activation / attention memory。

这就是为什么 16k 可以 fit 进 8×H100。

#5.2 但是 attention 需要跨 context 通信

虽然每个 rank 只持有一段 sequence,但 attention 需要看到别的 sequence chunk 的 K/V。

所以 CP 需要通信。

你这里是:

--cp-comm-type a2a

也就是用 all-to-all 风格通信交换 attention 需要的信息。

源码里 CP comm type 进入 TE attention 的位置:

megatron/core/extensions/transformer_engine.py:1451-1463

#5.3 CP 对 reported FLOPs 怎么影响?

Megatron FLOPs 分子仍然按全局算:

seq_len = 16384
global batch size = 256

它不会因为 CP=2 就把 FLOPs 除以 2。

公式里 denominator 会除以:

world_size = 8

也就是所有 GPU。

所以 CP 的作用不是改变 reported FLOPs 分子,而是:

让每张 GPU 的显存压力下降
让 16k 能跑起来
改变 step time:多了通信,但本地 sequence 变短

在你的 16k run 里:

iter 3: step time ≈ 41.6s, throughput 656.6
iter 4: step time ≈ 41.8s, throughput 652.9
iter 5: step time ≈ 41.5s, throughput 657.8

每步 token 数是:

global_batch × seq_len
= 256 × 16384
= 4,194,304 tokens

所以真实 token throughput 大概是:

4,194,304 / 41.5 ≈ 101k tokens/sec

reported TFLOP/s/GPU 则是:

Megatron estimated FLOPs / 41.5s / 8 GPUs
≈ 656 TFLOP/s/GPU

#6. fused kernel 底层怎么算

fused kernel 的核心思想是:

数学结果一样,但把多个小操作合成一个 kernel,减少读写显存和 kernel launch 开销。

#6.1 cross entropy loss fusion

配置:

--cross-entropy-loss-fusion

源码路径:

megatron/core/models/common/language_module/language_module.py:172-199
megatron/core/fusions/fused_cross_entropy.py:136-148

普通 cross entropy 大概会做:

logits
max
subtract max
exp
sum
log
gather target token
loss
backward gradient

如果不 fuse,中间可能会有很多 tensor 写回显存再读出来。

fused cross entropy 会把这些合在一起:

少写中间结果
少读中间结果
少启动 kernel

数学 FLOPs 没有本质变化,但 step time 可以下降。

所以:

reported FLOPs 分子基本不变
step time 下降
reported throughput 上升

这就是为什么你的实验里 CE fusion 有时能带来一点提升。

#6.2 flash attention / fused attention

FlashAttention 也是 fused kernel 思路。

普通 attention 可能逻辑上是:

QK^T
mask
softmax
dropout
PV

naive 实现会 materialize 巨大的 attention matrix。

FlashAttention 把它 tile 化、融合化:

不完整落盘整个 attention matrix
在 SRAM / shared memory 里分块计算
减少 HBM 读写

数学上还是 attention,但内存复杂度低很多。

再加上 small SWA 时,TE/FA 还能只算 window 内的块,所以真实计算和访存进一步减少。

#6.3 其他 fused kernel

类似还有:

fused RMSNorm / LayerNorm
bias/dropout/add fusion
gradient accumulation fusion
fused optimizer pieces

这些通常不是改变模型数学,而是:

减少 kernel launch
减少 HBM traffic
提高 Tensor Core / SM occupancy
减少 Python/CUDA 调度开销

reported FLOPs 分子通常不变,只是 step time 变短。


#7. 把这些东西放进一张表

选项真实计算变化Megatron reported FLOPs 分子变化对 step time 的影响怎么理解
small SWA window 127attention 从近似 S^2/2 变成 S×127,大部分层 attention 工作大幅减少这份源码里 core attention 仍按 full seq_len 算,基本没扣 window明显下降reported MFU 会被抬高,是 dense-equivalent 口径
GQAK/V projection 和 KV memory 减少,Q heads 不变K/V projection FLOPs 下降,core attention 基本不变不一定下降,取决于 kernel/shape不一定提高 reported TFLOP/s
FP8 delayed hybrid同样 GEMM shape,换 FP8 Tensor Core 路径基本不变,因为公式不看 dtype明显下降这是比较纯粹的 runtime/kernel 加速
Context Parallelsequence 分片,本地 seq 变短,但需要跨 rank 通信仍按全局 seq_len/global batch 算显存下降,通信上升,能跑长上下文主要是让 16k fit,并改变通信/计算平衡
fused kernels数学不变,多个操作融合通常不变下降一点到明显下降减少 HBM traffic 和 kernel launch

#8. 最重要的结论

你现在看到的 high throughput 来源可以拆成两类。

#A. 真正让同等数学计算更快的

FP8
fused kernels
FA3 / TE 高效 kernel
gradient accumulation fusion
cross entropy fusion

这些主要是:

同样/类似 FLOPs
更短 step time

所以 reported throughput 上升比较合理。

#B. 改变实际计算量,但 reported FLOPs 没完全跟着变的

small SWA window 127

它真实减少了 attention 工作量,但 Liangguang 这份 Megatron 的 reported FLOPs 公式仍然基本按 full seq_len attention 算。

所以 small SWA 的 reported MFU 更像:

dense-equivalent MFU

不是严格的真实硬件利用率。


#9. 用你 16k run 举个具体数字

你的 16k 复现里:

global_batch = 256
seq_len = 16384
每步 token = 256 × 16384 = 4,194,304 tokens
稳定 step time ≈ 41.5s

真实 token 吞吐:

4,194,304 / 41.5 ≈ 101,000 tokens/sec

reported throughput:

约 656 TFLOP/s/GPU

这个 656 是这么来的:

Megatron 估算该 step 的 FLOPs
/ 41.5 秒
/ 8 张 GPU

但因为 small SWA 的 window=127 没有被完整扣进 reported FLOPs 分子,所以这个 656 TFLOP/s/GPU 不是“GPU 真的执行了 full attention 那么多 FP ops”,而是:

按 Megatron full-ish FLOPs 估算口径折算出来的 dense-equivalent throughput

所以以后比较时建议同时看三列:

1. step time
2. tokens/sec
3. reported TFLOP/s/GPU / MFU

如果模型结构没变,看 reported MFU 很有意义。

如果像 small SWA 这样改了 attention 工作量,就一定要同时看 step time 和 tokens/sec。