PyTorch 张量与自动微分操作

笔记

1 张量索引操作

import torch
​
# 下标从左到右从0开始(0->第一个值), 从右到左从-1开始
# data[行下标, 列下标]
# data[0轴下标, 1轴下标, 2轴下标]
​
def dm01():# 创建张量torch.manual_seed(0)data = torch.randint(low=0, high=10, size=(4, 5))print('data->', data)# 根据下标值获取对应位置的元素# 行数据 第一行print('data[0] ->', data[0])# 列数据 第一列print('data[:, 0]->', data[:, 0])# 根据下标列表取值# 第二行第三列的值和第四行第五列值print('data[[1, 3], [2, 4]]->', data[[1, 3], [2, 4]])# [[1], [3]: 第二行第三列 第二行第五列值   第四行第三列 第四行第五列值print('data[[[1], [3]], [2, 4]]->', data[[[1], [3]], [2, 4]])# 根据布尔值取值# 第二列大于6的所有行数据print(data[:, 1] > 6)print('data[data[:, 1] > 6]->', data[data[:, 1] > 6])# 第三行大于6的所有列数据print('data[:, data[2]>6]->', data[:, data[2] > 6])# 根据范围取值  切片  [起始下标:结束下标:步长]# 第一行第三行以及第二列第四列张量print('data[::2, 1::2]->', data[::2, 1::2])
​# 创建三维张量data2 = torch.randint(0, 10, (3, 4, 5))print("data2->", data2)# 0轴第一个值print(data2[0, :, :])# 1轴第一个值print(data2[:, 0, :])# 2轴第一个值print(data2[:, :, 0])
​
​
if __name__ == '__main__':dm01()

2 张量形状操作

2.1 reshape

import torch
​
​
# reshape(shape=(行,列)): 修改连续或非连续张量的形状, 不改数据
# -1: 表示自动计算行或列   例如:  (5, 6) -> (-1, 3) -1*3=5*6 -1=10  (10, 3)
def dm01():torch.manual_seed(0)t1 = torch.randint(0, 10, (5, 6))print('t1->', t1)print('t1的形状->', t1.shape)# 形状修改为 (2, 15)t2 = t1.reshape(shape=(2, 15))t3 = t1.reshape(shape=(2, -1))print('t2->', t2)print('t2的形状->', t2.shape)print('t3->', t3)print('t3的形状->', t3.shape)
​
​
​
if __name__ == '__main__':dm01()

2.2 squeeze和unsqueeze

# squeeze(dim=): 删除值为1的维度, dim->指定维度, 维度值不为1不生效  不设置dim,删除所有值为1的维度
# 例如: (3,1,2,1) -> squeeze()->(3,2)  squeeze(dim=1)->(3,2,1)
# unqueeze(dim=): 在指定维度上增加值为1的维度  dim=-1:最后维度
def dm02():torch.manual_seed(0)# 四维t1 = torch.randint(0, 10, (3, 1, 2, 1))print('t1->', t1)print('t1的形状->', t1.shape)# squeeze: 降维t2 = torch.squeeze(t1)print('t2->', t2)print('t2的形状->', t2.shape)# dim: 指定维度t3 = torch.squeeze(t1, dim=1)print('t3->', t3)print('t3的形状->', t3.shape)# unsqueeze: 升维# (3, 2)->(1, 3, 2)# t4 = t2.unsqueeze(dim=0)# 最后维度 (3, 2)->(3, 2, 1)t4 = t2.unsqueeze(dim=-1)print('t4->', t4)print('t4的形状->', t4.shape)
​
​
if __name__ == '__main__':dm02()

2.3 transpose和permute

# 调换维度
# torch.permute(input=,dims=): 改变张量任意维度顺序
# input: 张量对象
# dims: 改变后的维度顺序, 传入轴下标值 (1,2,3)->(3,1,2)
# torch.transpose(input=,dim0=,dim1=): 改变张量两个维度顺序
# dim0: 轴下标值, 第一个维度
# dim1: 轴下标值, 第二个维度
# (1,2,3)->(2,1,3) 一次只能交换两个维度
def dm03():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(3, 4, 5))print('t1->', t1)print('t1形状->', t1.shape)# 交换0维和1维数据# t2 = t1.transpose(dim0=1, dim1=0)t2 = t1.permute(dims=(1, 0, 2))print('t2->', t2)print('t2形状->', t2.shape)# t1形状修改为 (5, 3, 4)t3 = t1.permute(dims=(2, 0, 1))print('t3->', t3)print('t3形状->', t3.shape)
​
​
if __name__ == '__main__':dm03()

2.4 view和contiguous

