PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

在这里插入图片描述

二、完整 DDP 训练 Demo

  • 官方DDP Dem参考

1. 基础训练脚本 (ddp_demo.py)

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScalerdef setup(rank, world_size):"""初始化分布式环境"""os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():"""清理分布式环境"""dist.destroy_process_group()class SimpleModel(nn.Module):"""简单的CNN模型"""def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.fc = nn.Linear(9216, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.max_pool2d(x, 2)x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = torch.flatten(x, 1)return self.fc(x)def prepare_dataloader(rank, world_size, batch_size=32):"""准备分布式数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)return loaderdef train(rank, world_size, epochs=2):"""训练函数"""setup(rank, world_size)# 设置当前设备torch.cuda.set_device(rank)# 初始化模型、优化器等model = SimpleModel().to(rank)ddp_model = DDP(model, device_ids=[rank])optimizer = optim.Adam(ddp_model.parameters())scaler = GradScaler()  # 混合精度训练criterion = nn.CrossEntropyLoss()train_loader = prepare_dataloader(rank, world_size)for epoch in range(epochs):ddp_model.train()train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shufflefor batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(rank), target.to(rank)optimizer.zero_grad()# 混合精度训练with torch.autocast(device_type='cuda', dtype=torch.float16):output = ddp_model(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()if batch_idx % 100 == 0:print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")cleanup()if __name__ == "__main__":# 单机多卡启动时,torchrun会自动设置这些环境变量rank = int(os.environ['LOCAL_RANK'])world_size = int(os.environ['WORLD_SIZE'])train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):# 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    if rank == 0:  # 只在主进程保存torch.save(model.state_dict(), "model.pth")
    
  3. 多机训练

    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py# 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    for i, (data, target) in enumerate(train_loader):if i % 2 == 0:optimizer.zero_grad()# 前向和反向...if i % 2 == 1:optimizer.step()
    
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/73811.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

策略模式_行为型_GOF23

策略模式 策略模式(Strategy Pattern)是一种行为型设计模式,核心思想是将一组算法封装成独立对象,使它们可以相互替换,从而让算法的变化独立于使用它的客户端。这类似于游戏中的技能切换——玩家根据战况选择不同技能…

【Python】天气数据可视化

1. Python进行数据可视化 在数据分析和科学计算领域,Python凭借其强大的库和简洁的语法,成为了众多开发者和科研人员的首选工具。数据可视化作为数据分析的重要环节,能够帮助我们更直观地理解数据背后的规律和趋势。本文将详细介绍如何使用P…

深度学习4.4笔记

《动手学深度学习》-4.4-笔记 验证数据集:通常是从训练集中划分出来的一部分数据,不要和训练数据混在一起,评估模型好坏的数据集 测试数据集:只用一次的数据集 k-折交叉验证(k-Fold Cross-Validation)是…

vue 两种路由模式

一、两种模式比较 在vue.js中,路由模式分为两种:hash 模式和 history 模式。这两种模式决定了URL的结构和浏览器历史记录的管理方式。 1. hash 模式带 #,#后面的地址变化不会引起页面的刷新。换句话说,hash模式不会将#后面的地址…

Android生态大变革,谷歌调整开源政策,核心开发不再公开

“开源”这个词曾经是Android的护城河,如今却成了谷歌的烫手山芋。最近谷歌宣布调整Android的开源政策,核心开发将全面转向私有分支。翻译成人话就是:以后Android的核心更新,不再公开共享了。 这操作不就是开源变节吗,…

JavaScript中集合常用操作方法详解

JavaScript中集合常用操作方法详解 JavaScript中的集合主要包括数组(Array)、集合(Set)和映射(Map)。下面我将详细介绍这些集合类型的常用操作方法。 数组(Array) 数组是JavaScript中最常用的集合类型,提供了丰富的操作方法。 创建数组 // 字面量创建 const ar…

【HC-05】蓝牙串口通信模块调试与应用(1)

一、HC-05 基础学习视频 HC-05蓝牙串口通信模块调试与应用1 二、HC-05学习视频课件

【学Rust写CAD】18 定点数2D仿射变换矩阵结构体(MatrixFixedPoint结构别名)

源码 // matrix/fixed.rs use crate::fixed::Fixed; use super::generic::Matrix;/// 定点数矩阵类型别名 pub type MatrixFixedPoint Matrix<Fixed, Fixed, Fixed, Fixed, Fixed, Fixed>;代码解析 这段代码定义了一个定点数矩阵的类型别名 MatrixFixedPoint&#xff…

axios文件下载使用后端传递的名称

java后端通过HttpServletResponse 返回文件流 在Content-Disposition中插入文件名 一定要设置Access-Control-Expose-Headers&#xff0c;代表跨域该Content-Disposition返回Header可读&#xff0c;如果没有&#xff0c;前端是取不到Content-Disposition的&#xff0c;可以在统…

HarmonyOS之深入解析如何根据url下载pdf文件并且在本地显示和预览

一、文件下载 ① 网络请求配置 下载在线文件&#xff0c;需要访问网络&#xff0c;因此需要在 config.json 中添加网络权限&#xff1a; {"module": {"requestPermissions": [{"name": "ohos.permission.INTERNET","reason&qu…

鸿蒙前后端项目源码-点餐v3.0-原创!原创!原创!

鸿蒙前后端点餐项目源码含文档ArkTS语言. 原创作品.我半个月写的原创作品&#xff0c;请尊重原创。 原创作品&#xff0c;盗版必究&#xff01;&#xff01;&#xff01;&#xff01; 原创作品&#xff0c;盗版必究&#xff01;&#xff01;&#xff01;&#xff01; 原创作…

VUE3+TypeScript项目,使用html2Canvas+jspdf生成PDF并实现--分页--页眉--页尾

使用html2CanvasJsPDF生成pdf&#xff0c;并实现分页添加页眉页尾 1.封装方法htmlToPdfPage.ts /**path: src/utils/htmlToPdfPage.tsname: 导出页面为PDF格式 并添加页眉页尾 **/ /*** 封装思路* 1.将页面根据A4大小分隔边距&#xff0c;避免内容被中间截断* 所有元素层级不要…

5.Excel:从网上获取数据

一 用 Excel 数据选项卡获取数据的方法 连接。 二 要求获取实时数据 每1分钟自动更新数据。 A股市场_同花顺行情中心_同花顺财经网 用上面方法将数据加载进工作表中。 在表格内任意区域右键&#xff0c;刷新。 自动刷新&#xff1a; 三 缺点 Excel 只能爬取网页上表格类型的…

《深度剖析SQL之WHERE子句:数据过滤的艺术》

在当今数据驱动的时代&#xff0c;数据处理和分析能力已成为职场中至关重要的技能。SQL作为一种强大的结构化查询语言&#xff0c;在数据管理和分析领域占据着核心地位。而WHERE子句&#xff0c;作为SQL中用于数据过滤的关键组件&#xff0c;就像是一把精准的手术刀&#xff0c…

华为eNSP-配置静态路由与静态路由备份

一、静态路由介绍 静态路由是指用户或网络管理员手工配置的路由信息。当网络拓扑结构或者链路状态发生改变时&#xff0c;需要网络管理人员手工修改静态路由信息。相比于动态路由协议&#xff0c;静态路由无需频繁地交换各自的路由表&#xff0c;配置简单&#xff0c;比较适合…

Docker 快速入门指南

Docker 快速入门指南 1. Docker 常用指令 Docker 是一个轻量级的容器化平台&#xff0c;可以帮助开发者快速构建、测试和部署应用程序。以下是一些常用的 Docker 命令。 1.1 镜像管理 # 搜索镜像 docker search <image_name># 拉取镜像 docker pull <image_name>…

基础认证-单选题(一)

单选题 1、下列关于request方法和requestlnStream方法说法错误的是(C) A 都支持取消订阅响应事件 B 都支持订阅HTTP响应头事件 C 都支持HttpResponse返回值类型 D 都支持传入URL地址和相关配置项 2、如需修改Text组件文本的透明度可通过以下哪个属性方法进行修改 (C) A dec…

Logback使用和常用配置

Logback 是 Spring Boot 默认集成的日志框架&#xff0c;相比 Log4j&#xff0c;它性能更高、配置更灵活&#xff0c;并且天然支持 Spring Profile 多环境配置。以下是详细配置步骤及常用配置示例。 一、添加依赖&#xff08;非 Spring Boot 项目&#xff09; 若项目未使用 Sp…

MySQL基础语法DDLDML

目录 #1.创建和删除数据库 ​#2.如果有lyt就删除,没有则创建一个新的lyt #3.切换到lyt数据库下 #4.创建数据表并设置列及其属性,name是关键词要用name包围 ​编辑 #5.删除数据表 #5.查看创建的student表 #6.向student表中添加数据,数据要与列名一一对应 #7.查询studen…

在windows下安装windows+Ubuntu16.04双系统(下)

这篇文章的内容主要来源于这篇文章&#xff0c;为正式安装windowsUbuntu16.04双系统部分。在正式安装前&#xff0c;若还没有进行前期准备工作&#xff08;1.分区2.制作启动u盘&#xff09;&#xff0c;见《在windows下安装windowsUbuntu16.04双系统(上)》 二、正式安装Ubuntu …