pytorch lightning最简上手

pytorch lightning最简上手

pytorch lightning 是对原生 pytorch 的通用模型开发过程进行封装的一个工具库。本文不会介绍它的高级功能,而是通过几个最简单的例子来帮助读者快速理解、上手基本的使用方式。在掌握基础 API 和使用方式之后,读者可自行到 pytorch lightning 的官方文档,了解进阶 API。本文假设读者对原生 pytorch 训练脚本的搭建方法已经比较熟悉。

安装

pytorch lighning 的安装非常简单,直接使用 pip 安装即可:

pip install pytorch-lightning

最简例子

pytorch lightning 有两个最核心的 API:LigtningModuleTrainer

其中 LightningModule 是我们熟悉的 torch.nn.Module 的子类,可以通过

print(isinstance(pl.LightningModule(), torch.nn.Module))

来验证。这意味着该类同样需要实现 forward 方法,并可直接通过实例调用。

Trainer 则是开始执行模型训练、测试过程的类,传入一个 LightningModule 和对应控制参数来实例化即可开始训练。

我们从一个最简单的例子——MNIST 手写数字识别开始:

1 导入必要的库

导入 pytorch_lightning 和 pytorch 常用的库。

import osimport torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

2 实现最简LigntningModule

我们先实现一个最简的 LightningModule。

  • __init__

    构造函数中,像常见的 torch.nn.Module 一样,我们定义好模型的层。由于是最简实例,这里只有一层线性层,将手写数字图像映射为输出 logits。

  • forward

    由于是继承自 torch.nn.Module,因此实现 forward 方法是必须的。forward 方法要完成模型的前向过程,这里直接调用 __init__ 中定义好的线性层,完成模型前向过程。

  • train_dataloader

    train_dataloader 方法也是最简实现中必须的,它的功能是获取训练集的 DataLoader。这里我们返回 MNIST 数据集的 DataLoader。dataloader 的获取也可以不在类内实现,而是在 fit 时传入,后面会介绍。

  • training_step

    training_step 是是 LigtningModule 的核心方法,它定义了一个训练步中需要做的事情。在深度学习的训练步中,最核心的事情就是模型前向,得到结果,计算损失,反向传播,更新参数,这几步在 pytorch 中都有对应的方法供调用。但是在 pytorch lightning 中,我们只需要进行模型前向,并返回必要的信息即可。在最简实现中,我们只需返回损失。

  • configure_optimizer

    在 training_step 中,我们只需返回损失,这意味着模型的反向传播和参数更新过程由 pytorch lightning 帮我们完成了。虽然这个过程可以有框架自己完成,但是我们还是要指定参数更新所用的优化器,在很多模型中,优化器、学习率等超参数设置对结果影响很大。在最简实现中,我们设置好学习率,并返回一个 Adam 优化器。

class MNISTModel(pl.LightningModule):def __init__(self):super(MNISTModel, self).__init__()self.l1 = torch.nn.Linear(28 * 28, 10)def forward(self, x):return torch.relu(self.l1(x.view(x.size(0), -1)))def train_dataloader(self):return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)def training_step(self, batch, batch_nb):x, y = batchloss = F.cross_entropy(self(x), y)return lossdef configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.02)

以上我们实现 training_step,train_dataloader, configure_optimizer,已经是最简单的 LightningModule 的实现了。如果连这三个方法都没有实现的话,将会报错:

 No `xxx` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined

3 开始训练

在实现好 LightningModule 之后,就可以开始训练了。

启动训练的最简实现非常简单,只需三行:实例化模型、实例化训练器、开始训练!

model = MNISTModel()
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(model)

开始训练后,pytorch lightning 会打印出可用设备、模型参数等丰富的信息。

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]| Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:07<00:00, 261.53it/s, loss=1.3, v_num=10]

总结

以上我们用 30 行左右代码,实现了一个最简的 pytorch lightning 训练过程。这足以体现出 pytorch lightning 的简洁、易用。但是,显然这个最简实现缺少了很多东西,比如验证、测试、日志打印、模型保存等。接下来,我们将实现相对完整但依旧简洁的 pytorch lightning 模型开发过程。

