pytorch使用GPU炼丹笔记

如何使用GPU训练/测试模型

  • 使用单GPU
    • 设置设备
    • 将数据转换成CUDA张量
    • 将模型参数转换成CUDA张量
    • 使用指定GPU
      • 1.使用CUDA_VISIBLE_DEVICES。
          • 1.1 直接在终端或shell脚本中设定:
          • 1.2 python代码中设定:
      • 2. 使用函数 set_device
  • 使用多GPU
    • DP方法
    • DDP方法
      • 需要先初始化
      • 数据集的处理
      • 模型初始化
      • 单节点多GPU分布式训练
  • 实验结果

原理:通过依靠GPU的并行计算能力,能够大大缩短模型训练时间。
在使用GPU跑代码的时候,只需要将模型参数和数据放到GPU上转换成CUDA张量即可。

所以,代码需要修改的地方只有两处:

  • 1.模型实例化处。
  • 2.数据迭代处(如果数据有更新或增加的话也需要在相应的地方改,只要确定在输入模型前的所有数据都转换为CUDA张量即可)。

使用单GPU

设置设备

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

cuda:0表示使用0号显卡进行训练,所以如果需要指定其他GPU,则可以直接修改该数字即可。如cuda:1,表示使用1号显卡进行训练。

将数据转换成CUDA张量

也就是将数据放到GPU上

for i,data in enumerate(trainloader,0):#从迭代器中获取数据输入inputs,labels = datainputs, labels = inputs.to(device), labels.to(device) #数据转换为CUDA张量outputs=net(inputs) #数据输入模型

to(device)的作用是将数据转换为CUDA张量。

以一个labels为例:

  • 加.to(device)之前,labels为tensor([0, 6, 2, 0])
  • 加.to(device)之后,labels为tensor([0, 6, 2, 0], device=‘cuda:0’)

将模型参数转换成CUDA张量

同理,也就是将模型放到GPU上

net=net.to(device)  #model为实例化的模型

使用指定GPU

PyTorch默认使用从0开始的GPU,如果GPU0正在运行程序,需要指定其他GPU。

除了上面直接修改cuda:0的方法之外还有以下两种方法来指定需要使用的GPU。

1.使用CUDA_VISIBLE_DEVICES。

1.1 直接在终端或shell脚本中设定:
CUDA_VISIBLE_DEVICES=1 python my_script.py
1.2 python代码中设定:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

2. 使用函数 set_device

import torch
torch.cuda.set_device(id)

该函数见 pytorch-master\torch\cuda_init_.py。
不过官方建议使用CUDA_VISIBLE_DEVICES,不建议使用 set_device 函数。

使用多GPU

DP方法

如果有多个GPU,使用nn.DataParallel来包装我们的模型。 然后通过model.to(device)把模型放到GPU上。
代码如下:

device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
net=Net().to(device) #需要先加载到GPU上,将模型参数转换成CUDA向量
if torch.cuda.device_count() > 1:net=nn.DataParallel(net)

DDP方法

DDP方法比DP方法要好,解决了DP数据分配不平衡和训练速度慢的缺点。
需要添加的地方有:

需要先初始化

需初始化local_rank参数,这里通过argparse模块设置:

parser.add_argument('--local_rank', type=int, default=0,help='node rank for distributed training') 

local_rank参数代表要训练的机子,本来用于标记主机和从机的,如果是多机的话,不同的机器使用不同的local_rank标识机器,由于这里是单机多卡,只用0表示主机就可以了。
另外还需要初始化进程组,代码如下:

torch.distributed.init_process_group(backend="nccl", init_method='file://file:///DATA/wanghongzhi/first_dnn/temp/test',rank=0, world_size=1) 

其中rank是主机的编号,world_size表示分布式主机的个数。

数据集的处理

这步主要是保证一个batch里的数据被均摊到进程上,每个进程能获取到相应的数据。每个GPU产生一个进程。

