使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器进行模型检查点处理

2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。这就是我将在这篇文章中介绍的内容。S3 连接器的文档仅展示了如何将其与 Amazon S3 一起使用 - 我将在此处向您展示如何将其用于 MinIO。让我们先执行此作 - 让我们设置 S3 连接器,以便它从 MinIO 写入和读取检查点。

将 S3 连接器连接到 MinIO

将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。

本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。

AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false

写入和读取 Checkpoint

我从一个简单的例子开始。下面的代码段创建了一个 S3Checkpointing 对象,并使用其 writer() 方法将模型的状态字典发送到 MinIO。我还使用 Torchvision 创建了一个 ResNet-18(18 层)模型,用于演示目的。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model = torchvision.models.resnet18()
model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Save checkpoint to S3
with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,该区域有一个强制参数。从技术上讲,访问 MinIO 时没有必要,但如果为此变量选择错误的值,内部检查可能会失败。此外,您的存储桶必须存在,上述代码才能正常工作。如果 writer() 方法不存在,它将引发错误。不幸的是,无论出了什么问题,writer() 方法都会引发相同的错误。例如,如果您的存储桶不存在,您将收到如下所示的错误。如果 writer() 方法不喜欢您指定的区域,您也会收到相同的错误。希望未来的版本将提供更具描述性的错误消息。

S3Exception: Client error: Request canceled

将以前保存的模型读取到内存中的代码类似于写入 MinIO。使用 reader() 方法,而不是 writer() 方法。下面的代码显示了如何执行此作。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Load checkpoint from S3
with s3_checkpoint.reader(checkpoint_uri) as reader:state_dict = torch.load(reader, weights_only=True)model.load_state_dict(state_dict)

接下来,让我们看看模型训练期间检查点的一些实际注意事项。

在模型训练期间编写检查点

如果您使用大型数据集训练大型模型,请考虑在每个 epoch 后设置检查点。这些训练运行可能需要数小时甚至数天才能完成,因此在发生故障时能够从上次中断的地方继续非常重要。此外,我们假设您必须使用共享存储桶来保存来自多个团队的多个模型的模型检查点。MLOps 约定是按试验组织训练运行。例如,如果您正在研究具有四个隐藏层的架构,那么在寻找各种超参数的最佳值时,您将使用此架构进行多次运行。如果同事使用五层体系结构运行实验,则需要一种方法来防止名称冲突。这可以通过模拟如下所示的层次结构的对象路径来解决。

最后,为了确保您在每个 epoch 中获得新版本的模型,请确保在用于保存检查点的存储桶上启用版本控制。下面的训练函数使用上述路径结构在每个 epoch 后对模型进行检查点作。(可以在本文的代码下载中找到此训练函数的更强大版本。

def train_model(model: nn.Module, loader: DataLoader, training_parameters: Dict[str, Any]) -> List[float]:if training_parameters['checkpoint']:checkpoint_uri = f's3://{training_parameters["checkpoint_bucket"]} \/{training_parameters["project_name"]} \/{training_parameters["experiment_name"]} \/{training_parameters["run_id"]} \/{training_parameters["model_name"]}'s3_checkpoint = S3Checkpoint(region=os.environ['AWS_REGION'])loss_func = nn.NLLLoss()optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], momentum=training_parameters['momentum'])# Epoch loopcompute_time_by_epoch = []for epoch in range(training_parameters['epochs']):# Batch loopfor images, labels in loader:# Flatten MNIST images into a 784 long vector.# shape = [32, 784]images = images.view(images.shape[0], -1)# Training passoptimizer.zero_grad()output = model(images)loss = loss_func(output, labels)loss.backward()optimizer.step()# Save checkpoint to S3if training_parameters['checkpoint']:with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,模型名称不包含指示纪元的子字符串。如前所述,我使用了启用了版本控制的存储桶 - 换句话说,版本号表示纪元。这种方法的优点在于,您无需知道引用最新模型的 epoch 数。在上述训练代码运行了 10 个 epoch 后,我的检查点存储桶如下面的屏幕截图所示。