# tensor.view(shape=): 修改连续张量的形状, 操作等同于reshape()
# tensor.is_contiugous(): 判断张量是否连续, 返回True/False  张量经过transpose/permute处理变成不连续
# tensor.contiugous(): 将张量转为连续张量
def dm04():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(3, 4))print('t1->', t1)print('t1形状->', t1.shape)print('t1是否连续->', t1.is_contiguous())# 修改张量形状t2 = t1.view((4, 3))print('t2->', t2)print('t2形状->', t2.shape)print('t2是否连续->', t2.is_contiguous())# 张量经过transpose操作t3 = t1.transpose(dim0=1, dim1=0)print('t3->', t3)print('t3形状->', t3.shape)print('t3是否连续->', t3.is_contiguous())# 修改张量形状# view# contiugous(): 转换成连续张量t4 = t3.contiguous().view((3, 4))print('t4->', t4)t5 = t3.reshape(shape=(3, 4))print('t5->', t5)print('t5是否连续->', t5.is_contiguous())
​
​
if __name__ == '__main__':dm04()

3 张量拼接操作

3.1 cat/concat

import torch
​
​
# torch.cat()/concat(tensors=, dim=): 在指定维度上进行拼接, 其他维度值必须相同, 不改变新张量的维度, 指定维度值相加
# tensors: 多个张量列表
# dim: 拼接维度
def dm01():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(2, 3))t2 = torch.randint(low=0, high=10, size=(2, 3))t3 = torch.cat(tensors=[t1, t2], dim=0)print('t3->', t3)print('t3形状->', t3.shape)t4 = torch.concat(tensors=[t1, t2], dim=1)print('t4->', t4)print('t4形状->', t4.shape)​
if __name__ == '__main__':# dm01()

3.2 stack

# torch.stack(tensors=, dim=): 根据指定维度进行堆叠, 在指定维度上新增一个维度(维度值张量个数), 新张量维度发生改变
# tensors: 多个张量列表
# dim: 拼接维度
def dm02():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(2, 3))t2 = torch.randint(low=0, high=10, size=(2, 3))t3 = torch.stack(tensors=[t1, t2], dim=0)# t3 = torch.stack(tensors=[t1, t2], dim=1)print('t3->', t3)print('t3形状->', t3.shape)
​
​
if __name__ == '__main__':dm02()

4 自动微分模块

4.1 梯度计算

"""
梯度: 求导,求微分 上山下山最快的方向
梯度下降法: W1=W0-lr*梯度   lr是可调整已知参数  W0:初始模型的权重,已知  计算出W0的梯度后更新到W1权重
pytorch中如何自动计算梯度 自动微分模块
注意点: ①loss标量和w向量进行微分  ②梯度默认累加,计算当前的梯度, 梯度值是上次和当前次求和  ③梯度存储.grad属性中
"""
import torch
​
​
def dm01():# 创建标量张量 w权重# requires_grad: 是否自动微分,默认False# dtype: 自动微分的张量元素类型必须是浮点类型# w = torch.tensor(data=10, requires_grad=True, dtype=torch.float32)# 创建向量张量 w权重w = torch.tensor(data=[10, 20], requires_grad=True, dtype=torch.float32)# 定义损失函数, 计算损失值loss = 2 * w ** 2print('loss->', loss)print('loss.sum()->', loss.sum())# 计算梯度 反向传播  loss必须是标量张量,否则无法计算梯度loss.sum().backward()# 获取w权重的梯度值print('w.grad->', w.grad)w.data = w.data - 0.01 * w.gradprint('w->', w)
​
​
if __name__ == '__main__':dm01()

4.2 梯度下降法求最优解

"""
① 创建自动微分w权重张量
② 自定义损失函数 loss=w**2+20  后续无需自定义,导入不同问题损失函数模块
③ 前向传播 -> 先根据上一版模型计算预测y值, 根据损失函数计算出损失值
④ 反向传播 -> 计算梯度
⑤ 梯度更新 -> 梯度下降法更新w权重
"""
import torch
​
​
def dm01():# ① 创建自动微分w权重张量w = torch.tensor(data=10, requires_grad=True, dtype=torch.float32)print('w->', w)# ② 自定义损失函数 后续无需自定义, 导入不同问题损失函数模块loss = w ** 2 + 20print('loss->', loss)# 0.01 -> 学习率print('开始 权重x初始值:%.6f (0.01 * w.grad):无 loss:%.6f' % (w, loss))for i in range(1, 1001):# ③ 前向传播 -> 先根据上一版模型计算预测y值, 根据损失函数计算出损失值loss = w ** 2 + 20# 梯度清零 -> 梯度累加, 没有梯度默认Noneif w.grad is not None:w.grad.zero_()# ④ 反向传播 -> 计算梯度loss.sum().backward()# ⑤ 梯度更新 -> 梯度下降法更新w权重# W = W - lr * W.grad# w.data -> 更新w张量对象的数据, 不能直接使用w(将结果重新保存到一个新的变量中)w.data = w.data - 0.01 * w.gradprint('w.grad->', w.grad)print('次数:%d 权重w: %.6f, (0.01 * w.grad):%.6f loss:%.6f' % (i, w, 0.01 * w.grad, loss))
​print('w->', w, w.grad, 'loss最小值', loss)
​
​
if __name__ == '__main__':dm01()

