Pytorch nn.Module

一、torch.nn简介

torch.nn是 PyTorch 中用于构建神经网络的模块。它提供了一系列的类和函数,用于定义神经网络的各种层、损失函数、优化器等。

torch.nn提供的类:

  • Module: 所有神经网络模型的基类,用于定义自定义神经网络模型。
  • Linear: 线性层,进行线性变换。
  • Conv2d: 二维卷积层。
  • RNN, LSTM, GRU: 循环神经网络层,分别对应简单RNN、长短时记忆网络(LSTM)、门控循环单元(GRU)。
  • BatchNorm2d: 二维批归一化层。
  • CrossEntropyLoss, MSELoss: 分类交叉熵损失函数和均方误差损失函数等。
  • 等等

torch.nn提供的函数:

  • functional: 包含各种神经网络相关的函数,如激活函数 (relu, sigmoid, tanh 等)、池化函数 (max_pool2d, avg_pool2d)、归一化函数 (batch_norm) 等。
  • init: 参数初始化函数,如常用的均匀分布初始化 (uniform_)、正态分布初始化 (normal_)、Xavier 初始化 (xavier_uniform_, xavier_normal_) 等。
  • conv: 用于创建卷积层的函数。
  • linear: 用于创建线性层的函数。
  • dropout: 用于创建 dropout 层的函数。
  • batch_norm: 用于创建批归一化层的函数。

这些类和函数提供了构建、训练和使用神经网络模型所需的基本组件和功能,使得用户可以方便地定义和管理各种类型的神经网络结构,并实现各种机器学习任务。

二、如何继承torch.nn提供的内置模块

在这个示例中,我们创建了一个名为 MyLinearLayer 的类,它继承自 nn.Linear。在 __init__ 方法中,我们通过调用 super() 来调用父类的初始化方法,并传入适当的参数。在 forward 方法中,我们可以添加自定义的前向传播逻辑,然后调用父类的 forward 方法来执行线性层的计算。

import torch
import torch.nn as nnclass MyLinearLayer(nn.Linear):def __init__(self, in_features, out_features):super(MyLinearLayer, self).__init__(in_features, out_features)# 这里可以添加自定义的初始化操作或其他设置def forward(self, x):# 这里可以添加自定义的前向传播逻辑return super(MyLinearLayer, self).forward(x)# 创建自定义的线性层
my_linear_layer = MyLinearLayer(10, 5)

可以根据需要继承其他 PyTorch 内置模块,只需将相应的模块类作为父类即可。

三、torch.nn.Module的常用用法

1、继承和初始化:

为了定义自己的神经网络模型,通常需要创建一个子类,并将其继承自 torch.nn.Module。在子类的初始化函数 __init__() 中,可以定义模型的各个层和组件,并将它们作为模块的属性。

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = self.fc1(x)x = torch.relu(x)x = self.fc2(x)return x

2、前向传播

在子类中实现 forward() 方法来定义模型的前向传播过程。该方法接受输入数据 x,并根据模型的结构和参数计算输出结果。

3、参数管理

torch.nn.Module 可以跟踪模型的所有可学习参数,并提供方便的方法来访问和管理这些参数。可以使用 parameters() 方法获取模型的所有参数,也可以使用 named_parameters() 方法获取参数及其对应的名称。

model = MyModel()
params = list(model.parameters())
for name, param in model.named_parameters():print(name, param.size())

4、移动模型到不同设备

可以使用 to() 方法将模型移动到指定的设备(如 CPU 或 GPU)上进行计算。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

5、保存和加载模型

可以使用 torch.save()torch.load() 函数来保存和加载整个模型,也可以只保存和加载模型的参数。

torch.save(model.state_dict(), 'model.pth')
model.load_state_dict(torch.load('model.pth'))

6、自定义模型

可以通过继承 torch.nn.Module 来定义自己的模型结构,可以根据任务需求自由组合各种层和组件,也可以在 forward() 方法中实现自定义的复杂计算逻辑。

等等.......

四、torch.nn会自动后向传播