pytorch lightning更多功能

本节将介绍相对更完整的 pytorch lightning 模型开发过程。

LighningModeul需实现方法

在一个相对完整的 LightnintModule 中,用户应当实现以下方法:

1 模型定义 (__init__)

通常定义模型的各个层,在 forward 调用这些层,完成模型前向。与原生 pytorch 类似。

2 前向计算 (forward)

与 torch.nn.Module 的 forward 中做的事情一样,调用 _init_ 中定义的层。完成模型前向。与原生 pytorch 类似。

3 训练/验证/测试步 (training_step/validation_step/test_step)

定义训练/测试/训练每一步中要做的事情,一般是计算损失、指标并返回。

def training_step(self, batch, batch_idx):# ....return xxx # 如果是training_step, 则必须包含损失

通常有两个入参 batch 和 batch_idx。是 batch 是 dataloader 给出的输入数据和标签,batch_idx 是当前 batch 的索引。

注意训练步的返回值必须是损失值,或者是包含 ‘loss’ 字段的字典。验证/测试步的返回值不必包括损失,可以是任意结果。

4 训练/验证/测试步结束后 (training_step_end/validation_step_end/test_step_end)

只在使用多个node进行训练且结果涉及如softmax之类需要全部输出联合运算的步骤时使用该函数。

5 训练/验证/测试轮结束后 (training_epoch_end/validation_epoch_end/test_epoch_end)

以 training_epoch_end 为例,其他类似。

如果需要对整一轮的结果进行处理,比如计算一些平均指标等,可以通过 training_epoch_end 来实现。

def training_epoch_end(self, outputs):# ....return xxx

其中入参 outputs 是一个列表,包含了每一步 training_step 返回的内容。我们可以在每一轮结束后,对每一步的结果进行处理。

4 选用优化器 (configure_optimizers)

设置模型参数更新所用的优化器。值得一提的是如果需要多个优化器(比如在训练 GAN 时),可以返回优化器列表。也可以在优化器的基础上返回学习率调整器,那就要返回两个列表。

5 数据加载器 (train_dataloader, val_dataloader, test_dataloader)

返回 dataloader。

各个 dataloader 也可以在运行 fit/validation/test 时传入,如:

train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
model = MNISTModel()		# 不需要实现get_dataloader方法
trainer.fit(model, train_loader)

LightningModule自带工具

LightningModule 中提供了一些常用工具供用户直接使用:

log

Tensorboard 损失/指标日志保存和查看,不要自己定义,直接用即可。用法非常简单,将要记录的值传入:

self.log('train loss', loss)

当然一个功能完整的日志保存接口肯定提供了很多参数来控制,比如是按照 epoch 记录还是按照 step 记录、多卡训练时如何同步、指标是否要展示在进度条上、指标是否要保存在日志文件中等等。pytorch lightning 为这些选项都提供了控制参数,读者可以参考官方文档中 log 相关部分。

print

python 自带的 print 函数在进行多进程训练时会在每个进程都打印内容,这是原生 pytorch 进行分布式训练时一个很小但是很头疼的问题。LightningModule 提供的 print 只打印一次。

freeze

冻结所有权重以供预测时候使用。仅当已经训练完成且后面只测试时使用。

Trainer实例化参数

在实例化 Trainer 时,pytorch lightning 也提供了很多控制参数,这里介绍常用的几个,完整参数及含义请参考官方文档中 Trainer 相关部分。

  • default_root_dir:默认存储地址。所有的实验变量和权重全部会被存到这个文件夹里面。默认情况下,实验结果会存在 lightning_logs/version_x/
  • max_epochs:最大训练周期数,默认为 1000,如果不设上限 epoch 数,设置为 -1。
  • auto_scale_batch_size:在进行训练前自动选择合适的batch size。
  • auto_select_gpus:自动选择合适的GPU。尤其是在有GPU处于独占模式时候,非常有用。
  • gpus:控制使用的GPU数。当设定为None时,使用 cpu。
  • auto_lr_find:自动找到合适的初始学习率。使用了该论文的技术。当且仅当执行 trainer.tune(model) 代码时工作。
  • precision:浮点数精度。默认 32,即常规单精度 fp32 旬来呢。指定为 16 可以使用 fp16 精度加快模型训练并减少显存占用。
  • val_check_interval:进行验证的周期。默认为 1,如果要训练 10 个 epoch 进行一次验证,设置为 10。
  • fast_dev_run:如果设定为true,会只执行一个 batch 的 train, val 和 test,然后结束。仅用于debug。
  • callbacks:需要调用的 callback 函数列表,关于常用 callback 函数下面会介绍。

