让我们用 JAX 重建 NanoGPT!(第一部分)

原文:towardsdatascience.com/lets-reproduce-nanogpt-with-jax-part-1-95bec4630eb4?source=collection_archive---------2-----------------------#2024-07-21

第一部分:使用 JAX 构建 124M GPT2。

第二部分:在单 GPU 中优化训练速度。

第三部分:在 JAX 中进行多 GPU 训练。

https://lou1swang.medium.com/?source=post_page---byline--95bec4630eb4--------------------------------https://towardsdatascience.com/?source=post_page---byline--95bec4630eb4-------------------------------- Louis Wang

·发表于 Towards Data Science ·阅读时间:8 分钟·2024 年 7 月 21 日

受到 Andrej Karpathy 最近的 YouTube 视频让我们重建 GPT-2(124M)的启发,我想用 JAX 重建它,并进行大多数训练优化。JAX 专为高效计算速度而构建,非常有趣的是,可以将 Pytorch 与其最近的训练优化以及 JAX 与其相关库(如 Flax:JAX 的神经网络训练层 API 和 Optax:JAX 的梯度处理和优化库)进行对比。我们将迅速了解 JAX,并用 JAX 重建 GPT。最后,我们将比较 Pytorch 和 JAX 在多 GPU 训练中的 token/sec。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/327bdd2b1dfc0479960467df61f2d5da.png

AI 生成的 GPT

什么是 Jax?

根据其readthedoc,JAX 是一个面向加速器的数组计算和程序转换的 Python 库,旨在实现高性能的数值计算和大规模机器学习。我想用它的名字来介绍 JAX。虽然有人称它为 Just Another XLA(加速线性代数),我更愿意称其为 J(it) A(utograd) X(LA),以展示它的高效能力。

J — Just-in-time (JIT) 编译。当你运行 Python 函数时,Jax 将其转换为一组基本操作,称为 Jaxpr。然后,Jaxpr 表达式会被转换为 XLA 的输入,XLA 将其编译成底层脚本,从而为目标设备(CPU、GPU 或 TPU)生成优化后的可执行文件。

A — Autograd。计算梯度是现代机器学习方法中的一个关键部分,你只需要调用jax.grad()来获取梯度,从而优化模型。

X — XLA。这是一个开源的机器学习编译器,支持 CPU、GPU 和 ML 加速器。通常,XLA 会对StableHLO图进行几个内建的优化和分析传递,然后将 HLO 计算发送到后端进行进一步的 HLO 级别优化。后端再进行特定目标的代码生成。

这些只是 JAX 的一些关键特性,但它还有许多类似于 numpy 的用户友好 API,如jax.numpy,以及通过jax.vmap进行的自动向量化,和通过jax.pmap将代码并行化到多个设备上。我们将在以后的博客中介绍更多 Jax 的概念和应用,但现在让我们用 Jax 复现 NanoGPT!

从注意力机制到变换器(Transformer)

GPT 是一种仅解码的变换器模型,关键构建模块是注意力模块。我们可以首先定义一个模型配置数据类来保存模型的超参数,这样模型模块就能高效地使用它来初始化模型架构。类似于 124M GPT 模型,在这里我们初始化一个 12 层的变换器解码器,具有 12 个头和 50257 个词汇表大小,每个词汇表项有 768 维嵌入向量。注意力计算的块大小为 1024。

fromdataclassesimportdataclass@dataclassclassModelConfig:vocab_size:int=50257n_head:int=12n_embd:int=768block_size:int=1024n_layer:int=12dropout_rate:float=0.1

接下来是变换器模型的关键构建模块——注意力机制(Attention)。其思想是将输入处理成三个权重矩阵:Key、Query 和 Value。在这里,我们依赖于flax,这是一个 Jax 层和训练 API 库,用来初始化这三个权重矩阵,只需要调用[flax.linen.Dense](https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.Dense.html)。如前所述,Jax 有许多类似 numpy 的 API,因此我们使用[jax.numpy.reshape](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.reshape.html)将权重矩阵后的输出从[batch_size, sequence_length, embedding_dim]重塑为[batch_size, sequence_length, num_head, embedding_dim / num_head]。由于我们需要对 Key 和 Value 矩阵执行矩阵乘法,jax 还提供了[jax.numpy.matmul](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.matmul.html)[jax.numpy.transpose](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.transpose.html)API(用于转置 Key 矩阵以进行乘法运算)。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8f2916bfb42338ef17e1526a677e4f85.png