train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=False,sampler=train_sampler)

模型初始化

先执行net=Net().to(device)将模型放在GPU上后,再执行下面语句初始化模型:

net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank],output_device=args.local_rank)

单节点多GPU分布式训练

运行代码的shell脚本为:

python -m torch.distributed.launch --nproc_per_node=2  main.py

其中nproc_per_node为使用的显卡数量。

关于DDP的更多细节、参数的选择及其作用可以参考:
博客1
博客2
博客3
博客4
github资源

实验结果

显卡使用TITIAN Xp,从下面数据结果看到DDP效果最好。

  • 1.使用CPU耗时888.595s
  • 2.使用一块GPU耗时80.06s
  • 3.使用DP,两块GPU耗时60.7744s
  • 4.使用DDP,两块GPU耗时48.7986s

参考:
单GPU
多GPU

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/333714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java解决错误经验_在Java错误进入生产之前的新处理方式

java解决错误经验我们如何认识到解决预生产错误的旧方法还不够,以及我们如何能够改变它 第一次尝试就没有完美的代码,我们所有人都可以证明我们已经通过艰苦的努力学习了。 不管我们使用多少测试周期,代码审查或工具,总有至少一个…

vim 寄存器中的 ^@,^M,^J

首先,ASCII 码表示的字符不都是可打印字符(可显示字符),意味着,其中的控制字符本不是对应某个字形的,所以本没有办法看到他们。那么如果万一某个文件中出现了这些怎么办捏??这里我们…

2019怎么保存低版本_CAD发给客户没字体怎么办?快速打包外部参照、字体、打印样式...

CAD发给客户没字体怎么办?快速打包DWG外部参照、字体、图片、打印样式!有没有遇见过这样的情况:图纸发给客户,外部参照的文件没有一起打包发出去,被老板和客户臭骂一顿。图纸发给审图,没有字体,…

Python正则表达式笔记

正则表达式作用函数函数参数查找函数re.findall()re.search()re.match()re.finditer()re.compile()函数替换函数re.sub(pattern,repl,string,count0,flags0)re.subn()分割函数re.split()模式串字符字符类别表达(匹配单个字符)\d\D\s\S\w\W[a-z][^a-z].多次匹配字符*&#xff1f…

CentOS Linux 下的 vim 无法使用系统剪贴板,怎么解决呢?

