#论文信息
- 标题:Stacking Your Transformers: A Closer Look at Model Growth for Efficient LLM Pre-Training
- 链接:<https://arxiv.org/abs/2405.15319>
- 会议/版本:NeurIPS 2024 / arXiv v2
#一句话结论
这篇论文的核心发现很直接:在 LLM 预训练里,与其设计很复杂的“长大”策略,不如先训练一个更小的模型,再沿着深度方向把它整块堆起来(stack),然后继续训练。 这种方法叫 Gstack。在他们的实验里,它比从头训练更快地达到相同 loss,而且这种优势在更大模型和更长训练 token 下仍然成立。
如果只记一句话,我会记这个:“先训小,再按层堆深,再续训” 是一种足够简单、但 surprisingly 有效的预训练加速路线。
#这篇论文在解决什么问题?
训练大语言模型太贵了。
通常我们会直接从随机初始化开始,按目标模型规模一路训练到底。但一个自然的问题是:
能不能先训练一个小模型,再把它“长大”成大模型,从而减少大模型从零开始学基础能力的成本?
这个方向叫 model growth。问题在于,以前这类工作虽然不少,但真正落到 decoder-only LLM 预训练里,还存在三类障碍:
- 缺少统一比较:不同 growth 方法各自做各自的实验,很难公平比较。
- 可扩展性不清楚:小规模实验有效,不代表到了 3B、7B 甚至更长 token 训练还有效。
- 缺少实用指南:即使方法有效,工程上也不知道该什么时候 grow、grow 多大。
这篇论文就是围绕这三个问题展开的。
#论文做了什么?
作者先把已有增长思路抽象成四类“原子操作”:
- Gdirect:直接复制/堆叠已有参数
- Glearn:通过可学习映射生成新参数
- Gzero:新参数初始化为 0
- Grandom:新参数随机初始化
然后又区分两种增长方向:
- widthwise:在同一层里扩宽
- depthwise:沿网络深度方向加层
他们用统一的 two-stage protocol 去比较:
- 先训练一个小模型,训练 token 数记为 d
- 用某个 growth operator 把它长成目标模型,增长倍数记为 g
- 再继续训练大模型,训练 token 数记为 D
最后比较 loss、benchmark 平均分、以及达到相同效果所需的计算量/训练步数。
#最关键发现:最有效的是 depthwise stacking,也就是 Gstack
实验结论很清楚:
- 横向扩宽(widthwise)基本没什么优势
- 纵向加深(depthwise)明显更有效
- 在 depthwise 方案里,最简单的直接堆叠 Gdirect^up,也就是 Gstack,是最强的
这件事其实很有意思。很多人直觉上会觉得:更复杂的方法、更多可学习映射、更“函数保持”的设计,应该更好。但这篇论文给出的答案反而是:
简单直接地把一个已经学会基本表示的小模型按层复制、堆深,然后继续训,效果最好。
作者在 1.1B LLM 的对比里发现,Gstack 在速度、loss 和绝大多数下游评测上都优于 baseline。论文里举的数字是:相较于从头训练到 100B token,Gstack 带来了 49.1% 的 speedup。
#Gstack 到底是什么?
可以把它理解成:
- 先训练一个较浅的小 Transformer,记作
M - 然后把这个小模型在深度上重复堆叠
g次,形成更深的大模型 - 再继续预训练
作者给出的形式可以写成:
target model = M ◦ M ◦ ... ◦ M(共 g 次)
也就是说,不是给大模型完全随机初始化,也不是只复制某几层,而是把整个小模型作为一个模块整体堆起来。
这个设计背后的直觉是:
- 小模型已经学到一套可用的 token 表示与层间变换
- 直接按深度复制,能保留这套“已有的计算路径”和层间连接模式
- 大模型不是从“不会”开始学,而是从“先有一个能工作的粗糙版本”开始继续细化
论文后面的 ablation 也支持这个理解:越保留原模型的完整连接结构,stacking 效果越好。
#为什么它可能有效?我的理解
这部分是论文启发下的理解,不是论文原文逐句照搬。
我觉得 Gstack 有效,主要因为它抓住了 LLM 预训练里一个很现实的事实:
#1. 预训练前期主要是在学“基本语言电路”
小模型先训练一段时间后,已经学会了很多基础模式,比如:
- token / subword 的统计结构
- 短程依赖
- 一些中低层语法、搭配和局部语义模式
如果大模型从零开始,这些东西要重新学一遍;而 Gstack 相当于把这些初级结构直接搬进更深的网络里。
#2. 深度方向复制,比宽度方向扩展更接近“能力复用”
宽度扩展常常意味着新增大量通道、头、MLP 维度,这些新单元一开始未必和旧表示空间对齐;但深度堆叠更像是在已有表示变换链条上“继续串联”。
这可能更容易让优化器接着往下走,而不是重新协调一大堆新特征维度。
#3. 它不强调严格 function preserving,但依然有效
很多 growth 文献会强调“增长后模型初始函数尽量不变”。这篇论文反而指出,Gstack 并不严格满足这种 function preserving 要求,但实证效果最好。
这很值得注意:
- 对 LLM 预训练来说,最重要的未必是“增长瞬间输出完全不变”
- 更关键的可能是:增长后的参数结构是否给了优化一个更好的起点
换句话说,优化友好性 可能比教科书式的函数保持更重要。
#可扩展性:不是小打小闹,3B / 7B 也有效
论文最有说服力的部分之一,是它不仅在 1.1B 做对比,还继续往上做了 3B 和 7B。
#在 3B 模型上
作者做法是:
- 先训练一个层数为目标模型 1/4 的小模型
- 先训 10B tokens(也就是
d = 10B) - 再用 g = 4 的方式堆成目标模型
- 然后继续训练到 300B tokens
结果是:
- 在 180B、240B tokens 这些点上,Gstack 相比 scratch 分别有 48.6%、54.5% 的 speedup
- benchmark 平均分也更高
#在 7B 模型上
论文摘要里给了一个最醒目的结果:
- 对比一个常规训练的 7B 模型在 300B tokens 的表现
- Gstack 版本只用 194B tokens 就达到相同 loss
- 相当于 54.6% speedup
这个量级已经不只是“略快一点”,而是相当可观的训练加速。
#更重要的一点:训得很久以后它也没掉队
很多“高效训练”方法的问题是:
- 前期 loss 掉得快
- 但训练足够久以后,优势消失,甚至被 scratch 反超
作者专门检查了这一点。
他们在 410M 模型上训练到 750B tokens,这是远超 Chinchilla 推荐 token 数的“过训练”场景。结果发现:
- 到 400B tokens 时,Gstack 仍有 53.1% acceleration
- 到 700B tokens 时,仍有 31.0% acceleration
论文甚至拟合认为,优势在更长训练区间里仍可能持续。
这点很重要,因为现实世界的很多 LLM 训练已经不完全遵守经典 Chinchilla 最优点,往往会继续 overtrain。也就是说,这篇论文不是只在“理想小实验”里成立,而是在更接近实际工业训练的设定里也成立。
#这篇论文最实用的部分:给了 growth timing 和 growth factor 的经验法则
如果只知道 “stacking 好用”,但不知道什么时候 grow、grow 几倍,工程上还是很难落地。
所以论文在第 4.2 节专门回答两个问题:
- growth timing d:小模型应该先训多久再长大?
- growth factor g:应该从多小的模型长到多大的模型?
#1. Growth timing:不是越早越好,也不是越晚越好
作者在 410M、1.1B、3B 三种目标模型上画了 IsoFLOP 曲线,比较不同 d 值:
- 0B
- 1B
- 5B
- 10B
- 20B
- 50B
他们发现曲线出现明显 valley,说明:
对于固定计算预算,存在一个最优的 growth timing。
这很好理解:
- grow 太早:小模型自己还没学出稳定结构,堆起来也没什么可继承
- grow 太晚:你在小模型上花了太多计算,反而浪费了本可用于训练大模型的预算
一个有意思的细节是:
d = 0B(相当于直接把随机初始化小模型拿来 stack)效果和从头训练差不多d = 1B就已经有明显加速
这说明 真正起作用的不是“stack”这个动作本身,而是“先让小模型学到点东西再 stack”。
作者进一步拟合了一个经验式:
log10(d) = a log10(N) + b / log10(C) + c
其中:
N是目标模型参数量C是计算预算d是最优 growth timing
拟合参数约为:
a = 0.88b = 163.27c = -5.74
别太把这个式子当成普适定律,但它至少给了一个很重要的工程认知:
grow timing 是可以被系统调参和经验建模的,不是拍脑袋。
#2. Growth factor:最佳通常在 2 到 4 之间
作者也测试了不同 growth factor。
结果显示:
- 即使比较激进的 grow 也可能有收益
- 但如果 base model 太浅(例如只剩 1 层)就会明显变差
- 综合 IsoFLOP 曲线来看,最优 g 往往落在 2 到 4 之间
由于算力限制,他们没有像 timing 那样给出特别可靠的通用公式,所以最后给出的实用建议是:
默认先用
g = 4。
这也是他们大量实验采用的主设置。
#这篇论文还有哪些值得注意的点?
#1. 它强调“简单 baseline + 统一比较”
这是我很喜欢的一点。论文没有一上来就推一个花哨新机制,而是先把已有思路做成统一实验框架,再证明:最简单的深度堆叠就是最强 baseline。
这种工作很朴素,但很有价值,因为它给后续研究建立了一个更靠谱的比较基线。
#2. 它暗示“连接结构保留”可能很关键
论文后面的 stacking variants 比较显示:
- 整体堆完整小模型最好
- 插值式 stacking、partial stacking 通常更差
这背后很可能说明:不是“把层数变多”本身重要,而是尽量保留已经学好的层间协同结构。
#3. 它对 function preserving 提出了一个实证上的反例
很多网络增长理论偏好 function-preserving 初始化,但这篇工作说明在 LLM 预训练里,非严格 FP 的方案也可能更优。
我觉得这是个很值得继续深挖的点:大模型预训练也许更需要的是“优化轨迹继承”,而不只是“函数值继承”。
#这篇论文的局限性
论文自己也讲了几类限制,我补充整理一下:
#1. growth factor 的经验式还不够扎实
他们自己承认,关于 g 的实验点不够多,所以虽然观察到最佳区间大致在 2 到 4,但还谈不上一个特别稳的 scaling law。
#2. 主要研究的是简单 operator
这其实既是优点也是限制。
优点是结论清楚、可复现、容易落地;
限制是:
- 多阶段 grow
- 动态调学习率
- 更复杂的 mapping 初始化
- 与 MoE / state space / hybrid 架构结合
这些都还没被系统探索。
#3. 评价重点仍是 pretraining loss + 常见 NLP benchmark
虽然这已经很有说服力,但如果放到今天大家更关心的能力维度,还会想继续问:
- 对 instruction tuning 后的能力迁移如何?
- 对推理、代码、多轮对话、工具使用影响如何?
- 对对齐阶段稳定性有没有副作用?
论文附录里做了一些扩展,但主线仍是“预训练效率”。
#4. 还没有直接回答最工业级的问题
比如:
- 在超大规模集群上如何安排 grow 的 checkpoint / pipeline 切换?
- grow 后 optimizer state 怎么迁移最合理?
- 对现代训练 recipe(数据配比、LR schedule、weight decay、rope/attention 变体)是否同样稳健?
这些工程细节,决定了它离真正大规模生产还有多远。
#我对这篇论文的总体评价
我的评价是:这是一篇很扎实、很有工程价值的论文。
它未必在理论上给出特别深的新解释,但它做成了三件非常重要的事:
- 把 model growth 在 LLM 预训练里重新系统化了
- 证明最简单的 Gstack 是个很强、而且可扩展的 baseline
- 第一次把“什么时候 grow、grow 多大”这两个工程问题讲得比较可操作
如果你是做 LLM 预训练系统的人,这篇论文最值得带走的不是某个 fancy 公式,而是一个很实用的范式:
先用较小模型学出基础结构,再通过 depthwise stacking 扩成目标模型继续训练,这可能是比全程 scratch 更省算力的默认路线之一。
#对研究者/工程实践者的启发
#如果你是研究者
可以继续追这些方向:
- 为什么 depthwise stacking 比 widthwise growth 更稳?
- 什么样的“连接保留率”最影响后续优化?
- stacking 为什么能在不严格 function preserving 的情况下更好?
- 它与 scaling law、grokking、representation reuse 有什么关系?
#如果你是工程实践者
论文给出的相对稳妥起点大概是:
- 先别上复杂增长法,优先试 Gstack
- growth factor 先从 4 开始
- 不要 0 token 就 grow,先让小模型学一段
- 把 grow timing 当成一个真正要调的超参数,而不是拍脑袋决定
#最后总结
这篇论文最有价值的地方在于,它把“模型先小后大”这件事从一个听起来合理的想法,推进成了一个有统一基线、有规模验证、有经验法则的 LLM 预训练方案。
如果以后大家把 “直接从头训练目标模型” 当作唯一默认范式,这篇论文是在认真地说:未必。先训小,再 stack,可能更划算。
对现在越来越昂贵的 LLM 训练来说,这类工作我认为会越来越重要。