#论文信息

  • 标题:Stacking Your Transformers: A Closer Look at Model Growth for Efficient LLM Pre-Training
  • 链接:<https://arxiv.org/abs/2405.15319>
  • 会议/版本:NeurIPS 2024 / arXiv v2

#一句话结论

这篇论文的核心发现很直接:在 LLM 预训练里,与其设计很复杂的“长大”策略,不如先训练一个更小的模型,再沿着深度方向把它整块堆起来(stack),然后继续训练。 这种方法叫 Gstack。在他们的实验里,它比从头训练更快地达到相同 loss,而且这种优势在更大模型和更长训练 token 下仍然成立。

如果只记一句话,我会记这个:“先训小,再按层堆深,再续训” 是一种足够简单、但 surprisingly 有效的预训练加速路线。

#这篇论文在解决什么问题?

训练大语言模型太贵了。

通常我们会直接从随机初始化开始,按目标模型规模一路训练到底。但一个自然的问题是:

能不能先训练一个小模型,再把它“长大”成大模型,从而减少大模型从零开始学基础能力的成本?

这个方向叫 model growth。问题在于,以前这类工作虽然不少,但真正落到 decoder-only LLM 预训练里,还存在三类障碍:

  1. 缺少统一比较:不同 growth 方法各自做各自的实验,很难公平比较。
  2. 可扩展性不清楚:小规模实验有效,不代表到了 3B、7B 甚至更长 token 训练还有效。
  3. 缺少实用指南:即使方法有效,工程上也不知道该什么时候 grow、grow 多大。

这篇论文就是围绕这三个问题展开的。

#论文做了什么?

作者先把已有增长思路抽象成四类“原子操作”:

  • Gdirect:直接复制/堆叠已有参数
  • Glearn:通过可学习映射生成新参数
  • Gzero:新参数初始化为 0
  • Grandom:新参数随机初始化

然后又区分两种增长方向:

  • widthwise:在同一层里扩宽
  • depthwise:沿网络深度方向加层

他们用统一的 two-stage protocol 去比较:

  1. 先训练一个小模型,训练 token 数记为 d
  2. 用某个 growth operator 把它长成目标模型,增长倍数记为 g
  3. 再继续训练大模型,训练 token 数记为 D

最后比较 loss、benchmark 平均分、以及达到相同效果所需的计算量/训练步数。

#最关键发现:最有效的是 depthwise stacking,也就是 Gstack

实验结论很清楚:

  • 横向扩宽(widthwise)基本没什么优势
  • 纵向加深(depthwise)明显更有效
  • 在 depthwise 方案里,最简单的直接堆叠 Gdirect^up,也就是 Gstack,是最强的

这件事其实很有意思。很多人直觉上会觉得:更复杂的方法、更多可学习映射、更“函数保持”的设计,应该更好。但这篇论文给出的答案反而是:

简单直接地把一个已经学会基本表示的小模型按层复制、堆深,然后继续训,效果最好。

作者在 1.1B LLM 的对比里发现,Gstack 在速度、loss 和绝大多数下游评测上都优于 baseline。论文里举的数字是:相较于从头训练到 100B token,Gstack 带来了 49.1% 的 speedup

#Gstack 到底是什么?

可以把它理解成:

  • 先训练一个较浅的小 Transformer,记作 M
  • 然后把这个小模型在深度上重复堆叠 g 次,形成更深的大模型
  • 再继续预训练

作者给出的形式可以写成:

target model = M ◦ M ◦ ... ◦ M(共 g 次)

也就是说,不是给大模型完全随机初始化,也不是只复制某几层,而是把整个小模型作为一个模块整体堆起来

这个设计背后的直觉是:

  • 小模型已经学到一套可用的 token 表示与层间变换
  • 直接按深度复制,能保留这套“已有的计算路径”和层间连接模式
  • 大模型不是从“不会”开始学,而是从“先有一个能工作的粗糙版本”开始继续细化

论文后面的 ablation 也支持这个理解:越保留原模型的完整连接结构,stacking 效果越好。

#为什么它可能有效?我的理解

这部分是论文启发下的理解,不是论文原文逐句照搬。

我觉得 Gstack 有效,主要因为它抓住了 LLM 预训练里一个很现实的事实:

#1. 预训练前期主要是在学“基本语言电路”

小模型先训练一段时间后,已经学会了很多基础模式,比如:

  • token / subword 的统计结构
  • 短程依赖
  • 一些中低层语法、搭配和局部语义模式

如果大模型从零开始,这些东西要重新学一遍;而 Gstack 相当于把这些初级结构直接搬进更深的网络里。