多头注意力(Multihead Attention)

请注意,我们需要在注意力矩阵上加上一个掩码,以避免信息泄漏(防止之前的 tokens 访问到后面的 tokens),[jax.numpy.tril](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tril.html)帮助构建一个下三角数组,而[jax.numpy.where](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html)可以为我们填充无限大的数值,以便在 softmax[jax.nn.softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html)后得到 0。多头注意力的完整代码如下所示。

fromflaximportlinenasnnimportjax.numpyasjnpclassCausalSelfAttention(nn.Module):config:ModelConfig@nn.compactdef__call__(self,x,deterministic=True):assertlen(x.shape)==3b,l,d=x.shape q=nn.Dense(self.config.n_embd)(x)k=nn.Dense(self.config.n_embd)(x)v=nn.Dense(self.config.n_embd)(x)# q*k / sqrt(dim) -> softmax -> @vq=jnp.reshape(q,(b,l,d//self.config.n_head,self.config.n_head))k=jnp.reshape(k,(b,l,d//self.config.n_head,self.config.n_head))v=jnp.reshape(v,(b,l,d//self.config.n_head,self.config.n_head))norm=jnp.sqrt(list(jnp.shape(k))[-1])attn=jnp.matmul(q,jnp.transpose(k,(0,1,3,2)))/norm mask=jnp.tril(attn)attn=jnp.where(mask[:,:,:l,:l],attn,float("-inf"))probs=jax.nn.softmax(attn,axis=-1)y=jnp.matmul(probs,v)y=jnp.reshape(y,(b,l,d))y=nn.Dense(self.config.n_embd)(y)returny

你可能会注意到,在 Pytorch 中常见的__init__forward方法在这里并不存在。这是 jax 的特点,在 jax 中你可以显式地通过setup方法定义层,或者通过在__call__方法上添加nn.compact来隐式定义它们。[参考]

接下来让我们构建 MLP 和 Block 层,包括 Dense 层、Gelu 激活函数、LayerNorm 和 Dropout。再次,flax.linen 提供了层的 API,帮助我们构建模块。请注意,我们会传递一个deterministic布尔变量来控制某些层(如 Dropout)在训练或评估期间的不同行为。

classMLP(nn.Module):config:ModelConfig@nn.compactdef__call__(self,x,deterministic=True):x=nn.Dense(self.config.n_embd*4)(x)x=nn.gelu(x,approximate=True)x=nn.Dropout(rate=self.config.dropout_rate)(x,deterministic=deterministic)x=nn.Dense(self.config.n_embd)(x)x=nn.Dropout(rate=self.config.dropout_rate)(x,deterministic=deterministic)returnxclassBlock(nn.Module):config:ModelConfig@nn.compactdef__call__(self,x):x=nn.LayerNorm()(x)x=x+CausalSelfAttention(self.config)(x)x=nn.LayerNorm()(x)x=x+MLP(self.config)(x)returnx

现在让我们使用上述模块来构建 NanoGPT:

给定一个序列的 token ids 输入,我们使用[flax.linen.Embed](https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.Embed.html)层来获取位置嵌入和 token 嵌入。然后,我们将它们传入 Block 模块 N 次,其中 N 是模型配置中定义的层数。最后,我们将来自最后一个 Block 的输出映射到每个词汇表 token 的概率,以预测下一个 token。除了前向__call__方法之外,我们还需要创建一个init方法来获取虚拟输入并获得模型的参数。

classGPT(nn.Module):config:ModelConfig@nn.compactdef__call__(self,x,deterministic=False):B,T=x.shapeassertT<=self.config.block_size pos=jnp.arange(0,T)[None]pos_emb=nn.Embed(self.config.block_size,self.config.n_embd)(pos)wte=nn.Embed(self.config.vocab_size,self.config.n_embd)tok_emb=wte(x)x=tok_emb+pos_embfor_inrange(self.config.n_layer):x=Block(self.config)(x)x=nn.LayerNorm()(x)logits=nn.Dense(config.n_embd,config.vocab_size)(x)# logits = wte.attend(x) # parameter sharingreturnlogitsdefinit(self,rng):tokens=jnp.zeros((1,self.config.block_size),dtype=jnp.uint16)params=jax.jit(super().init,static_argnums=(2,))(rng,tokens,True)returnparams

