给定计算预算下的最佳LLM模型尺寸与预训练数据量分配

给定计算预算下的最佳LLM模型尺寸与预训练数据量分配
FesianXu 20250304 at Wechat Search Team

前言

如果给定了计算预算 C C C,如何分配LLM的模型尺寸 N N N和训练的数据量 D D D,才能使得模型的效果 L L L最好呢?笔者在此介绍一篇经典的文章讨论这个问题。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注明出处,谢谢

  • 关键字:最佳计算预算分配
  • 发表信息:NIPS 2022

∇ \nabla 联系方式:

e-mail: FesianXu@gmail.com

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号:机器学习杂货铺3号店


我们知道在大语言模型(Large Language Model, LLM)中,存在所谓的尺度扩展规律(Scaling Laws) [2],如Fig 1所示,即是:

LLM的性能会随着模型的参数量、模型的训练量、模型的训练数据量的增加而增加

fig1_llm_scaling_laws

Fig 1. 大模型中的尺度扩展规律,测试集损失随着模型训练量、训练集数据量、模型参数量的增加而递减(即是模型性能递增)。

我们也知道模型的参数量、模型的训练量和模型的训练数据量都会影响到最终的计算预算(可以用FLOPs计算),因此LLM的性能可以说和计算预算直接挂钩,这也是Fig 1 左图所表示的。我们不禁会有个疑问,给定了模型的计算预算 C C C,我们应该怎么均衡模型参数量 N N N和预训练的Token数量 D D D,才能使得模型的预训练损失 L L L最小化呢?我们期待得到最优的模型参数 N o p t N_{opt} Nopt和最优的预训练Token数量 D o p t D_{opt} Dopt,可以使得预训练损失最小,正如公式(1)所示。

N o p t ( C ) , D o p t ( C ) = arg ⁡ min ⁡ N , D s . t . F L O P s ( N , D ) = C L ( N , D ) (1) N_{opt}(C), D_{opt}(C) = \underset{N, D \ \mathrm{s.t.} \ \mathrm{FLOPs}(N, D) = C}{\arg\min} L(N,D) \tag{1} Nopt(C),Dopt(C)=N,D s.t. FLOPs(N,D)=CargminL(N,D)(1)

作者探索这个规律的方法论也很直接,作者步进遍历了一遍不同的模型尺寸(从70M到16B参数量),也步进遍历了一遍预训练数据Token数量(从5B到400B),最终跑了超过400个组合的数据点,不得不说有算力真的可以为所欲为。从直观上看越大尺寸的模型需要越多训练的Token,当然我们需要研究具体的比例,作者采用了三种不同的方法去找这个比例关系。

固定模型尺寸下的性能分析

这种方法是分别固定住模型尺寸(从70M到10B多个模型尺寸都需要实验),然后观察训练了不同数量的Tokens数量后,在每一个节点时哪一个模型尺寸能够达到最小的训练损失。如Fig 2 左图 所示, 这里有些地方需要解释。首先这里的横坐标是浮点计算量FLOPs,在不同模型尺寸下,相同的FLOPs能训练的Token数量是不同的,因此才会出现Fig 2左图中同一个FLOPs中,大尺寸模型损失比小尺寸模型还大的情况。从Fig 2 左图中,我们能发现在不同的FLOPs下,到达最小损失的模型尺寸是不一样的(不太容易看出来,在左图中是灰色点,它们形成了一个包络线),不同的FLOPs在对应尺寸模型下能够折算成训练过的Token数量,因此能够画出Fig 2 中图和右图,横坐标是FLOPs,纵坐标是达到最小损失(也就是左图的灰色点)时的模型尺寸和过了的Tokens数。换句话说,Fig 2中图和右图就是给定计算预算 C C C下的最佳模型尺寸 N o p t N_{opt} Nopt和训练数据量 D o p t D_{opt} Dopt,我们发现有 N o p t ∝ C a , D o p t ∝ C b N_{opt} \propto C^{a}, D_{opt} \propto C^{b} NoptCa,DoptCb,通过实验可以算出 a = 0.50 , b = 0.50 a = 0.50, b = 0.50 a=0.50,b=0.50

fig2_fix_model_size_vary_tokens_num