此培训演示可被视为 DIY MLOps 解决方案的开始。

结论

适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码行数会更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,工程师可以分别使用 writer() 和 reader() 方法将检查点写入和读取 MinIO。在本文中,我展示了如何配置 S3 Connect 以连接到 MinIO。我还演示了 S3Checkpoint 类及其 reader() 和 writer() 方法的基本用法。最后,我展示了一种在实际训练函数中针对启用了版本的检查点存储桶使用这些检查点功能的方法。在这篇文章中,我没有介绍在分布式训练期间检查点所需的技术和工具,这可能有点棘手。分布式训练期间的检查点设置会有所不同,具体取决于您使用的框架(PyTorch、Ray 或 DeepSpeed 等)和您正在进行的分布式训练类型:数据并行(每个工作程序都有模型的完整副本)或模型并行(每个工作程序只有一个模型分片)。在以后的文章中,我将介绍其中的一些技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/69705.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 Python(Flask)、JavaScript、HTML 和 CSS 实现前后端交互的详细开发过程

以下是一个基于 Python(Flask)、JavaScript、HTML 和 CSS 实现前后端交互的详细开发过程: --- ### 一、技术选型 1. **后端**:Python Flask(轻量级Web框架) 2. **前端**:HTML/CSS JavaScript&…

细究 ES6 中多种遍历对象键名方式的区别

一、前言 说到遍历对象,第一反应是用 for...in..、和 Object.keys()。平常最多用的就是这俩个。 最近重新翻看 《ES6 标准入门》这本书,发现遍历对象键名的方式还是挺多的。 今天借此机会,以一个基本案例,总结五种遍历对象键名…

尚硅谷爬虫note004

一、urllib库 1. python自带,无需安装 # _*_ coding : utf-8 _*_ # Time : 2025/2/11 09:39 # Author : 20250206-里奥 # File : demo14_urllib # Project : PythonProject10-14#导入urllib.request import urllib.request#使用urllib获取百度首页源码 #1.定义一…

Spring 项目接入 DeepSeek,分享两种超简单的方式!

⭐自荐一个非常不错的开源 Java 面试指南:JavaGuide (Github 收获148k Star)。这是我在大三开始准备秋招面试的时候创建的,目前已经持续维护 6 年多了,累计提交了 5600 commit ,共有 550 多位贡献者共同参与…

批量查询linux下可执行程序缺少的依赖

方法一:使用 find 和 xargs find . -maxdepth 1 -type f -executable | xargs ldd方法二:使用 for 循环 直接复制下面内容粘贴到命令行即可 for file in *; doif [ -f "$file" ] && [ -x "$file" ]; thenecho "Depe…

日常知识点之面试后反思裸写string类

1:实现一个字符串类。 简单汇总 最简单的方案,使用一个字符串指针,以及实际字符串长度即可。 参考stl的实现,为了提升string的性能,实际上单纯的字符串指针和实际长度是不够了,如上,有优化方案…

phpipam1.7安装部署

0软件说明 phpipam是一个开源Web IP地址管理应用程序(IPAM) phpipam官网:https://www.phpipam.net/ 1安装环境 操作系统:Rocky Linux9.5x86_64 phpipam版本:1.7 php版本:8.0.30 数据库版本&#xff1a…

python卷积神经网络人脸识别示例实现详解

目录 一、准备 1)使用pytorch 2)安装pytorch 3)准备训练和测试资源 二、卷积神经网络的基本结构 三、代码实现 1)导入库 2)数据预处理 3)加载数据 4)构建一个卷积神经网络 5&#xff0…

网络安全总结

网络安全总结 网络安全第一篇 1. 防火墙必不可少(局域网与互联网之间必须隔离) 连接到Internet的每一个人都需要在其网络入口处采取一定的措施來阻止和丢弃恶意的网络通信,但是我们貌似没有这么做,这就需要我们在物理或者软件实现我们的防火墙&#xf…