torch.nn 模块中的大多数层(如线性层、卷积层等)以及损失函数都会自动支持反向传播。在使用这些层构建神经网络模型时,不需要手动实现反向传播算法。PyTorch 提供了自动求导(Autograd)机制,能够根据定义的前向传播过程自动计算梯度,并通过反向传播算法自动求解梯度。具体来说:

  1. 在模型的前向传播过程中,PyTorch 会跟踪所有参与计算的张量,并构建计算图。
  2. 一旦得到了输出结果,可以调用backword()方法自动计算梯度。
  3. 在调用 backword()方法后,PyTorch 会沿着计算图执行反向传播算法,计算所有需要的梯度。
  4. 梯度被存储在每个张量的.grad 属性中,可以用于参数更新。
    loss.backward()#自动计算损失函数关于模型参数的梯度,即反向传播过程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/748187.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring-1

目录 概念 优点 Autowired和Resource关键字 相同点 不同点 依赖注入的三种方式 概念 Spring 是个java企业级应用的开源开发框架。Spring主要用来开发Java应用,但是有些扩展是针对构建J2EE(Java平台企业版)平台的web应用。Spring 框架目…

前端算法 - 面试记录

1. 实现两个字符串相加(不能直接转成数字使用加法运算,因为js精度规定不能超出一定长度) 现场写法完善: function addStr(a, b) {let res let j 0const add (x, y) > {let numif (x y > 10) {num x y j - 10j 1}e…

AI推介-多模态视觉语言模型VLMs论文速览(arXiv方向):2024.03.10-2024.03.15

论文目录~ 1.3D-VLA: A 3D Vision-Language-Action Generative World Model2.PosSAM: Panoptic Open-vocabulary Segment Anything3.Anomaly Detection by Adapting a pre-trained Vision Language Model4.Introducing Routing Functions to Vision-Language Parameter-Efficie…

java-ssm-jsp-基于java的客户管理系统的设计与实现

java-ssm-jsp-基于java的客户管理系统的设计与实现 获取源码——》公主号:计算机专业毕设大全

自习室预订系统|基于springboot框架+ Mysql+Java+B/S架构的自习室预订系统设计与实现(可运行源码+数据库+设计文档+部署说明)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 学生功能模块 管理员功能登录前台功能效果图 系统功能设计 数据库E-R图设计 lunwen参…

DirectShowPlayerService::doSetUrlSource: Unresolved error code 0x800c000d

报出这个问题,应该是对给的url解析不正确,我给的是rtsp的视频流地址,应该是对该格式解析异常。 所以参考两篇文: QT无法播放视频:报错:DirectShowPlayerService::doRender: Unresolved error code 0x8004…

OCP NVME SSD规范解读-12.Telemetry日志要求

以NVME SSD为例,通常大家想到的是观察SMAR-log定位异常,但是这个信息在多数情况下无法只能支撑完整的定位链路。 定位能力的缺失和低效是数据中心问题解决最大的障碍。 为了解决这个问题,Meta的做法是推进OCP组织加入延迟记录页面。同时NVME协…

练习题手撕总结

基础篇 1.基础知识(时间复杂度、空间复杂度等) 2.线性表(顺序表、单链表) 3.双链表、循环链表 4.队列 5.栈 6.递归算法 7.树、二叉树(递归、非递归遍历) 8.二叉搜索树(BST) 9.二分查…

Android Studio实现内容丰富的安卓宠物医院管理系统

获取源码请点击文章末尾QQ名片联系,源码不免费,尊重创作,尊重劳动 项目编号128 1.开发环境android stuido jdk1.8 eclipse mysql tomcat 2.功能介绍 安卓端: 1.注册登录 2.系统公告 3.宠物社区(可发布宠物帖子&#xf…

Boyer Moore 算法介绍

1. Boyer Moore 算法介绍 Boyer Moore 算法:简称为 BM 算法,是由它的两位发明者 Robert S. Boyer 和 J Strother Moore 的名字来命名的。BM 算法是他们在 1977 年提出的高效字符串搜索算法。在实际应用中,比 KMP 算法要快 3~5 倍。 BM 算法思…

数据结构 之 队列(Queue)

​​​​​​​ 🎉欢迎大家观看AUGENSTERN_dc的文章(o゜▽゜)o☆✨✨ 🎉感谢各位读者在百忙之中抽出时间来垂阅我的文章,我会尽我所能向的大家分享我的知识和经验📖 🎉希望我们在一篇篇的文章中能够共同进步&#xff0…

IO流——缓冲流

缓冲流 缓冲流作用:对原始流进行包装,以提高原始流读写数据的性能 字节缓冲流 作用:提高字节流读写数据的性能原理:字节缓冲输入流自带了8KB的缓冲池,字节缓冲输出流也自带了8KB的缓冲池 构造器说明public Buffer…

JAVA爬虫系列

目录 准备工作 yml 1.入门程序(获取到静态页面) 2.HttpClient---Get 2.1 修改成连接池 3.HttpClient---Get带参数 3.1 修改成连接池 4.HttpClient---Post 4.1 修改成连接池 5.HttpClient---Post带参数 6.HttpClient-连接池 7.设置请求信息 …

蓝桥真题——-小蓝重组质数(全排列和质数判断)

小蓝有一个十进制正整数n&#xff0c;其不包含数码0&#xff0c;现在小蓝可以任意打乱数码的顺序&#xff0c;小蓝想知道通过打乱数码顺序,n 可以变成多少个不同的质数。 #include <iostream> #include<bits/stdc.h> using namespace std; bool isprime(int n) {if…

讯鹏Andon系统解决方案帮助工厂打造生产过程透明化

在现代制造业中&#xff0c;高效透明的生产管理模式对企业的发展至关重要。Andon系统作为一种解决方案&#xff0c;通过软硬件结合的方式&#xff0c;为企业打造了高效透明的生产管理模式&#xff0c;帮助企业实现生产过程的优化和管理的可视化。 Andon系统的软硬件结合为企业提…

swiftUI中的可变属性和封装

swiftUI的可变属性 关于swift中的属性&#xff0c;声明常量使用let &#xff0c; 声明变量使用var 如果需要在swiftUI中更改视图变化那么就需要在 var前面加上state 。 通过挂载到state列表 &#xff0c;从而让xcode找到对应的改变的值 例子&#xff1a; import SwiftUIstruc…

【兆易创新GD32H759I-EVAL开发板】图像处理加速器(IPA)的应用

GD32H7系列的IPA&#xff08;Image Pixel Accelerator&#xff09;是一个高效的图像处理硬件加速器&#xff0c;专门设计用于加速图像处理操作&#xff0c;如像素格式转换、图像旋转、缩放等。它的优势在于能够利用硬件加速来实现这些操作&#xff0c;相比于软件实现&#xff0…

BLE---Service interoperability requirements

0 Preface/Foreword references: Bluetooth core specification V5.4 definition&#xff1a;定义 declaration&#xff1a;声明 1 service definition&#xff08;服务定义&#xff09; 服务定义&#xff08;definition&#xff09;&#xff1a;必须包含服务声明(declara…

【JavaScript】JavaScript 运算符 ① ( 运算符分类 | 算术运算符 | 浮点数 的 算术运算 精度问题 )

文章目录 一、JavaScript 运算符1、运算符分类2、算术运算符3、浮点数 的 算术运算 精度问题 一、JavaScript 运算符 1、运算符分类 在 JavaScript 中 , 运算符 又称为 " 操作符 " , 可以实现 赋值 , 比较 > < , 算术运算 -*/ 等功能 , 运算符功能主要分为以下…

安卓UI面试题 6-10

6. SurfaceView & View 的区别?SurfaceView是在一个新启的单独线程中可以重新绘制画面,而View必须在UI的主线程中更新画面。在UI的主线程中更新画面,可能会引发一些问题,比如你更新画面的时间过长,那么你的主UI线程会被你正在画的函数阻塞。那么将无法响应按键,触屏等…