现在让我们验证一下参数的数量:我们首先初始化模型配置的数据类和随机密钥,然后创建一个虚拟输入并将其输入到 GPT 模型中。接着,我们利用jax.util.treemapAPI 创建一个计数参数函数。我们得到了124439808(124M)个参数,与 Huggingface 的 GPT2 相同,哇!

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/513bfdf90096dcb03e3cd4a76910d2d7.png

Colab 结果:参数数量

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/576fc03b5cb1912c0ba770b510de7b73.png

验证 Huggingface 的 GPT2 参数数量

数据加载器和训练循环

现在让我们在一个小数据集上进行过拟合。为了与 Andrej 的 Pytorch NanoGPT 视频中进行对比,我们使用他在视频中分享的玩具 dataset。我们使用tiktoken库的 GPT2 分词器对输入文件中的所有文本进行分词,并将这些 token 转换为jax.numpy.array以便 Jax 的模型训练。

classDataLoader:def__init__(self,B,T):self.current_position=0self.B=B self.T=Twithopen("input.txt","r")asf:text=f.read()enc=tiktoken.get_encoding("gpt2")self.tokens=jnp.array(enc.encode(text))print(f"loaded{len(self.tokens)}tokens in the datasets")print(f" 1 epoch ={len(self.tokens)//(B*T)}batches")defnext_batch(self):B,T=self.B,self.T buf=self.tokens[self.current_position:self.current_position+B*T+1]x,y=jnp.reshape(buf[:-1],(B,T)),jnp.reshape(buf[1:],(B,T))self.current_position+=B*Tifself.current_position+B*T+1>len(self.tokens):self.current_position=0returnx,y

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/da59795faf894d9f4bd6f4d1ed782ece.png

Colab 结果:简单的数据加载器,批量大小为 4,序列长度为 128

接下来,让我们暂时忽略分布式训练和优化,先创建一个简单的训练循环进行基本检查。初始化模型后的第一件事是创建一个TrainState,这是一个可以更新参数和梯度的模型状态。TrainState 接受三个重要输入:apply_fn(模型前向函数)、params(来自初始化方法的模型参数)和 tx(一个 Optax 梯度变换)。

然后我们使用 train_step 函数来更新模型状态(梯度和参数),以继续模型训练。Optax提供了用于下一个令牌预测任务的 softmax 交叉熵作为损失函数,jax.value_and_grad用于计算损失函数的梯度和损失值。最后,我们使用apply_gradientsAPI 更新模型的状态和新参数。[ref] 别忘了对 train_step 函数进行 JIT 编译,以减少计算开销!

definit_train_state(key,config)->TrainState:model=GPT(config)params=model.init(key)optimizer=optax.adamw(3e-4,b1=0.9,b2=0.98,eps=1e-9,weight_decay=1e-1)train_state=TrainState.create(apply_fn=model.apply,params=params,tx=optimizer)returntrain_state@jax.jitdeftrain_step(state:TrainState,x:jnp.ndarray,y:jnp.ndarray)->Tuple[jnp.ndarray,TrainState]:defloss_fn(params:FrozenDict)->jnp.ndarray:logits=state.apply_fn(params,x,False)loss=optax.softmax_cross_entropy_with_integer_labels(logits,y).mean()returnloss loss,grads=jax.value_and_grad(loss_fn,has_aux=False)(state.params)new_state=state.apply_gradients(grads=grads)returnloss,new_state

现在一切准备就绪,可以开始进行简单的训练循环了……让我们检查损失值。模型的预测应该优于随机猜测,因此损失值应该低于 -ln(1/50257)≈10.825。我们对单批次过拟合的预期是:一开始损失接近 10.825,然后下降到接近 0。让我们取一批(x,y)并运行训练循环 50 次。我还添加了类似的日志来计算训练速度。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/add7d951685406ff5cfce20cccb41414.png

如我们所见,损失值正是我们预期的,训练吞吐量大约是 400–500 k token/sec。这已经比 Andrej 视频中没有任何优化的 Pytorch 初始版本快了 40 倍。请注意,我们是在 1 个 A100 GPU 上运行 Jax 脚本,这应该消除了硬件差异对速度比较的影响。这里没有.to(device)的操作来将模型或数据从主机 CPU 移动到设备 GPU,这正是 Jax 的一个优势!