【文本处理】如何在批量WORD和txt文本提取手机号码,固话号码,提取邮箱,删除中文,删除英文,提取车牌号等等一些文本提取固定格式的操作,基于WPF的解决方案

企业的应用场景 数据清洗:在进行数据导入或分析之前,往往需要对大量文本数据进行预处理,比如去除文本中的无关字符(中文、英文),只保留需要的联系信息(手机号码、固话号码、邮箱)。…

【Cocos TypeScript 零基础 15.1】

目录 见缝插针UI脚本针脚本球脚本心得_旋转心得_更改父节点心得_缓动动画成品展示图 见缝插针 本人只是看了老师的大纲,中途不明白不会的时候再去看的视频 所以代码可能与老师代码有出入 SIKI_学院_点击跳转 UI脚本 import { _decorator, Camera, color, Component, directo…

pdf.js默认显示侧边栏和默认手形工具

文章目录 默认显示侧边栏(切换侧栏)默认手形工具(手型工具) 大部分的都是在viewer.mjs中的const defaultOptions 变量设置默认值,可以使用数字也可以使用他们对应的变量枚举值 默认显示侧边栏(切换侧栏) 在viewer.mjs中找到defaultOptions,大概在732行,或则搜索sidebarViewOn…

基于 ollama 在linux 私有化部署DeepSeek-R1以及使用RESTful API的方式使用模型

由于业务需求部署的配置 deepseek:32b,linux配置GPU L20 4卡 ,SSD 200g,暂未发现有什么问题,持续观察中 ##通用写法,忽略就行,与deepseek无关 import pandas as pd from openai.embeddings_utils import get_embedding, cosine_s…

基于 STM32 的病房监控系统

标题:基于 STM32 的病房监控系统 内容:1.摘要 基于 STM32 的病房监控系统摘要:本系统采用 STM32 微控制器作为核心,通过传感器实时监测病房内的环境参数,如温度、湿度、光照等,并将数据上传至云端服务器。医护人员可以通过手机或…

Java分布式幂等性怎么设计?

在高并发的场景的架构中,幂等性是必须得保证的。比如说支付功能,用户发起支付,如果后台没有坐幂等性校验,刚好用户手抖多点了几下,于是后台就有可能多次收到同一个请求,不做幂等性校验很容易就让用户重复支…

Pdf手册阅读(1)--数字签名篇

原文阅读摘要 PDF支持的数字签名, 不仅仅是公私钥签名,还可以是指纹、手写、虹膜等生物识别签名。PDF签名的计算方式,可以基于字节范围进行计算,也可以基于Pdf 对象(pdf object)进行计算。 PDF文件可能包…

Debezium系列之:时区转换器,时间戳字段转换到指定时区

Debezium系列之:时区转换器,时间戳字段转换到指定时区 示例:基本配置应用TimezoneConverter SMT的效果示例:高级配置配置选项当Debezium发出事件记录时,记录中的时间戳字段的时区值可能会有所不同,这取决于数据源的类型和配置。为了在数据处理管道和应用程序中保持数据一…

Zabbix-监控SSL证书有效期

背景 项目需要,需要监控所有的SSL证书的有效期,因此需要自定义一个监控项 实现 创建自定义脚本 在Zabbix的scripts目录(/etc/zabbix/scripts/)下创建一个新的shell脚本check_ssl.sh,内容如下 #!/bin/bash time$(echo | openssl s_client…

【AI知识点】大模型开源的各种级别和 deepseek 的开源级别

【AI论文解读】【AI知识点】【AI小项目】【AI战略思考】【AI日记】【读书与思考】【AI应用】 大模型开源的各种级别 大模型的“开源”程度不同,通常可以分为以下几个主要级别: 1. 权重不开源(Closed-source) 特点:仅…

java安全中的类加载

java安全中的类加载 提前声明: 本文所涉及的内容仅供参考与教育目的,旨在普及网络安全相关知识。其内容不代表任何机构、组织或个人的权威建议,亦不构成具体的操作指南或法律依据。作者及发布平台对因使用本文信息直接或间接引发的任何风险、损失或法律纠…