4.3 梯度计算注意点

# 自动微分的张量不能转换成numpy数组, 可以借助detach()方法生成新的不自动微分张量
import torch
​
​
def dm01():x1 = torch.tensor(data=10, requires_grad=True, dtype=torch.float32)print('x1->', x1)# 判断张量是否自动微分 返回True/Falseprint(x1.requires_grad)# 调用detach()方法对x1进行剥离, 得到新的张量,不能自动微分,数据和原张量共享x2 = x1.detach()print(x2.requires_grad)print(x1.data)print(x2.data)print(id(x1.data))print(id(x2.data))# 自动微分张量转换成numpy数组n1 = x2.numpy()print('n1->', n1)
​
​
if __name__ == '__main__':dm01()

4.4 自动微分模块应用

import torch
import torch.nn as nn  # 损失函数,优化器函数,模型函数
​
​
def dm01():# todo:1-定义样本的x和yx = torch.ones(size=(2, 5))y = torch.zeros(size=(2, 3))print('x->', x)print('y->', y)# todo:2-初始模型权重 w b 自动微分张量w = torch.randn(size=(5, 3), requires_grad=True)b = torch.randn(size=(3,), requires_grad=True)print('w->', w)print('b->', b)# todo:3-初始模型,计算预测y值y_pred = torch.matmul(x, w) + bprint('y_pred->', y_pred)# todo:4-根据MSE损失函数计算损失值# 创建MSE对象, 类创建对象criterion = nn.MSELoss()loss = criterion(y_pred, y)print('loss->', loss)# todo:5-反向传播,计算w和b梯度loss.sum().backward()print('w.grad->', w.grad)print('b.grad->', b.grad)
​
​
if __name__ == '__main__':dm01()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/79486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接口的基础定义与属性约束

在 TypeScript 中,接口(Interface)是一个非常强大且常用的特性。接口定义了对象的结构,包括对象的属性和方法,可以为对象提供类型检查和约束。通过接口,我们可以清晰地描述一个对象应该具备哪些属性和方法。…

高效全能PDF工具,支持OCR识别

软件介绍 PDF XChange Editor是一款功能强大的PDF编辑工具,支持多种操作功能,不仅可编辑PDF内容与图片,还具备OCR识别表单信息的能力,满足多种场景下的需求。 软件特点 这款PDF编辑器完全免费,用户下载后直接…