callback函数

Callback 是一个自包含的程序,可以与训练流程交织在一起,而不会污染主要的研究逻辑。Callback 并不一定只能在 epoch 结尾调用。pytorch-lightning 提供了数十个hook(接口,调用位置)可供选择,也可以自定义callback,实现任何想实现的模块。

推荐使用方式是,随问题和项目变化的操作,实现到 lightning module里面。而独立的、可复用的内容则可以定义单独的模块,方便多个模型调用。

常见的内建 callback 如:EarlyStopping,根据某个值,在数个epoch没有提升的情况下提前停止训练。。PrintTableMetricsCallback,在每个epoch结束后打印一份结果整理表格等。更多内建 callbacks 可参考相关文档。

模型加载与保存

模型保存

ModelCheckpoint 是一个自动储存的 callback 模块。默认情况下训练过程中只会自动储存最新的模型与相关参数,而用户可以通过这个 module 自定义。如观测一个 val_loss 的值,并储存 top 3 好的模型,且同时储存最后一个 epoch 的模型,等等。例:

from pytorch_lightning.callbacks import ModelCheckpoint# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(monitor='val_loss',filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',save_last=True
)trainer = pl.Trainer(gpus=1, max_epochs=3, callbacks=[checkpoint_callback])

ModelCheckpoint Callback中,如果 save_weights_only=True,那么将会只储存模型的权重,相当于 model.save_weights(filepath),反之会储存整个模型(包括模型结构),相当于model.save(filepath))。

另外,也可以手动存储checkpoint: trainer.save_checkpoint("example.ckpt")

模型加载

加载一个模型,包括它的模型权重和超参数:

model = MyLightingModule.load_from_checkpoint(PATH)print(model.learning_rate)
# 打印出超参数model.eval()
y_hat = model(x)

加载模型时替换一些超参数:

class LitModel(LightningModule):def __init__(self, in_dim, out_dim):super().__init__()self.save_hyperparameters()self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)# 如果在训练和保存模型时,超参数设置如下,在加载后可以替换这些超参数。
LitModel(in_dim=32, out_dim=10)# 仍然使用in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)# 替换为in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

完整加载训练状态,包括模型的一切,以及和训练相关的一切参数,如 model, epoch, step, LR schedulers, apex 等。

model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')# 自动恢复 model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)

实例

基于第三节介绍的更多功能,我们扩展第二节 MNIST 训练程序。代码如下。

import osimport torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import numpy as npclass MNISTModel(pl.LightningModule):def __init__(self):super().__init__()self.fc = nn.Linear(28 * 28, 10)def forward(self, x):return torch.relu(self.fc(x.view(-1, 28 * 28)))def training_step(self, batch, batch_nb):# REQUIREDx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log('train_loss', loss, on_step=False, on_epoch=True)return {'loss': loss}def validation_step(self, batch, batch_nb):# OPTIONALx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)pred = y_hat.argmax(dim=1, keepdim=True)correct = pred.eq(y.view_as(pred)).sum().item()acc = correct / x.shape[0]self.log('val_acc', acc, on_step=False, on_epoch=True)self.log('val_loss', loss, on_step=False, on_epoch=True)return {'val_loss': loss, 'val_acc': acc}def validation_epoch_end(self, outputs):# OPTIONALavg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()avg_acc = np.mean([x['val_acc'] for x in outputs])return {'val_loss': avg_loss, 'val_acc': avg_acc}def test_step(self, batch, batch_nb):# OPTIONALx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)return {'test_loss': loss}def test_epoch_end(self, outputs):# OPTIONALavg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()return {'test_loss': avg_loss}def configure_optimizers(self):# REQUIREDreturn torch.optim.Adam(self.parameters(), lr=0.02)def train_dataloader(self):# REQUIREDreturn DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)def val_dataloader(self):# OPTIONALreturn DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)def test_dataloader(self):# OPTIONALreturn DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=10,callbacks=[pl.callbacks.EarlyStopping( monitor="val_loss", patience=3),]
)
trainer.fit(model)
trainer.test()

