DeepSpeed简介及加速模型训练

DeepSpeed是由微软开发的开源深度学习优化框架,专注于大规模模型的高效训练与推理。其核心目标是通过系统级优化技术降低显存占用、提升计算效率,并支持千亿级参数的模型训练。

官网链接:deepspeed
训练代码下载:git代码

一、DeepSpeed的核心作用

  1. 显存优化与高效内存管理

    • ZeRO(Zero Redundancy Optimizer)技术:通过分片存储模型状态(参数、梯度、优化器状态)至不同GPU或CPU,显著减少单卡显存占用。例如,ZeRO-2可将显存占用降低8倍,支持单卡训练130亿参数模型。
      在这里插入图片描述

    • Offload技术:将优化器状态卸载到CPU或NVMe硬盘,扩展至TB级内存,支持万亿参数模型训练。

    • 激活值重计算(Activation Checkpointing):牺牲计算时间换取显存节省,适用于长序列输入。

  2. 灵活的并行策略

    • 3D并行:融合数据并行(DP)、模型并行(张量并行TP、流水线并行PP),支持跨节点与节点内并行组合,适应不同硬件架构。

    • 动态批处理与梯度累积:减少通信频率,支持超大Batch Size训练。

  3. 训练加速与混合精度支持

    • 混合精度训练:支持FP16/BF16,结合动态损失缩放平衡效率与数值稳定性。

    • 稀疏注意力机制:针对长序列任务优化,执行效率提升6倍。

    • 通信优化:支持MPI、NCCL等协议,降低分布式训练通信开销。

  4. 推理优化与模型压缩

    • 低精度推理:通过INT8/FP16量化减少模型体积,提升推理速度。

    • 模型剪枝与蒸馏:压缩模型参数,降低部署成本。


二、与pytorch 对比分析

1. 优势

  • 显存效率:相比PyTorch DDP,单卡80GB GPU可训练130亿参数模型(传统方法仅支持约10亿)。

  • 并行灵活性:支持3D并行组合,优于Horovod(侧重数据并行)和Megatron(侧重模型并行)。

  • 生态集成:与Hugging Face Transformers、PyTorch无缝兼容,简化现有项目迁移。

  • 全流程覆盖:同时优化训练与推理,而vLLM仅专注推理优化。

2. 局限性

  • 配置复杂度:分布式训练需手动调整通信策略和分片参数,学习曲线陡峭(需编写JSON配置文件)。

  • 硬件依赖:部分高级功能(如ZeRO-Infinity)依赖NVMe硬盘或特定GPU架构。

  • 推理效率:纯推理场景下,vLLM的吞吐量更高(连续批处理优化更专精)。


三、训练用例

1、ds_config.json(deepspeed执行训练时,使用的配置文件)
  • deepspeed训练模型时,不需要在代码中定义优化器,只需要在 json 文件中进行配置即可, json文件内容如下:
{"train_batch_size": 128, //所有GPU上的 单个训练批次大小 之和"gradient_accumulation_steps": 1, //梯度累积 步数"optimizer": {"type": "Adam", //选择的 优化器"params": {"lr": 0.00015 //相关学习率大小}},"zero_optimization": { //加速策略"stage":2}
}

2、训练函数

  • 将模型包装成 deepspeed 形式
#将模型 包装成 deepspeed 形式
model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())
  • 使用 deepspeed 包装后的模型 进行 反向传播和梯度更新
#使用 deepspeed 进行 反向传播和梯度更新
#反向传播
model_engine.backward(loss)#梯度更新
model_engine.step()
  • 完整训练代码如下:
'''
使用命令行进行启动启动命令如下:
deepspeed ds_train.py --epochs 10 --deepspeed --deepspeed_config ds_config.json
'''import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModelif __name__ == '__main__':#读取命令行 传递的参数parser = argparse.ArgumentParser()parser.add_argument("--local_rank", help = "local device id on current node", type = int, default=0)parser.add_argument("--epochs", type = int, default=1)parser = deepspeed.add_config_arguments(parser)args = parser.parse_args()#获取数据集train_loader, test_loader = load_data() #数据集加载器中的 batch_size的大小 = (ds_config.json中 train_batch_size/gpu数量)#获取原始模型model = CustomModel().cuda()#将模型 包装成 deepspeed 形式model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())loss_fn = torch.nn.CrossEntropyLoss().cuda() # 损失函数(分类任务常用)for i in range(args.epochs):for inputs, labels in train_loader:#前向传播inputs = inputs.cuda()labels = labels.cuda()outputs = model_engine(inputs)loss = loss_fn(outputs, labels)#使用 deepspeed 进行 反向传播和梯度更新#反向传播model_engine.backward(loss)#梯度更新model_engine.step()model_engine.save_checkpoint('./ds_models', i)#模型保存torch.save(model_engine.module.state_dict(),'deepspeed_train_model.pth')
3、模型评估