就这样,我们做到了。我们将在第二部分通过更多优化将训练速度提升至原来的 10 倍…

第二部分:训练优化之旅,如何在单个 GPU 上达到 1350k tokens/sec!

“除非另有说明,所有图片均为作者所提供”

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1120547.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Dis++查看磁盘SMART状态预防硬件故障

使用Dis查看磁盘SMART状态预防硬件故障 在AI模型训练日益常态化的今天&#xff0c;一个看似不起眼的硬盘故障&#xff0c;可能让数天的训练成果付诸东流。某次深夜&#xff0c;一位研究员正进行Qwen3-VL多模态模型的GRPO强化学习训练&#xff0c;任务已持续72小时。突然&#x…

软考高项公认的高含金量、高实用性、高性价比证书

软考高项&#xff0c;即信息系统项目管理师&#xff0c;属于计算机技术与软件&#xff08;高级&#xff09;专业技术资格。简称为“高级项目经理、管理师”&#xff0c;相当于高级职称。可以以考代评&#xff0c;积分落户或办理居住证&#xff0c;企业信息系统集成资质申请&…

让我们重新审视包括新玩家 Pandas 在内的不同库中的 Case-When:

原文&#xff1a;towardsdatascience.com/lets-revisit-case-when-in-different-libraries-including-the-new-player-pandas-8c4febb979ba 无论您是在进行数据分析、数据清洗&#xff0c;甚至特征工程&#xff0c;创建基于其他列值的新列都是一个经常进行的操作。 我用于数据…

BlindWaterMark盲水印终极指南:5分钟学会图像版权保护

BlindWaterMark盲水印终极指南&#xff1a;5分钟学会图像版权保护 【免费下载链接】BlindWaterMark 盲水印 by python 项目地址: https://gitcode.com/gh_mirrors/bli/BlindWaterMark 在数字时代&#xff0c;图像版权保护变得前所未有的重要。BlindWaterMark作为一款基于…

HunyuanVideo-Foley:革命性AI音效生成技术重塑视频创作生态

HunyuanVideo-Foley&#xff1a;革命性AI音效生成技术重塑视频创作生态 【免费下载链接】HunyuanVideo-Foley 项目地址: https://ai.gitcode.com/tencent_hunyuan/HunyuanVideo-Foley 在视频内容创作成为主流的今天&#xff0c;AI音效生成技术正在彻底改变传统音效制作…

vivado安装包组件选择策略:入门级完整示例参考

Vivado安装组件怎么选&#xff1f;新手避坑指南&#xff1a;从零构建轻量高效FPGA开发环境你是不是也经历过这样的场景——花两三个小时下载Vivado安装包&#xff0c;勾选“全部安装”&#xff0c;结果磁盘直接爆满、系统卡顿、启动缓慢……最后发现&#xff0c;90%的功能根本用…

2026专科生必备!8个降AI率工具测评榜单

2026专科生必备&#xff01;8个降AI率工具测评榜单 为什么专科生需要一份靠谱的降AI率工具榜单&#xff1f; 随着人工智能技术在学术领域的广泛应用&#xff0c;论文、报告甚至作业的AI检测标准也在不断提升。对于专科生而言&#xff0c;如何在保证内容质量的同时降低AI率&…

使用 Python 多线程提升你的编码技能

原文&#xff1a;towardsdatascience.com/level-up-your-coding-skills-with-python-threading-8f1bd06b9476 https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/9cbfec975450d8357e227d828448ea09.png 由Sonika Agarwal在Unsplash上的照片 …

ESP32连接阿里云MQTT:网络协议栈配置实战案例

ESP32连接阿里云MQTT实战&#xff1a;从协议栈配置到稳定上线的完整路径 你有没有遇到过这样的场景&#xff1f; ESP32明明连上了Wi-Fi&#xff0c;IP也拿到了&#xff0c;可就是连不上阿里云&#xff1b;日志里反复打印“TLS handshake failed”或“Connection timeout”&am…

[特殊字符]_网络IO性能优化:从TCP到HTTP的层层优化[20260106161818]

作为一名专注于网络性能优化的工程师&#xff0c;我在过去的项目中积累了丰富的网络IO优化经验。最近&#xff0c;我参与了一个对网络性能要求极高的项目——实时视频流平台。这个项目让我重新审视了Web框架在网络IO方面的表现。今天我要分享的是基于真实项目经验的网络IO性能优…