Fig 2. 训练曲线包络。左侧展示了我们所有不同的运行情况。我们启动了一系列模型尺寸,从70M到10B,每个模型针对四个不同的余弦循环周期长度。从这些曲线中,我们提取了每 FLOP 最小损失的包络线,我们利用这些点来估计给定计算预算下的最佳模型尺寸(中间)和最佳训练 token 数量(右侧)。绿色显示了基于训练 Gopher(5.76 × 10²³ FLOP)所用 FLOP 数量的最佳模型尺寸和训练 token 数量的预测。

固定计算预算下的性能分析

第一种方法的计算量FLOPs没有固定,在此方法中我们固定计算量 C C C(也就是所谓的IsoFLOP),分析等量计算下的最佳模型参数量 N o p t N_{opt} Nopt。同时,在知道了每个实验固定的计算量,和在此之下的最佳模型参数量后,也就可以反推训练Token数量。实验如Fig 3 左图所示,可以发现在不同的固定计算量下(从 6 × 1 0 18 6 \times 10^{18} 6×1018 3 × 1 0 21 3 \times 10^{21} 3×1021 FLOPs),遍历不同尺寸的模型能够发现在某些尺寸处会存在明显的低谷,这个低谷就是在固定计算预算情况下的最佳模型参数量,由此也能绘制出Fig 3 中图和右图,绘制逻辑如第一种方法所述。不难发现同样有 N o p t ∝ C a , D o p t ∝ C b N_{opt} \propto C^{a}, D_{opt} \propto C^{b} NoptCa,DoptCb这个规律,算出 a = 0.49 , b = 0.51 a=0.49, b=0.51 a=0.49,b=0.51

fig3_isoFLOP_profiles

Fig 3. 等量浮点运算曲线(IsoFLOP Curves):针对不同模型规模,通过调整训练令牌(token)数量,使得最终总浮点运算量(FLOPs)保持恒定,并设置余弦周期长度以匹配目标FLOPs量。研究发现,损失函数会出现一个明显低谷(如左图),这表明在给定FLOPs计算预算下,存在一个最优的待训练模型。基于这些低谷位置,我们推算出更大模型的最优参数规模与令牌数量(中图和右图)。图中绿色部分展示了在Gopher模型计算预算下,最优模型的参数与令牌数量估计值。

对参数化损失函数进行拟合

在第1和2中方法中已经积累了很多最小损失 L L L下的 F L O P s ( N o p t , D o p t ) = C FLOPs(N_{opt}, D_{opt}) = C FLOPs(Nopt,Dopt)=C的数据点了,我们不妨把损失拆解为三大部分如公式(2)所示,其中第一项 E E E为不可约损失,也就是自然文本的熵,是不可继续减少的最基础的损失项。第二项为(不完美的)参数量为 N N N的Transformer模型训练过程中产生的损失(因为参数量 N N N总是有限,也就是不完美的,因此总是在理想损失 E E E的基础上有超额损失),第三项则是(不完美的)训练数据量 D D D下(因为训练数据量 D D D不可能是无限的)的产生的超额损失。
L ^ ( N , D ) ≜ E + A N α + B D β (2) \hat{L}(N, D) \triangleq E + \frac{A}{N^\alpha} + \frac{B}{D^\beta} \tag{2} L^(N,D)E+NαA+DβB(2)

作者采用L-BFGS算法去最小化所谓的Huber loss(因为数据点只有400多个,这个loss作者说对离群点比较稳健)去进行估计 ( A , B , E , α , β ) (A,B,E,\alpha,\beta) (A,B,E,α,β),笔者也没细究,读者有兴趣的可以翻阅 [3] 和 [4]。最终估计出来的参数为:
E = 1.69 , A = 406.4 , B = 410.7 , α = 0.34 , β = 0.28 (3) E=1.69, A=406.4, B=410.7, \alpha=0.34, \beta=0.28 \tag{3} E=1.69,A=406.4,B=410.7,α=0.34,β=0.28(3)
在LLM Scaling Law的论文 [2] 中提出了一个估算: F L O P s ( N , D ) ≈ 6 N D FLOPs(N, D) \approx 6ND FLOPs(N,D)6ND,借此可以将公式(2)进行变形,得到公式(4)