Ref

  • pytorch lightning 的官方文档
  • Pytorch Lightning 完全攻略
  • 参考代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532367.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RT-Smart 官方 ARM 32 平台 musl gcc 工具链下载

前言 RT-Smart 的开发离不开 musl gcc 工具链&#xff0c;用于编译 RT-Smart 内核与用户态应用程序 RT-Smart musl gcc 工具链代码当前未开源&#xff0c;但可以下载到 RT-Thread 官方编译好的最新的 musl gcc 工具链 ARM 32位 平台 比如 RT-Smart 最好用的 ARM32 位 qemu 平…

java list翻转_JAVA实现两种方法反转单列表

/***authorluochengcheng* 定义一个单链表*/classNode {//变量private intrecord;//指向下一个对象privateNode nextNode;public Node(intrecord) {super();this.record record;}public intgetRecord() {returnrecord;}public void setRecord(intrecord) {this.record record;}…

OpenAI Whisper论文笔记

OpenAI Whisper论文笔记 OpenAI 收集了 68 万小时的有标签的语音数据&#xff0c;通过多任务、多语言的方式训练了一个 seq2seq &#xff08;语音到文本&#xff09;的 Transformer 模型&#xff0c;自动语音识别&#xff08;ASR&#xff09;能力达到商用水准。本文为李沐老师…

mysql 工具 08s01_Mysql管理必备工具Maatkit详解之十四(mk-kill)

mk-kill - 顾名思义&#xff0c;杀mysql线程。安装方法查看这里。在一个OLTP的生产环境&#xff0c;一般不会让sql执行过长的时间&#xff0c;特别是myisam这样表锁的引擎&#xff0c;如果出现长时间执行的sql一般是误操作&#xff0c;要不就是出现问题了。出现这种情况&#x…

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 转自&#xff1a;【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 作者&#xff1a;潘小小 知识蒸馏是一种模型压缩方法&#xff0c;是一种基于“教师-学生网络思想”的训练方法&#xff0c;由于其简单&#xf…

深度学习三大谜团:集成、知识蒸馏和自蒸馏

深度学习三大谜团&#xff1a;集成、知识蒸馏和自蒸馏 转自&#xff1a;https://mp.weixin.qq.com/s/DdgjJ-j6jHHleGtq8DlNSA 原文&#xff08;英&#xff09;&#xff1a;https://www.microsoft.com/en-us/research/blog/three-mysteries-in-deep-learning-ensemble-knowledge…

在墙上找垂直线_墙上如何快速找水平线

在装修房子的时候&#xff0c;墙面的面积一般都很大&#xff0c;所以在施工的时候要找准水平线很重要&#xff0c;那么一般施工人员是如何在墙上快速找水平线的呢&#xff1f;今天小编就来告诉大家几种找水平线的方法。一、如何快速找水平线1、用一根透明的软管&#xff0c;长度…

百度地图mysql打点_关于百度地图连接MYSQL的问题,谢谢啦!

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼大家好&#xff0c;刚使用百度地图API&#xff0c;请教大家一个问题&#xff0c;谢啦&#xff01;我需要从我的数据库中取出字段为"city"的所有数据&#xff0c;然后通过bdGEO()函数在地图上标注这些城市&#xff0c;我是…

PyTorch中的torch.nn.Parameter() 详解

PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数&#xff0c;笔者第一次见的时候也是大概能理解函数的用途&#xff0c;但是具体实现原理细节也是云里雾里&#xff0c;在参考了几篇博文&#xff0c;做过几个实验之后算是清晰了&am…

Vision Transformer(ViT)PyTorch代码全解析(附图解)

Vision Transformer&#xff08;ViT&#xff09;PyTorch代码全解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来&#xff0c;屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文&#xff0c;及其PyTorch实现&#xff0c;将整个ViT的代码做一…

hdfs的副本数为啥增加了_HDFS详解之块大小和副本数

1.HDFSHDFS : 伪分布式(学习)NNDNSNNsbin/start-dfs.sh(开启hdfs使用的脚本)bin/hdfs dfs -ls (输入命令加前缀bin/hdfs dfs)2.block(块)dfs.blocksize &#xff1a; 134217728(字节) / 128M 官网默认一个块的大小128M*举例理解块1个文件 130M&#xff0c;默认一个块的大小128M…

Linux下的ELF文件、链接、加载与库(含大量图文解析及例程)

Linux下的ELF文件、链接、加载与库 链接是将将各种代码和数据片段收集并组合为一个单一文件的过程&#xff0c;这个文件可以被加载到内存并执行。链接可以执行与编译时&#xff0c;也就是在源代码被翻译成机器代码时&#xff1b;也可以执行于加载时&#xff0c;也就是被加载器加…

mysql gender_Mysql第一弹

1、创建数据库pythoncreate database python charsetutf8;2、设计班级表结构为id、name、isdelete&#xff0c;编写创建表的语句create table classes(id int unsigned auto_increment primary key not null,name varchar(10),isdelete bit default 0);向班级表中插入数据pytho…

python virtualenv nginx_Ubuntu下搭建Nginx+supervisor+pypy+virtualenv

系统&#xff1a;Ubuntu 14.04 LTS搭建python的运行环境&#xff1a;NginxSupervisorPypyVirtualenv软件说明&#xff1a;Nginx&#xff1a;通过upstream进行负载均衡Supervisor&#xff1a;管理python进程Pypy&#xff1a;用Python实现的Python解释器PyPy is a fast, complian…

如何设置mysql表中文乱码_php mysql表中文乱码问题如何解决

为避免mysql中出现中文乱码&#xff0c;建议在创建数据库时指定编码格式&#xff1a;复制代码 代码示例:create database zzjz CHARACTER SET gbk COLLATE gbk_chinese_ci;create table zz_employees (employeeid int unsigned not null auto_increment primary key,name varch…

java 按钮 监听_Button的四种监听方式

Button按钮设置点击的四种监听方式注&#xff1a;加粗放大的都是改变的代码1.使用匿名内部类的形式进行设置使用匿名内部类的形式&#xff0c;直接将需要设置的onClickListener接口对象初始化&#xff0c;内部的onClick方法会在按钮被点击的时候执行第一个活动的java代码&#…

java int转bitmap_Java Base64位编码与String字符串的相互转换,Base64与Bitmap的相互转换实例代码...

首先是网上大神给的类package com.duanlian.daimengmusic.utils;public final class Base64Util {private static final int BASELENGTH 128;private static final int LOOKUPLENGTH 64;private static final int TWENTYFOURBITGROUP 24;private static final int EIGHTBIT …

linux查看java虚拟机内存_深入理解java虚拟机(linux与jvm内存关系)

本文转载自美团技术团队发表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux与进程内存模型要理解jvm最重要的一点是要知道jvm只是linux的一个进程,把jvm的视野放大,就能很好的理解JVM细分的一些概念下图给出了硬件系统进程三个层面内存之间的关系.从硬件上…

java 循环stringbuffer_java常用类-----StringBuilder和StringBuffer的用法

一、可变字符常用方法package cn.zxg.PackgeUse;/*** 测试StringBuilder,StringBuffer可变字符序列常用方法*/public class TestStringBuilder2 {public static void main(String[] args) {StringBuilder sbnew StringBuilder();for(int i0;i<26;i){char temp(char)(ai);sb.…

java function void_Java8中你可能不知道的一些地方之函数式接口实战

什么时候可以使用 Lambda&#xff1f;通常 Lambda 表达式是用在函数式接口上使用的。从 Java8 开始引入了函数式接口&#xff0c;其说明比较简单&#xff1a;函数式接口(Functional Interface)就是一个有且仅有一个抽象方法&#xff0c;但是可以有多个非抽象方法的接口。 java8…