文章目录查看系统当前的 vim 是否支持剪贴板安装 gvim 来支持系统剪贴板gvim 和 vim 的区别SSH 连接远程主机遇到的问题查看系统当前的 vim 是否支持剪贴板 首先查看下系统的 vim 是否支持系统剪贴板,在命令终端输入如下命令: [roothtlwk0001host test…

jwt令牌_jwt-cli:用于解码JSON Web令牌(JWT令牌)的Shell库

jwt令牌当我开始经常需要解码JSON Web令牌时,我感到迫切需要编写允许我快速进行操作的程序。 有很多不错的选项,例如jwt.io ,但是一旦您需要执行此操作,它通常就会变得笨拙。 并且,如果您需要处理多个令牌或进一步处理…

日历对象导哪个包_java.util的的Date类和Calendar类

Datejava.util.Date类的对象用来表示时间和日期,用得最多的是获取系统当前日期和时间,精确到毫秒。Java中有两个Date类,还有一个是java.sql.Date,这个类一般不用,即使在数据库中也不推荐使用。写代码时注意导入的包一定…

Debian Linux 的 vim 如何使用系统剪贴板

以 ubuntu 为例,ubuntu 默认是没有 vim 的,需要自己安装一下: 更新源: apt-get update安装 vim : apt-get install vim此时,系统不支持剪切板,我们使用命令 vim --version|grep clipboard 查…

Python中replace()函数

replace()函数 功能:类似正则表达式的sub()函数,使用新的字符串替换主串中的内容。 函数需要通过字符串来调用,replace(old, new, max)函数参数依次为: old表示主串中要被替换的字符串。new表示新的字符串。max表示替换次数,默…

lombok和maven_Lombok,AutoValue和Immutables,或如何编写更少,更好的代码返回

lombok和maven在上一篇有关Lombok库的文章中 ,我描述了一个库,该库有助于处理Java中的样板代码( 是的,我知道这些问题已经在Kotlin中解决了 ,但这是现实生活,我们不能一味地坐下来,一旦出现较新…

上传文件显示进度条_文件上传带进度条进阶-断点续传

说明 1. 把文件按大小1M分割成N份 2. 每次上传时,告诉后台大文件的md5、当前第几份(从0开始)、总共几份 3. 并行上传,前端同时开启5个请求进行传输增加速度 4. 上传失败或出错后,继续上传下一份,把出错的份…

Unix 下的 vim 如何使用系统剪贴板

在 Unix 环境下," 寄存器需要 xterm-clipboard feature 的 VIM 软件才能使用,具有这个 feature 的 VIM 可以安装 vim-gtk(包含gvim和vim),使用 gvim 可以正常调用 " 寄存器。

python中关键字global的简单理解

python用global关键字来标识函数里或类里的全局变量,下面以例子来看看global关键字的作用。 未使用global关键字 a10 #全局变量 def sum(x):a2 #局部变量xa*xreturn x xsum(3) print("a:",a) #10 输出的是全局变量a10 print("x:",x) #6使用…

apache kafka_2018年机器学习趋势与Apache Kafka生态系统相结合

apache kafka在慕尼黑举行的OOP 2018大会上,我介绍了有关使用Apache Kafka生态系统和诸如TensorFlow,DeepLearning4J或H2O之类的深度学习框架构建可扩展,关键任务微服务的演讲的更新版本。 我想分享更新后的幻灯片,并讨论一些有关…

cookies丢失 同域名_后端设置Cookie前端跨域获取丢失问题(基于springboot实现)

1.跨域问题说明:后端域名为A.abc.com,前端域名为B.abc.com。2.后端设置一个cookie发送给前台,domain应该是setDomain(“abc.com”),而不是setDomain(“B.abc.com”)3.另外,还要实现WebMvcConfigurerr配置加入Cors的跨域…

shell脚本--使用for循环逐行访问txt文件

方法1 export text_pathdata/1.txt for line in $(cat $text_path) doecho $line done方法2 export text_pathdata/1.txt for line in cat $text_path doecho $line done

vertx rest 跨域_Vertx编程风格:您的React式Web Companion REST API解释了

vertx rest 跨域Vertx提供了许多在轻量级环境中进行编程的选项,例如node.js。 但是,对于新用户来说,选择采用哪种方法来创建REST API几乎不会造成混淆。 在vertx中进行编程时,可以采用不同的模型。 下面通过易于理解的图表进行说…

输出节点位移_绝对值信号的编码器有哪些信号输出(一、二)

绝对值信号的编码器有哪些信号输出(一、二)之前介绍过很多次拉线位移传感器输出是有两大类的,数字信号输出和模拟量信号输出,而数字信号输出还分为增量型脉冲信号输出和绝对值信号输出,今天就系统的介绍一下绝对值信号…

vim 中的 quickfix 指令

用 quickfix 可以快速修改编译错误。 运行了 make 命令编译之后,如果有编译错误 Vim 会以列表形式把编译错误列出,并使用 quickfix 工具快速帮你定位出错的行。 指令说明cc显示编译错误的详细信息,这些信息显示在状态行里cn下一个编译错误cp前一个编译…

Python第三方库的安装,升级以及版本查看

方法:通过电脑的cmd命令行来进行python第三方库的安装,升级以及版本查看 安装和升级pip 安装pip方法1 在cmd命令行输入以下命令: python -m ensurepip #当提示不存在pip时使用这行代码进行安装安装pip方法2 在终端输入以下命令&#xf…