N o p t ( C ) = G ( C 6 ) a , D o p t ( C ) = G − 1 ( C 6 ) b , 其中 G = ( α A β B ) 1 α + β , a = β α + β , b = α α + β (4) \begin{aligned} N_{\mathrm{opt}}(C) &= G \left( \frac{C}{6} \right)^a, \\ % 公式1,\mathrm{opt}正体下标 D_{\mathrm{opt}}(C) &= G^{-1} \left( \frac{C}{6} \right)^b, \\ % 公式2,G的逆 \text{其中}\quad % 用\text添加中文注释,\quad增加间距 G &= \left( \frac{\alpha A}{\beta B} \right)^{\frac{1}{\alpha + \beta}}, \\ % G的定义(注意分数指数) a &= \frac{\beta}{\alpha + \beta}, \\ % a的定义(β在分子) b &= \frac{\alpha}{\alpha + \beta} % b的定义(α在分子) \end{aligned} \tag{4} Nopt(C)Dopt(C)其中Gab=G(6C)a,=G1(6C)b,=(βBαA)α+β1,=α+ββ,=α+βα(4)

作者算得 a = 0.46 , b = 0.54 a=0.46, b=0.54 a=0.46,b=0.54,具体过程请自行参考原文。

给定计算量下的最优设计

Fig 4是将以上三种预测方法绘制成计算量——最佳模型尺寸估计曲线图,其中那贴上了一些之前工作的估计 [2] 和一些模型的对比,如Gopher(280B参数量)、GPT-3(175B参数量)和Megatron-NLG (530B)参数量。从图中能发现:

  1. 方法1和方法2估计出来的曲线基本上贴合,方法3估计出的模型尺寸在计算预算小的时候和前两者基本贴合,但在大计算预算下会偏小些,不过也不会差距特别大。
  2. 主流的大模型,如Gopher、GPT3等在对应的计算预算下,模型尺寸明显偏大,基本上是贴着 [2] 的曲线走的。

为了证明本文提出的估计方法更佳准确,作者在方法1和2中对齐Gopher的计算预算(大概是 5.76 × 1 0 23 5.76\times10^{23} 5.76×1023 FLOPs),找到了最佳模型尺寸,约是70B,作者将这个训练出来的模型称之为Chinchilla,需要将这个模型的性能和Gopher进行公平对比。注意到在方法1和2中,从Fig 2和Fig 3的右图中可以找出给定预算下的最佳训练Token数量,对于Chinchilla来说是1.4-1.5T左右,因此 D o p t / N o p t ≈ 20 D_{opt}/N_{opt} \approx 20 Dopt/Nopt20

fig4_optimal_size_tokens_prediction

Fig 4. 预测结果叠加对比:我们将三种不同方法的预测结果与Kaplan等人 [2] 的推算进行叠加对比。研究发现,所有方法均表明,当前大型模型的参数规模应显著缩小,并因此需要比现有实践更长的训练时间。

作者在相当多语言下游任务的基准上进行了测试,都发现Chinchilla对比Gopher存在普遍优势,在一些任务中甚至超过了Megatron-NLG 530B模型。这些实验过于冗长,笔者就不展示细节了。

笔者读后感

这篇论文的意义在于告诉我们,在给定了计算预算下,是存在一个最优的模型尺寸和训练数据量的,他们存在一个比例( D o p t ≈ 20 N o p t D_{opt} \approx 20 N_{opt} Dopt20Nopt),越大的模型就需要越多数据进行训练,才能发挥出模型最优的性能。这篇论文的发表时间比较早,是2022年,现在已经有很多工作证实了在推理中进行复杂策略可以有效提高模型性能 [5,6],并且这些推理策略同样也存在Scaling Law。这意味着计算预算不仅可以花在预训练上,而且可以花在推理时的Scaling,这也是这篇文章没有考虑到的点。当然,在 [6] 中作者也承认,推理时的Scaling并非是万能的,而是:

推理时计算与预训练计算并非一对一“可互换”。对于模型能力范围内的简单和中等难度问题,或者在推理(实时性)要求较低的情况下,测试时计算可以轻松弥补额外的预训练。然而,对于超出基础模型能力范围的具有挑战性的问题,或者在推理(实时性)要求较高的情况下,预训练可能更有效于提升性能。

也就是说预训练的地位并不是通过推理时的Scaling就可以替代的,预训练中分配一定量的计算预算对于全方面提高LLM的性能是必须的。结合了模型训练、模型推理的更为综合的最优配比,应该是值得去研究的。

Reference

[1]. Hoffmann, Jordan, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas et al. “Training compute-optimal large language models.” arXiv preprint arXiv:2203.15556 (2022).