利用 KeyBERT、HDBSCAN 和 Zephyr-7B-Beta 构建知识图谱

原文&#xff1a;towardsdatascience.com/leverage-keybert-hdbscan-and-zephyr-7b-beta-to-build-a-knowledge-graph-33d7534ee01b?sourcecollection_archive---------0-----------------------#2024-01-07 增强型大语言模型自然语言处理与传统机器学习技术结合&#xff0c;用…

SAPlink终极指南:5个技巧掌握ABAP对象高效管理

SAPlink终极指南&#xff1a;5个技巧掌握ABAP对象高效管理 【免费下载链接】SAPlink SAPlink 项目地址: https://gitcode.com/gh_mirrors/sa/SAPlink SAPlink是一款专为SAP Netweaver系统设计的ABAP对象导入导出工具&#xff0c;通过独特的Nugget文件格式实现了代码的便…

ms-swift支持训练任务超时自动终止释放资源

ms-swift支持训练任务超时自动终止释放资源 在大模型时代&#xff0c;一个看似微不足道的“卡住”任务&#xff0c;可能意味着数小时GPU算力的浪费、数千元云成本的流失&#xff0c;甚至影响整个团队的迭代节奏。你有没有经历过这样的场景&#xff1a;提交了一个LoRA微调任务&…

得意黑 Smiley Sans 字体安装与应用全攻略:从下载到专业设计的完美指南

得意黑 Smiley Sans 字体安装与应用全攻略&#xff1a;从下载到专业设计的完美指南 【免费下载链接】smiley-sans 得意黑 Smiley Sans&#xff1a;一款在人文观感和几何特征中寻找平衡的中文黑体 项目地址: https://gitcode.com/gh_mirrors/smi/smiley-sans 还在为字体安…

STNodeEditor实战指南:构建高效可视化编程工作流

STNodeEditor实战指南&#xff1a;构建高效可视化编程工作流 【免费下载链接】STNodeEditor 一款基于.Net WinForm的节点编辑器 纯GDI绘制 使用方式非常简洁 提供了丰富的属性以及事件 可以非常方便的完成节点之间数据的交互及通知 大量的虚函数供开发者重写具有很高的自由性 …

盲水印终极使用指南:保护图像版权的完整解决方案

盲水印终极使用指南&#xff1a;保护图像版权的完整解决方案 【免费下载链接】BlindWaterMark 盲水印 by python 项目地址: https://gitcode.com/gh_mirrors/bli/BlindWaterMark 盲水印技术是现代数字版权保护的重要工具&#xff0c;它能在不改变图像视觉质量的前提下&a…

常见网络安全威胁和防御措施

网络安全威胁是一种技术风险&#xff0c;会削弱企业网络的防御能力&#xff0c;危及专有数据、关键应用程序和整个 IT 基础设施。由于企业面临广泛的威胁&#xff0c;因此他们应该仔细监控和缓解最关键的威胁和漏洞。网络安全问题有七大类&#xff0c;它们都包括多种威胁&#…

ncmdumpGUI终极指南:网易云音乐NCM格式转换完整解决方案

ncmdumpGUI终极指南&#xff1a;网易云音乐NCM格式转换完整解决方案 【免费下载链接】ncmdumpGUI C#版本网易云音乐ncm文件格式转换&#xff0c;Windows图形界面版本 项目地址: https://gitcode.com/gh_mirrors/nc/ncmdumpGUI 在音乐数字化时代&#xff0c;网易云音乐的…

终极SAP开发利器:SAPlink高效代码迁移完全指南

终极SAP开发利器&#xff1a;SAPlink高效代码迁移完全指南 【免费下载链接】SAPlink SAPlink 项目地址: https://gitcode.com/gh_mirrors/sa/SAPlink 在传统的SAP Netweaver开发环境中&#xff0c;ABAP程序员常常面临一个痛点&#xff1a;如何在不同系统间安全、高效地迁…

视频字幕制作效率革命:AI智能助手如何10倍提升创作生产力

视频字幕制作效率革命&#xff1a;AI智能助手如何10倍提升创作生产力 【免费下载链接】VideoCaptioner &#x1f3ac; 卡卡字幕助手 | VideoCaptioner - 基于 LLM 的智能字幕助手&#xff0c;无需GPU一键高质量字幕视频合成&#xff01;视频字幕生成、断句、校正、字幕翻译全流…