分布式执行引擎ray入门--(3)Ray Train

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

import os
import tempfileimport torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Composeimport ray.train.torchdef train_func(config):# Model, Loss, Optimizermodel = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# model.to("cuda")  # This is done by `prepare_model`# [1] Prepare model.model = ray.train.torch.prepare_model(model)criterion = CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=0.001)# Datatransform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])data_dir = os.path.join(tempfile.gettempdir(), "data")train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=128, shuffle=True)# [2] Prepare dataloader.train_loader = ray.train.torch.prepare_data_loader(train_loader)# Trainingfor epoch in range(10):for images, labels in train_loader:# This is done by `prepare_data_loader`!# images, labels = images.to("cuda"), labels.to("cuda")outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()# [3] Report metrics and checkpoint.metrics = {"loss": loss.item(), "epoch": epoch}with tempfile.TemporaryDirectory() as temp_checkpoint_dir:torch.save(model.module.state_dict(),os.path.join(temp_checkpoint_dir, "model.pt"))ray.train.report(metrics,checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),)if ray.train.get_context().get_world_rank() == 0:print(metrics)# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(train_func,scaling_config=scaling_config,# [5a] If running in a multi-node cluster, this is where you# should configure the run's persistent storage that is accessible# across all worker nodes.# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))model = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

+import ray.train
+from ray.train import Checkpointdef train_func(config):...torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

from ray.train import RunConfig# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

from ray.train.torch import TorchTrainertrainer = TorchTrainer(train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL-视图:视图概述、使用视图注意点、视图是否影响基本表

视图 一、视图概述二、使用视图注意点三、视图操作是否影响基本表 一、视图概述 在数据库管理系统中,视图(View)是一种虚拟表,它并不实际存储数据,而是基于一个或多个实际表的查询结果。视图提供了一种对数据库中数据…

题目 2132: T1268-完全背包问题

题目描述: 设有n种物品,每种物品有一个重量及一个价值。但每种物品的数量是无限的,同时有一个背包,最大载重量为M,今从n种物品中选取若干件(同一种物品可以多次选取),使其重量的和小于等于M,而价值的和为最…

RabbitMQ备份交换机

1. 备份交换机 备份交换机可以理解为 RabbitMQ 中交换机的“备胎”,当我们为某一个交换机声明一个对应的备份交换机时,就是为它创建一个备胎,当交换机接收到一条不可路由消息时,将会把这条消息转发到备份交换机中,由备…

reids设计与实现(一)——数据对象

文章目录 1. 前言2. redis 动态字符串2.1. 字符串的数据结构:2.2. 剖析,length;2.3. 剖析,free;2.3. 使用c字符串函数; 3. redis 链表4. 字典5. 跳跃表 1. 前言 reids作为最常用的缓存数据库,深…

web3 DePIN赛道之OORT

文章目录 什么是DePIN什么是oort背景:去中心化云计算场景团队OORT AIOORT StorageOORT Compute 参考 什么是DePIN DePIN是Decentralized Physical Infrastructure Networks的简称,中文意思就是去中心化的网络硬件基础设施,是利用区块链技术和代币奖励来调动分散在世…

mysql笔记:2. 表操作

文章目录 创建表创建表主键约束外键约束非空约束唯一约束默认值约束 查看数据表结构查看表详细结构 修改数据表修改表名修改字段名修改字段数据类型添加字段删除字段更改表的存储引擎 删除数据表删除没被关联的表删除被关联的表 数据表是数据库中最重要、最基本的操作对象&…

单例九品--第九品[可用的设计]

单例九品--第九品[可用的设计] 上一品引入写在前边代码部分实现方式的评注和思考写在最后 上一品引入 自第五品以来,为解决第四品的静态初始化灾难问题,将全局对象设置为指针类型,但是指针是有被修改的风险。所以第八品将全局单例对象封装在…

2024.02.07 校招 实习 内推 面经

绿*泡*泡VX: neituijunsir 交流*裙 ,内推/实习/校招汇总表格 1、校招 | 腾讯TEG 2024 校招开放全新机会,持续热招(内推) 校招 | 腾讯TEG 2024 校招开放全新机会,持续热招(内推) …

【MATLAB】MATLAB学习笔记

MATLAB入门 基础操作变量命名数据类型逻辑和流程控制循环结构分支结构 绘图基本操作二维平面绘图绘图参数三位立体绘图图像窗口的分割 本文参考B站视频:BV13D4y1Q7RS 由于我对于C语言很熟悉,很多语法是会参考C来学 基础操作 清屏%% 清空环境变量及命令 …

图腾柱PFC工作原理:一张图

视屏链接: PFC工作原理

Java SE入门及基础(33)

final 修饰符 1. 应用范围 final 修饰符应该使用在类、变量以及方法上 2. final 修饰类 Note that you can also declare an entire class final. A class that is declared final cannot be subclassed. This is particularly useful, for example, when creating an imm…

docker学习笔记——Dockerfile

Dockerfile是一个镜像描述文件,通过Dockerfile文件可以构建一个属于自己的镜像。 如何通过Dockerfile构建自己的镜像: 在指定位置创建一个Dockerfile文件,在文件中编写Dockerfile相关语法。 构建镜像,docker build -t aa:1.0 .(指…

Effective C++ 学习笔记 条款22 将成员变量声明为private

下面是作者的规划。首先带你看看为什么成员变量不该是public,然后让你看看所有反对public成员变量的论点同样适用于protected成员变量。最后导出一个结论:成员变量应该是private。获得这个结论后,本条款也就大功告成了。 好,现在…

【Vue】生命周期

Vue生命周期 就是一个Vue实例从创建 到 销毁 的整个过程 生命周期四个阶段 1.创建阶段:创建响应式数据 2.挂载阶段:渲染模板 3.更新阶段:修改数据,更新视图 4.销毁阶段:销毁Vue实例 生命周期钩子 Vue生命周期过…

【每日一题】2834. 找出美丽数组的最小和-2024.3.8

题目: 2834. 找出美丽数组的最小和 给你两个正整数:n 和 target 。 如果数组 nums 满足下述条件,则称其为 美丽数组 。 nums.length n.nums 由两两互不相同的正整数组成。在范围 [0, n-1] 内,不存在 两个 不同 下标 i 和 j &…

阿里云实现两个VPC网络资源互通

背景 由于实际项目预算有限,两套环境虽然分别属于不同的专有网络即不同的VPC,但是希望借助一台运维机器实现对两个环境的监控和日常的运维操作 网络架构 如下是需要实现的外网架构图,其中希望实现UAT环境的一台windows的堡垒机可以访问生产…

前端软件工程师100问?

作为一个前端软件工程师,可能会遇到的问题非常多,以下是我为您精选的100个常见问题: 以上100个问题涵盖了前端开发的基础知识、框架应用、性能优化、安全性、兼容性、前沿技术等多个维度。作为前端软件工程师,了解这些问题有助于提…

第G3周:CGAN入门|生成手势图像

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 一、前置知识 CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息…

【linux】04 :linix实用操作

1.常用快捷键 ctrlc表示强制停止。linux某些程序的运行,如果想强制停止,可以使用;命令输入错误,也可以通过ctrlc,退出当前输入,重新输入。 ctrld表示退出登录,比如退出root以回到普通用户,或者…

Stable Diffusion 模型下载:ZavyChromaXL(现实、魔幻)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八 下载地址 模型介绍 作者述:该模型系列应该是用于 SDXL 的 ZavyMix SD1.5 模型的延续。主要重点是获…