[2]. Kaplan, Jared, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. “Scaling laws for neural language models.” arXiv preprint arXiv:2001.08361 (2020).

[3]. J. Nocedal. Updating Quasi-Newton Matrices with Limited Storage. Mathematics of Computation, 35(151):773–782, 1980. ISSN 0025-5718. doi: 10.2307/2006193. URL https://www.jstor.org/stable/2006193 aka L-BFGS

[4]. P. J. Huber. Robust Estimation of a Location Parameter. The Annals of Mathematical Statistics, 35 (1):73–101, Mar. 1964. ISSN 0003-4851, 2168-8990. doi: 10.1214/aoms/1177703732. URL
https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full. aka Huber loss

[5]. https://fesianxu.github.io/2025/03/02/test-time-scaling-laws-20250302/, 《大模型推理时的尺度扩展定律》

[6]. Snell, Charlie, Jaehoon Lee, Kelvin Xu, and Aviral Kumar. “Scaling llm test-time compute optimally can be more effective than scaling model parameters.” arXiv preprint arXiv:2408.03314 (2024).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/72525.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

青训营:简易分布式爬虫

一、项目介绍 该项目是一个简易分布式爬虫系统,以分布式思想为基础,通过多节点协作的方式,将大规模的网页抓取任务分解,从而高效、快速地获取网络数据 。 项目地址:https://github.com/yanchengsi/distributed_crawle…

任务9:交换机基础及配置

CSDN 原创主页:不羁https://blog.csdn.net/2303_76492156?typeblog 一、交换机基础 交换机的概念:交换机是一种网络设备,用于连接多台计算机或网络设备,实现数据包在局域网内的快速交换。交换机基于MAC地址来转发数据包&#x…

YOLOv8改进------------SPFF-LSKA

YOLOv8改进------------SPFF-LSKA 1、LSAK.py代码2、添加YAML文件yolov8_SPPF_LSKA.yaml3、添加SPPF_LSKA代码4、ultralytics/nn/modules/__init__.py注册模块5、ultralytics/nn/tasks.py注册模块6、导入yaml文件训练 1、LSAK.py代码 论文 代码 LSKA.py添加到ultralytics/nn/…

[Lc(2)滑动窗口_1] 长度最小的数组 | 无重复字符的最长子串 | 最大连续1的个数 III | 将 x 减到 0 的最小操作数

目录 1. 长度最小的字数组 题解 代码 ⭕2.无重复字符的最长子串 题解 代码 3.最大连续1的个数 III 题解 代码 4.将 x 减到 0 的最小操作数 题解 代码 1. 长度最小的字数组 题目链接:209.长度最小的字数组 题目分析: 给定一个含有 n 个 正整数 的数组…

安卓binder驱动内核日志调试打印开放及原理(第一节)

背景: 经常有学员朋友在做系统开发时候,有时候遇到binder相关的一些问题,这个时候可能就需要比较多的binder相关日志,但是正常情况下这些binder通讯的的内核日志都是没有的打印的,因为经常binder通讯太过于频繁&#…

docker 安装达梦数据库(离线)

docker安装达梦数据库,官网上已经下载不了docker版本的了,下面可通过百度网盘下载 通过网盘分享的文件:dm8_20240715_x86_rh6_rq_single.tar.zip 链接: https://pan.baidu.com/s/1_ejcs_bRLZpICf69mPdK2w?pwdszj9 提取码: szj9 上传到服务…

MWC 2025 | 紫光展锐联合移远通信推出全面支持R16特性的5G模组RG620UA-EU

2025年世界移动通信大会(MWC 2025)期间,紫光展锐联合移远通信,正式发布了全面支持5G R16特性的模组RG620UA-EU,以强大的灵活性和便捷性赋能产业。 展锐芯加持,关键性能优异 RG620UA-EU模组基于紫光展锐V62…

达梦适配记录-检查服务器

service DmServicedmdb status 查看是否开启,没有配置systemctl,查看《DM8_Linux 服务脚本使用手册》2.1.2.2 1 .拷贝服务模板文件( DmService )到目录( /opt/dmdbms/bin ),并将新文…

Pipeline模式详解:提升程序处理效率的设计模式