OpenCV 中用于背景分割的一个类cv::bgsegm::BackgroundSubtractorGMG

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::bgsegm::BackgroundSubtractorGMG 是 OpenCV 中用于背景分割的一个类,它实现了基于贝叶斯推理的背景建模算法(Bayesi…

MongoDB知识框架

简介:MongoDB 是一个基于分布式文件存储的数据库,属于 NoSQL 数据库产品,以下是其知识框架总结: 一、数据模型 文档:MongoDB 中的数据以 BSON(二进制形式的 JSON)格式存储在集合中,…

WEBSTORM前端 —— 第2章:CSS —— 第8节:网页制作2(小兔鲜儿)

目录 1.项目目录 2.SEO 三大标签 3.Favicon 图标 4.版心 5.快捷导航(shortcut) 6.头部(header) 7.底部(footer) 8.banner 9.banner – 圆点 10.新鲜好物(goods) 11.热门品牌(brand) 12.生鲜(fresh) 13.最新专题(topic) 1.项目目录 【xtx-pc】 ima…

1、RocketMQ 核心架构拆解

1. 为什么要使用消息队列? 消息队列(MQ)是分布式系统中不可或缺的中间件,主要解决系统间的解耦、异步和削峰填谷问题。 解耦:生产者和消费者通过消息队列通信,彼此无需直接依赖,极大提升系统灵…

[Linux网络_71] NAT技术 | 正反代理 | 网络协议总结 | 五种IO模型

目录 1.NAT技术 NAPT 2.NAT和代理服务器 3.网线通信各层协议总结 补充说明 4.五种 IO 模型 1.什么是IO?什么是高效的IO? 2.有那些IO的方式?这么多的方式,有那些是高效的? 异步 IO 🎣 关键缺陷类比…

Unity基础学习(八)时间相关内容Time

众所周知,每一个游戏都会有自己的时间。这个时间可以是内部,从游戏开始的时间,也可以是外部真实的物理时间,时间相关内容 主要用于游戏中 参与位移计时 时间暂停等。那么我们今天就来看看Unity中和时间相关的内容。 Unity时间功能…

Java游戏服务器开发流水账(1)游戏服务器的架构浅析

新项目立项停滞,头大。近期读老项目代码看到Java,笔记记录一下。 为什么要做服务器的架构 游戏服务器架构设计具有多方面的重要意义,它直接关系到游戏的性能、可扩展性、稳定性以及用户体验等关键因素 确保游戏的流畅运行 优化数据处理&a…

计算机视觉与深度学习 | 基于Transformer的低照度图像增强技术

基于Transformer的低照度图像增强技术通过结合Transformer的全局建模能力和传统图像增强理论(如Retinex),在保留颜色信息、抑制噪声和平衡亮度方面展现出显著优势。以下是其核心原理、关键公式及典型代码实现: 一、原理分析 1. 全局依赖建模与局部特征融合 Transformer的核…

Linux 文件目录管理常用命令

pwd 显示当前绝对路径 cd 切换目录 指令备注cd -回退cd …返回上一层cd ~切换到用户主目录 ls 列出目录的内容 指令备注ls -a显示当前目录中的所有文件和目录,包括隐藏文件ls -l以长格式显示当前目录中的文件和目录ls -hl以人类可读的方式显示当前目录中的文…

【Linux 系统调试】性能分析工具perf使用与调试方法

目录 一、perf基本概念 1‌. 事件类型‌ 2‌. 低开销高精度 3‌. 工具定位‌ 二、安装与基础配置 1. 安装方法 2. 启用符号调试 三、perf工作原理 1. 数据采集机制 2. 硬件事件转化流程 四、perf应用场景 1. 系统瓶颈定位 2. 锁竞争优化 3. 缓存优化 五、perf高级…

嵌入式中屏幕的通信方式

LCD屏通信方式详解 LCD屏(液晶显示屏)的通信方式直接影响其数据传输效率、显示刷新速度及硬件设计复杂度。根据应用场景和需求,LCD屏的通信方式主要分为以下三类,每种方式在协议类型、数据速率、硬件成本及适用场景上存在显著差异…

【el-admin】el-admin关联数据字典

数据字典使用 一、新增数据字典1、新增【图书状态】和【图书类型】数据字典2、编辑字典值 二、代码生成配置1、表单设置2、关联字典3、验证关联数据字典 三、查询操作1、模糊查询2、按类别查询(下拉框) 四、数据校验 一、新增数据字典 1、新增【图书状态…

【Spring】Spring MVC笔记

文章目录 一、SpringMVC简介1、什么是MVC2、什么是SpringMVC3、SpringMVC的特点 二、HelloWorld1、开发环境2、创建maven工程a>添加web模块b>打包方式:warc>引入依赖 3、配置web.xmla>默认配置方式b>扩展配置方式 4、创建请求控制器5、创建springMVC…

如何在大型项目中解决 VsCode 语言服务器崩溃的问题

在大型C/C项目中,VS Code的语言服务器(如C/C扩展)可能因内存不足或配置不当频繁崩溃。本文结合系统资源分析与实战技巧,提供一套完整的解决方案。 一、问题根源诊断 1.1 内存瓶颈分析 通过top命令查看系统资源使用情况&#xff…

LeetCode百题刷002摩尔投票法

遇到的问题都有解决的方案,希望我的博客可以为你提供一些帮助 图片源自leetcode 题目:169. 多数元素 - 力扣(LeetCode) 一、排序法 题目要求需要找到多数值(元素个数>n/2)并返回这个值。一般会想到先…

Android Studio Gradle 中 只显示 Tasks 中没有 build 选项解决办法

一、问题描述 想把项目中某一个模块的代码单独打包成 aar ,之前是点击 AndroidStudio 右侧的 Gradle 选项,然后再点击需要打包的模块找到 build 进行打包,但是却发现没有 build 选项。 二、解决办法 1、设置中勾选 Configure all Gradle tasks… 选项 …

深入浅出之STL源码分析2_stl与标准库,编译器的关系

引言 在第一篇博客中,深入浅出之STL源码分析1_vector基本操作-CSDN博客 我们将引出下面的几个问题 1.刚才我提到了我的编译器版本是g 11.4.0,而我们要讲解的是STL(标准模板库),那么二者之间的关系是什么?…

(十二)深入了解AVFoundation-采集:人脸识别与元数据处理

(一)深入了解AVFoundation:框架概述与核心模块解析-CSDN博客 (二) 深入了解AVFoundation - 播放:AVFoundation 播放基础入门-CSDN博客 (三)深入了解AVFoundation-播放&#xff1…