import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt# 1. 定义数据转换(预处理)
transform = transforms.Compose([transforms.ToTensor(),          # 转为Tensor格式(自动归一化到0-1)transforms.Normalize((0.1307,), (0.3081,))  # 标准化(MNIST的均值和标准差)
])test_data = datasets.MNIST(root='./data',train=False,          # 测试集transform=transform)#获取数据集
train_loader, test_loader = load_data()model = CustomModel()
model.load_state_dict(torch.load('deepspeed_train_model.pth'))#评估
model.eval()  # 设置为评估模式
correct = 0
total = 0with torch.no_grad():  # 不计算梯度(节省内存)for images, labels in test_loader:images, labels = images, labelsoutputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 取概率最大的类别total += labels.size(0)correct += (predicted == labels).sum().item()print(f"测试集准确率: {100 * correct / total:.2f}%")# 随机选择一张测试图片
index = np.random.randint(0,1000)  # 可以修改这个数字试不同图片
test_image, true_label = test_data[index]
test_image = test_image.unsqueeze(0)  # 增加批次维度# 预测
with torch.no_grad():output = model(test_image)
predicted_label = torch.argmax(output).item()print(f"预测: {predicted_label}, 真实: {true_label}")# 显示结果
plt.imshow(test_image.cpu().squeeze(), cmap='gray')
plt.title(f"预测: {predicted_label}, 真实: {true_label}")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/83903.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

集星獭 | 重塑集成体验:新版编排重构仿真电商订单数据入库

概要介绍 新版服务编排以可视化模式驱动电商订单入库流程升级,实现订单、客户、库存、发票、发货等环节的自动化处理。流程中通过循环节点、判断逻辑与数据查询的编排,完成了低代码构建业务逻辑,极大提升订单处理效率与业务响应速度。 背景…

AMO——下层RL与上层模仿相结合的自适应运动优化:让人形行走操作(loco-manipulation)兼顾可行性和动力学约束

前言 自从去年24年Q4,我司「七月在线」侧重具身智能的场景落地与定制开发之后 去年Q4,每个月都会进来新的具身需求今年Q1,则每周都会进来新的具身需求Q2的本月起,一周不止一个需求 特别是本周,几乎每天都有国企、央企…

MATLAB中进行语音信号分析

在MATLAB中进行语音信号分析是一个涉及多个步骤的过程,包括时域和频域分析、加窗、降噪滤波、端点检测以及特征提取等。 1. 加载和预览语音信号 首先,你需要加载一个语音信号文件。MATLAB支持多种音频文件格式,如.wav。 [y, fs] audiorea…

JWT令牌验证

一、JWT 验证方式详解 JWT(JSON Web Token)的验证核心是确保令牌未被篡改且符合业务规则,主要分为以下步骤: 1. 令牌解析与基础校验 收到客户端传递的 JWT 后,首先按 . 分割为三部分:Header、Payload、S…

一文讲清python、anaconda的安装以及pycharm创建工程

软件下载 Pycharm下载地址: Other Versions - PyCharm anaconda下载地址: https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Windows-x86_64.exe 安装步骤 一、 Python 解释器的安装步骤 安装目录介绍: 二、 Anaconda 安装 2.1 安装步…

Mac如何允许安装任何来源软件?

打开系统偏好设置-安全性与隐私,点击右下角的解锁按钮,选择允许从任何来源。 如果没有这一选项,请到打开终端,输入命令行:sudo spctl --master-disable, 输入命令后回车,输入电脑的开机密码后回车。 返回“…

React Flow 中 Minimap 与 Controls 组件使用指南:交互式小地图与视口控制定制(含代码示例)

本文为《React Agent:从零开始构建 AI 智能体》专栏系列文章。 专栏地址:https://blog.csdn.net/suiyingy/category_12933485.html。项目地址:https://gitee.com/fgai/react-agent(含完整代码示​例与实战源)。完整介绍…

Windows Ubuntu 目录映射关系

情况一:你是通过 WSL (Windows Subsystem for Linux) 安装 Ubuntu 这是最常见的情况。如果你在 Microsoft Store 安装了 “Ubuntu”,默认就是 WSL。 📁 目录映射关系如下: 从 Ubuntu(WSL)访问 Windows&…

双指针法高效解决「移除元素」问题

双指针法高效解决「移除元素」问题 双指针法高效解决「移除元素」问题一、问题描述二、解法解析:双指针法1. 核心思想2. 算法步骤3. 执行过程示例 三、关键点分析四、复杂度分析五、与其他解法的比较1. 快慢指针法2. 本解法的优势 六、实际应用场景七、总结 双指针法…

知识图谱构架

目录 知识图谱构架 一、StanfordNLP 和 spaCy 工具介绍 (一)StanfordNLP 主要功能 使用示例 (二)spaCy 主要功能 使用示例 二、CRF 和 BERT 的基本原理和入门 (一)CRF(条件随机场&…

激光三角测量标定与应用

文章目录 1,介绍。2,技术原理3,类型。3.1,直射式3.2,斜射式3.3,两种三角位移传感器特性的比较 4,什么是光片?5,主要的算子。1,create_sheet_of_light_model2&…

高可用消息队列实战:AWS SQS 在分布式系统中的核心解决方案

引言:消息队列的“不可替代性” 在微服务架构和分布式系统盛行的今天,消息队列(Message Queue) 已成为解决系统解耦、流量削峰、异步处理等难题的核心组件。然而,传统的自建消息队列(如RabbitMQ、Kafka&am…

人工智能核心知识:AI Agent 的四种关键设计模式

人工智能核心知识:AI Agent 的四种关键设计模式 一、引言 在人工智能领域,AI Agent(人工智能代理)是实现智能行为和决策的核心实体。它能够感知环境、做出决策并采取行动以完成特定任务。为了设计高效、灵活且适应性强的 AI Age…

平替BioLegend品牌-Elabscience PE Anti-Mouse Foxp3抗体:流式细胞术中的高效工具,助力免疫细胞分析!”

概述 调节性T细胞(Treg)在维持免疫耐受和抑制过度免疫反应中发挥关键作用,其标志性转录因子Foxp3(Forkhead box P3)是Treg功能研究的重要靶点。Elabscience 推出的抗小鼠Foxp3抗体(3G3-E)&…

编程日志5.13

邻接表的基础代码 #include<iostream> using namespace std; //邻接表的类声明 class Graph {private: //结构体EdgeNode表示图中的边结点,包含顶点vertex、权重weight和指向下一个边结点的指针next struct EdgeNode { int vertex; int weight; …

PowerBI 矩阵实现动态行内容(如前后销售数据)统计数据,以及过滤同时为0的数据

我们有一张活动表 和 一张销售表 我们想实现如下的效果&#xff0c;当选择某个活动时&#xff0c;显示活动前后3天的销售对比图&#xff0c;如下&#xff1a; 实现方法&#xff1a; 1.新建一个表&#xff0c;用于显示列&#xff1a; 2.新建一个度量值&#xff0c;用SELECTEDVA…

Prompt Tuning:高效微调大模型的新利器

Prompt Tuning(提示调优)是什么 Prompt Tuning(提示调优) 是大模型参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)的重要技术之一,其核心思想是通过优化 连续的提示向量(而非整个模型参数)来适配特定任务。以下是关于 Prompt Tuning 的详细解析: 一、核心概念…

杰发科技AC7840——如何把结构体数据写到Dflash中

1. 结构体数据被存放在Pflash中 正常情况下&#xff0c;可以看到全局变量的结构体数据被存放在Pflash中 数字部分存在RAM中 2. 最小编程单位 8字节编程&#xff0c;因此如果结构体存放在Dfalsh中&#xff0c;进行写操作&#xff0c;需要写8字节的倍数 第一种办法&#xff1a;…

CSS 选择器入门

一、CSS 选择器基础&#xff1a;快速掌握核心概念 什么是选择器&#xff1f; CSS 选择器就像 “网页元素的遥控器”&#xff0c;用于定位 HTML 中的特定元素并应用样式。 /* 结构&#xff1a;选择器 { 属性: 值; } */ p { color: red; } /* 选择所有<p>元素&#xff0c;…

Anaconda3安装教程(附加安装包)Anaconda详细安装教程Anaconda3 最新版安装教程

多环境隔离 可同时维护生产环境、开发环境、测试环境&#xff0c;例如&#xff1a; conda create -n ml python3.10 # 创建机器学习环境 conda activate ml # 激活环境三、Anaconda3 安装教程 解压Anaconda3安装包 找到下载的 Anaconda3 安装包&#xff08;.ex…