文章目录 Pipeline模式详解:提升程序处理效率的设计模式引言Pipeline的基本概念Pipeline的工作原理Pipeline的优势Pipeline的应用场景1. 数据处理2. DevOps中的CI/CD3. 机器学习4. 图像处理 常见的Pipeline实现方式1. 函数式编程中的Pipeline2. 基于消息队列的Pipel…

STM32单片机芯片与内部115 DSP-FIR IIR低通 高通 带通 带阻 中值 自适应 滤波器 逐个数据实时 样条插值拟合

目录 一、FIR 低通、高通、带通、带阻 1、FIR滤波器特点 2、滤波器结构 3、滤波器系数 4、滤波实现 5、FIR 滤波后的群延迟 二、IIR 低通、高通、带通、带阻 1、IIR滤波器特点 2、滤波器结构 3、滤波器系数 4、滤波实现 5、IIR滤波后的群延迟 三、中值滤波 1、中值…

C语言_图书管理系统_借阅系统管理

✨✨ 欢迎大家来到小伞的大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:数据结构与算法 小伞的主页:xiaosan_blog 本文所需对顺序表的理解: 注:由于顺序表实现图书…

表达式基础

文章目录 1、表达式组成1、运算符 2、表达式的分类1、算数运算符1、自增运算符和自减运算2、取余运算(%)3、除法运算(/)4、案例 2、关系运算符3、逻辑运算符4、条件运算符(三目运算符)1、案例 5、赋值运算()1、赋值类型转换2、复合赋值运算 6、逗号运算7、取地址运算(&)8、…

除了合并接口,还有哪些优化 Flask API 的方法?

除了合并接口,还有许多其他方法可以优化 Flask API,以下从性能优化、代码结构优化、安全性优化、错误处理优化等方面详细介绍: 性能优化 1. 使用缓存 内存缓存:可以使用 Flask-Caching 扩展来实现内存缓存,减少对数…

Web服务器配置

配置虚拟主机 通过虚拟主机,可以实现用自定义的域名来访问,并且可以为不同的域名指定不同的站点目录。 配置IP地址和域名的映射关系 申请真实的域名需要一定的费用,为了方便开发,可以通过修改hosts文件来实现将任意域名解析到本…

爬虫逆向实战小记——解决webpack实记

注意!!!!某XX网站实例仅作为学习案例,禁止其他个人以及团体做谋利用途!!! aHR0cHM6Ly9wbW9zLnhqLnNnY2MuY29tLmNuOjIwMDgwL3B4Zi1zZXR0bGVtZW50LW91dG5ldHB1Yi8jL3B4Zi1zZXR0bGVtZW5…

蓝桥杯 之 前缀和与查分

文章目录 题目求和棋盘挖矿 前缀和有利于快速求解 区间的和、异或值 、乘积等情况差分是前缀和的反操作 前缀和 一维前缀和: # 原始的数组num,下标从1到n n len(num) pre [0]*(n1) for i in range(n):pre[i1] pre[i] num[i] # 如果需要求解num[l] 到num[r] 的区…

Windows10下本地搭建Manim环境

文章目录 1. 简介2. Python环境3. uv工具4. Latex软件5. 安装Manim数学库6. 中文支持参考 1. 简介 manim是个一科普动画的库, 本文用到的是社区版本。 2. Python环境 这个不用多说,可以参考其他的文章。记得把pip也安上。 3. uv工具 上面的pip是老…

#UVM# 关于field automation机制中的 pack_bytes 和unpack_bytes 函数剖析

一 pack_bytes 函数 在 UVM 中,pack_bytes 函数用于将类中的所有字段打包成一个字节流(byte stream)。这是 UVM 提供的字段自动化(field automation)机制的一部分,用于简化数据打包和传输。 extern function int pack_bytes(ref byte unsigned bytestream[], input uv…

YOLOv8 自定义目标检测

一、引言 YOLOv8 不仅支持预训练模型的推理,还允许用户将其应用于自定义对象检测。本文将详细介绍如何使用 YOLOv8 训练一个新的模型,并在自定义数据集上进行对象检测。 二、数据集准备 1. 数据集格式 YOLOv8 支持多种数据集格式,包括 CO…

关于tresos Studio(EB)的MCAL配置之GPT

概念 GPT,全称General Purpose Timer,就是个通用定时器,取的名字奇怪了点。定时器是一定要的,要么提供给BSW去使用,要么提供给OS去使用。 配置 General GptDeinitApi控制接口Gpt_DeInit是否启用 GptEnableDisable…