#2. 深度方向复制,比宽度方向扩展更接近“能力复用”

宽度扩展常常意味着新增大量通道、头、MLP 维度,这些新单元一开始未必和旧表示空间对齐;但深度堆叠更像是在已有表示变换链条上“继续串联”。

这可能更容易让优化器接着往下走,而不是重新协调一大堆新特征维度。

#3. 它不强调严格 function preserving,但依然有效

很多 growth 文献会强调“增长后模型初始函数尽量不变”。这篇论文反而指出,Gstack 并不严格满足这种 function preserving 要求,但实证效果最好。

这很值得注意:

  • 对 LLM 预训练来说,最重要的未必是“增长瞬间输出完全不变”
  • 更关键的可能是:增长后的参数结构是否给了优化一个更好的起点

换句话说,优化友好性 可能比教科书式的函数保持更重要。

#可扩展性:不是小打小闹,3B / 7B 也有效

论文最有说服力的部分之一,是它不仅在 1.1B 做对比,还继续往上做了 3B 和 7B。

#在 3B 模型上

作者做法是:

  • 先训练一个层数为目标模型 1/4 的小模型
  • 先训 10B tokens(也就是 d = 10B
  • 再用 g = 4 的方式堆成目标模型
  • 然后继续训练到 300B tokens

结果是:

  • 在 180B、240B tokens 这些点上,Gstack 相比 scratch 分别有 48.6%54.5% 的 speedup
  • benchmark 平均分也更高

#在 7B 模型上

论文摘要里给了一个最醒目的结果:

  • 对比一个常规训练的 7B 模型在 300B tokens 的表现
  • Gstack 版本只用 194B tokens 就达到相同 loss
  • 相当于 54.6% speedup

这个量级已经不只是“略快一点”,而是相当可观的训练加速。

#更重要的一点:训得很久以后它也没掉队

很多“高效训练”方法的问题是:

  • 前期 loss 掉得快
  • 但训练足够久以后,优势消失,甚至被 scratch 反超

作者专门检查了这一点。

他们在 410M 模型上训练到 750B tokens,这是远超 Chinchilla 推荐 token 数的“过训练”场景。结果发现:

  • 400B tokens 时,Gstack 仍有 53.1% acceleration
  • 700B tokens 时,仍有 31.0% acceleration

论文甚至拟合认为,优势在更长训练区间里仍可能持续。

这点很重要,因为现实世界的很多 LLM 训练已经不完全遵守经典 Chinchilla 最优点,往往会继续 overtrain。也就是说,这篇论文不是只在“理想小实验”里成立,而是在更接近实际工业训练的设定里也成立。

#这篇论文最实用的部分:给了 growth timing 和 growth factor 的经验法则

如果只知道 “stacking 好用”,但不知道什么时候 grow、grow 几倍,工程上还是很难落地。

所以论文在第 4.2 节专门回答两个问题:

  1. growth timing d:小模型应该先训多久再长大?
  2. growth factor g:应该从多小的模型长到多大的模型?

#1. Growth timing:不是越早越好,也不是越晚越好

作者在 410M、1.1B、3B 三种目标模型上画了 IsoFLOP 曲线,比较不同 d 值:

  • 0B
  • 1B
  • 5B
  • 10B
  • 20B
  • 50B

他们发现曲线出现明显 valley,说明:

对于固定计算预算,存在一个最优的 growth timing。

这很好理解:

  • grow 太早:小模型自己还没学出稳定结构,堆起来也没什么可继承
  • grow 太晚:你在小模型上花了太多计算,反而浪费了本可用于训练大模型的预算

一个有意思的细节是:

  • d = 0B(相当于直接把随机初始化小模型拿来 stack)效果和从头训练差不多
  • d = 1B 就已经有明显加速

这说明 真正起作用的不是“stack”这个动作本身,而是“先让小模型学到点东西再 stack”。

作者进一步拟合了一个经验式:

log10(d) = a log10(N) + b / log10(C) + c

其中:

  • N 是目标模型参数量
  • C 是计算预算
  • d 是最优 growth timing

拟合参数约为:

  • a = 0.88
  • b = 163.27
  • c = -5.74

别太把这个式子当成普适定律,但它至少给了一个很重要的工程认知:

grow timing 是可以被系统调参和经验建模的,不是拍脑袋。

#2. Growth factor:最佳通常在 2 到 4 之间

作者也测试了不同 growth factor。

结果显示:

  • 即使比较激进的 grow 也可能有收益
  • 但如果 base model 太浅(例如只剩 1 层)就会明显变差
  • 综合 IsoFLOP 曲线来看,最优 g 往往落在 2 到 4 之间

由于算力限制,他们没有像 timing 那样给出特别可靠的通用公式,所以最后给出的实用建议是:

默认先用 g = 4

这也是他们大量实验采用的主设置。

#这篇论文还有哪些值得注意的点?

#1. 它强调“简单 baseline + 统一比较”

这是我很喜欢的一点。论文没有一上来就推一个花哨新机制,而是先把已有思路做成统一实验框架,再证明:最简单的深度堆叠就是最强 baseline。

这种工作很朴素,但很有价值,因为它给后续研究建立了一个更靠谱的比较基线。

#2. 它暗示“连接结构保留”可能很关键

论文后面的 stacking variants 比较显示:

  • 整体堆完整小模型最好
  • 插值式 stacking、partial stacking 通常更差

这背后很可能说明:不是“把层数变多”本身重要,而是尽量保留已经学好的层间协同结构。

#3. 它对 function preserving 提出了一个实证上的反例

很多网络增长理论偏好 function-preserving 初始化,但这篇工作说明在 LLM 预训练里,非严格 FP 的方案也可能更优。

我觉得这是个很值得继续深挖的点:大模型预训练也许更需要的是“优化轨迹继承”,而不只是“函数值继承”。

#这篇论文的局限性

论文自己也讲了几类限制,我补充整理一下:

#1. growth factor 的经验式还不够扎实

他们自己承认,关于 g 的实验点不够多,所以虽然观察到最佳区间大致在 2 到 4,但还谈不上一个特别稳的 scaling law。

#2. 主要研究的是简单 operator

这其实既是优点也是限制。

优点是结论清楚、可复现、容易落地;

限制是:

  • 多阶段 grow
  • 动态调学习率
  • 更复杂的 mapping 初始化
  • 与 MoE / state space / hybrid 架构结合

这些都还没被系统探索。

#3. 评价重点仍是 pretraining loss + 常见 NLP benchmark

虽然这已经很有说服力,但如果放到今天大家更关心的能力维度,还会想继续问:

  • 对 instruction tuning 后的能力迁移如何?
  • 对推理、代码、多轮对话、工具使用影响如何?
  • 对对齐阶段稳定性有没有副作用?

论文附录里做了一些扩展,但主线仍是“预训练效率”。

#4. 还没有直接回答最工业级的问题

比如:

  • 在超大规模集群上如何安排 grow 的 checkpoint / pipeline 切换?
  • grow 后 optimizer state 怎么迁移最合理?
  • 对现代训练 recipe(数据配比、LR schedule、weight decay、rope/attention 变体)是否同样稳健?

这些工程细节,决定了它离真正大规模生产还有多远。

#我对这篇论文的总体评价

我的评价是:这是一篇很扎实、很有工程价值的论文。

它未必在理论上给出特别深的新解释,但它做成了三件非常重要的事:

  1. 把 model growth 在 LLM 预训练里重新系统化了
  2. 证明最简单的 Gstack 是个很强、而且可扩展的 baseline
  3. 第一次把“什么时候 grow、grow 多大”这两个工程问题讲得比较可操作

如果你是做 LLM 预训练系统的人,这篇论文最值得带走的不是某个 fancy 公式,而是一个很实用的范式:

先用较小模型学出基础结构,再通过 depthwise stacking 扩成目标模型继续训练,这可能是比全程 scratch 更省算力的默认路线之一。

#对研究者/工程实践者的启发

#如果你是研究者

可以继续追这些方向:

  • 为什么 depthwise stacking 比 widthwise growth 更稳?
  • 什么样的“连接保留率”最影响后续优化?
  • stacking 为什么能在不严格 function preserving 的情况下更好?
  • 它与 scaling law、grokking、representation reuse 有什么关系?

#如果你是工程实践者

论文给出的相对稳妥起点大概是:

  • 先别上复杂增长法,优先试 Gstack
  • growth factor 先从 4 开始
  • 不要 0 token 就 grow,先让小模型学一段
  • 把 grow timing 当成一个真正要调的超参数,而不是拍脑袋决定

#最后总结

这篇论文最有价值的地方在于,它把“模型先小后大”这件事从一个听起来合理的想法,推进成了一个有统一基线、有规模验证、有经验法则的 LLM 预训练方案。

如果以后大家把 “直接从头训练目标模型” 当作唯一默认范式,这篇论文是在认真地说:未必。先训小,再 stack,可能更划算。

对现在越来越昂贵的 LLM 训练来说,这类工作我认为会越来越重要。