Pytorch:张量的形状操作

文章目录

      • 一、维度改变
        • 1.flatten展开
          • a.函数的基本用法
          • b.示例
        • 2.unsqueeze增维
          • a.函数的基本用法
          • b.示例
        • 3.squeeze降维
          • a.函数的基本用法
          • b.示例
      • 二、张量变形
        • 1.view()
          • a.函数的基本用法
          • b.参数:
          • c.注意事项
          • d.示例
        • 2.reshape()
          • a.注意事项
          • b.示例
        • 3.reshape_as()
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
      • 三、维度重排
        • 1.permute
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
        • 2.transpose
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意

维度改变和张量变形都不改变内存中存储的结构,因此改变后的张量的值顺序和没改变前是一样的。

一、维度改变

1.flatten展开
  • torch.flatten(tensor)
  • tensor.flatten()

torch.flatten() 是一个在 PyTorch 中常用于张量(tensor)处理的函数,它将输入张量展开成一个一维张量。该函数通常用于准备数据,将多维数据转换为一维,以便用于机器学习模型,特别是在模型的全连接层(fully connected layers)之前。
常用于展开成一维

a.函数的基本用法

只给定一个张量,将直接展开成一维。
torch.flatten(input, start_dim=0, end_dim=-1) 的参数解释如下:

  • input: 输入的张量。
  • start_dim: 开始展开的维度,默认为 0。这意味着从哪个维度开始将张量展开。
  • end_dim: 结束展开的维度,默认为 -1,即最后一个维度。这意味着展开将持续到哪个维度。
b.示例

考虑一个三维张量,例如形状为 (2, 3, 4) 的张量。如果使用 torch.flatten() 将其展开,可以有多种方式处理:

  1. 完全展开: 将整个张量展开成一维数组。

    import torch
    x = torch.randn(2, 3, 4)
    flat_x = torch.flatten(x)
    # 结果形状为 [24]
    
  2. 从特定维度开始展开: 指定从哪个维度开始展开。例如,从第一维(索引为 0 的维度)开始展开。

    flat_x = torch.flatten(x, start_dim=1)
    # 结果形状为 [2, 12],保留了第一个维度,其余维度被展开
    
2.unsqueeze增维
  • torch.unsqueeze(tensor)
  • tensor.unsqueeze()

torch.unsqueeze() 是 PyTorch 中用来增加张量的维度的函数。该函数可以在张量的指定位置插入一个维度,它非常有用于调整张量的形状,以满足特定操作或模型的需求,例如在单样本张量上应用需要批处理的模型。
常用于在第0个维度上增加大小为1的维度

a.函数的基本用法

torch.unsqueeze(input, dim) 的参数解释如下:

  • input: 输入的张量。
  • dim: 要插入新维度的索引位置。这个位置遵循 Python 的索引规则,支持负索引。
b.示例

假设有一个二维张量 x 形状为 (3, 4),表示一个包含3个样本,每个样本4个特征的数据集。如果需要在特定维度增加一个维度,可以使用 torch.unsqueeze() 如下:

import torch
x = torch.randn(3, 4)# 在第0维增加一个维度
x_unsqueezed = x.unsqueeze(0)
print(x_unsqueezed.shape)
# 输出: torch.Size([1, 3, 4])# 在第1维增加一个维度
x_unsqueezed = torch.unsqueeze(x, 1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 1, 4])# 使用负索引,在最后一个维度后增加一个维度
x_unsqueezed = torch.unsqueeze(x, -1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 4, 1])
3.squeeze降维
  • torch.squeeze(tensor)
  • tensor.squeeze()

torch.squeeze() 是 PyTorch 中的一个函数,用于减少张量的维度,特别是去除那些维度大小为1的维度。这个函数非常有用于去除由于某些操作(比如 unsqueeze)产生的单一维度,从而使张量的形状更加紧凑。

a.函数的基本用法

只给定一个张量,将直接去掉所有大小为1的维度。
torch.squeeze(input, dim=None) 的参数解释如下:

  • input: 输入的张量。
  • dim: 指定要压缩的维度。如果指定的维度大小为1,则该维度会被去除如果大小不为1,则该维度不会被压缩如果不指定 dim 参数,那么所有大小为1的维度都会被去除。
b.示例

考虑一个张量 x,其形状包括一些大小为1的维度。以下是如何使用 torch.squeeze() 来去除这些维度的示例:

import torch
x = torch.randn(1, 3, 1, 5)# 去除所有大小为1的维度
squeezed_x = x.squeeze()
print(squeezed_x.shape)
# 输出: torch.Size([3, 5])# 只压缩第0维(大小为1)
squeezed_x = x.squeeze(0)
print(squeezed_x.shape)
# 输出: torch.Size([3, 1, 5])# 只压缩第2维(大小为1)
squeezed_x = torch.squeeze(x, 2)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 5])# 尝试压缩一个不是大小为1的维度(没有变化)
squeezed_x = torch.squeeze(x, 1)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 1, 5])

二、张量变形

1.view()

在 PyTorch 中,.view() 方法是一个非常重要且常用的功能,用于改变张量的形状而不改变其数据内容。此方法提供了一种高效的方式来重新排列张量的维度,使其适应不同的需求,例如输入到一个模型或对数据进行不同的操作。
view是共享内存的!

a.函数的基本用法

.view() 方法的基本用法是 tensor.view(*shape),其中 *shape 是希望张量拥有的新形状,由一组维度大小组成。

b.参数:
  • shape: 新的形状,是一个由整数构成的元组,其中的每个整数指定相应维度的大小。你也可以在某个位置使用 -1,让 PyTorch 自动计算该维度的大小。(注意某个位置是任意的某个位置,但是只能有一个)
c.注意事项
  1. 连续性.view() 要求张量在内存中是连续的(即一维数组中的元素顺序与多维视图中的顺序相同)。如果张量不是连续的,你可能需要首先调用 .contiguous() 方法来使其连续。

  2. 自动计算维度:使用 -1 作为形状参数的一部分,PyTorch 将自动计算该维度的正确大小,以便保持元素总数不变。

  3. 大小不变.view()要求张量变换形状之后的大小和变换之前的大小是一样的。即维度大小之积相等。比如tensor.Size([2,4])tensor.Size([8])是一样的。

d.示例
import torch
x = torch.randn(4, 4)  # 创建一个 4x4 的张量# 改变形状为 2x8
y = x.view(2, 8)
print(y.shape)
# 输出: torch.Size([2, 8])# 改变形状为 16(一维)
z = x.view(-1)#z = x.view(16)
print(z.shape)
# 输出: torch.Size([16])# 使用 -1 自动计算维度
w = x.view(-1, 8)
print(w.shape)
# 输出: torch.Size([2, 8])
import torch
x = torch.randn(2, 1)  # 创建一个 2×1 的张量# 改变形状为 2x8
y = x.view(2)
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=2 #共享内存,y也会变
print(x)
print(y)
tensor([-0.5001,  0.5409])
tensor([[2.0000],[0.5409]])
tensor([2.0000, 0.5409])
2.reshape()

在 PyTorch 中,.reshape() 方法用于改变张量的形状而不改变其数据内容。
这一方法与 .view() 类似,都允许您重新排列张量的维度,但它们在处理非连续张量时的行为不同。
只有当非连续张量时,才会导致和.view不一样,如果是连续的,同样也是共享内存的。

a.注意事项
  1. 数据连续性:与 .view() 相比,.reshape() 可以处理非连续张量,如果必要,它会自动处理数据的内存复制。因此,如果原始张量不连续,而你尝试用 .view() 改变其形状可能会导致错误,但 .reshape() 会自动解决这个问题。

  2. 自动计算维度:使用 -1 作为形状参数的一部分时,PyTorch 会自动计算该维度的大小,以确保总元素数量与原张量相同。

b.示例
import torch
x = torch.randn(2, 3, 4)  # 创建一个 2x3x4 的张量# 改变形状为 6x4
y = x.reshape(6, 4)
print(y.shape)
# 输出: torch.Size([6, 4])# 改变形状为 1x24
z = x.reshape(1, 24)
print(z.shape)
# 输出: torch.Size([1, 24])# 使用 -1 自动计算维度
w = x.reshape(-1, 2)
print(w.shape)
# 输出: torch.Size([12, 2])
import torch
x = torch.randn(2, 2)  # 创建一个 2x1 的张量
x=x.transpose(0,1)
# 改变形状为 2x8
y = x.reshape(4)#转置后的x不是连续的,使用reshape产生复制,此时不能用.view()
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=100
print(x)
print(y)
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
tensor([[100.0000,  -0.1661],[ -0.3646,  -0.2516]])
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
3.reshape_as()

在 PyTorch 中,.reshape_as() 是一个方便的方法,用于将一个张量重新塑形为与另一个张量相同的形状。这个方法实质上是 .reshape() 方法的一个简化版本,它以另一个张量的形状为目标形状。
换句话说,.reshape_as()相当于是省略了自指定参数的.reshape(),而可以直接用目标张量形状作为形状。

a.函数的基本用法

.reshape_as() 的基本用法非常直接:tensor1.reshape_as(tensor2)。这会将 tensor1 的形状修改为与 tensor2 相同的形状。

b.参数:
  • tensor2: 这是模型张量,tensor1 将改变形状以匹配 tensor2 的形状。
c.示例
import torch
x = torch.randn(2, 3, 4)  # 原始张量,形状为 2x3x4
y = torch.randn(6, 4)     # 目标张量,形状为 6x4# 将 x 的形状改变为与 y 相同
z = x.reshape_as(y)
print(z.shape)
# 输出: torch.Size([6, 4])
d.注意

虽然 .reshape_as() 很方便,但使用它时应确保两个张量具有相同的元素总数,因为改变形状的操作不会改变数据的总量。如果两个张量的总元素数量不匹配,尝试使用 .reshape_as() 将抛出错误。此外,如果原始张量在内存中是非连续的,.reshape_as() 会像 .reshape() 一样处理,可能需要在内部进行数据复制以确保连续性。

三、维度重排

permute方法可以按照指定顺序重新排列维度,而transpose方法可以交换张量的两个维度。用于需要进行维度重排或转置操作。如矩阵转置。

1.permute

在 PyTorch 中,.permute() 方法用于重新排列张量的维度,这是处理多维数据时一个非常有用的功能,尤其在需要对维度进行特定的重排序操作时。

a.函数的基本用法

.permute() 方法的调用格式为 tensor.permute(*dims),其中 *dims 是一个整数序列,代表新的维度排列顺序。

b.参数:
  • dims: 这个参数定义了张量的每个维度应该如何重新排列。序列中的每个整数都代表原始张量中一个维度的索引,这些索引的排列顺序确定了输出张量的形状。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量# 改变维度的排列顺序为 [2, 0, 1]
y = x.permute(2, 0, 1)
print(y.shape)
# 输出: torch.Size([5, 2, 3])# 将维度的排列顺序改为 [1, 2, 0]
z = x.permute(1, 2, 0)
print(z.shape)
# 输出: torch.Size([3, 5, 2])
d.注意
import torch
x = torch.tensor([[1,2,3,4],[2,4,2,4],[5,6,7,8]]) 
x = x.permute(1,0)
'''
tensor([[1, 2, 5],[2, 4, 6],[3, 2, 7],[4, 4, 8]])
'''

在 PyTorch 中,当使用 .permute() 方法重排张量维度时,张量的数据实际上在内存中的位置并没有改变。更准确地说.permute() 改变的是张量访问这些数据的方式,通过调整形状(shape)步长(stride) 的元信息,而不是数据本身。

  • 步长(Stride)
    • 步长是一个定义在每一维上的整数数组,表示为了在数据中从当前维度的一个元素移动到下一个元素,需要跨过的内存位置数。对于一个连续的张量,步长决定了元素在内存中的布局。

形状(Shape)和步长的调整当调用 .permute(1,0) 时,你实际上是告诉 PyTorch 以一个新的顺序来解释原始数据的内存布局。例如:

x = torch.tensor([[1, 2, 3, 4],[2, 4, 2, 4],[5, 6, 7, 8]])

原始的 x 的形状为 (3, 4),即有 3 行和 4 列。在 PyTorch 中,这意味着其步长为 (4, 1),其中 4 表示要从一行的开始移动到下一行的开始,在内存中需要跨过 4 个元素位置;1 表示在同一行中从一个元素移动到下一个元素,只需要跨过 1 个元素位置。

当你调用 x.permute(1, 0) 时,你是在指示 PyTorch 将原来的列视为行,将原来的行视为列。这就改变了形状为 (4, 3)。这时,步长变为 (1, 4)。这意味着:

  • 要从列的一个元素到下一个元素(现在变成了“行”移动),你只需要移动一个数据位置(原来的行移动)。
  • 要从一行移动到下一行(现在是原来的列跨行移动),你需要跨过 4 个数据位置。
2.transpose

在 PyTorch 中,.transpose() 方法用于交换张量中的两个维度,这是处理多维数组时一个常用的功能,尤其是在需要对特定的维度进行转置操作时。

a.函数的基本用法

.transpose() 方法的调用格式为 tensor.transpose(dim0, dim1),其中 dim0dim1 是要交换的维度的索引。

b.参数:
  • dim0: 第一个要交换的维度的索引。
  • dim1: 第二个要交换的维度的索引。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量# 交换维度 0 和 1
y = x.transpose(0, 1)
print(y.shape)
# 输出: torch.Size([3, 2, 5])# 交换维度 1 和 2
z = x.transpose(1, 2)
print(z.shape)
# 输出: torch.Size([2, 5, 3])
d.注意

.permute() 类似,.transpose() 也是返回原始数据的一个新视图,并不复制数据。因此,输出张量与输入张量共享同一块内存空间,只是它们的形状和步长(stride)不同。同样,.transpose() 会导致张量在内存中可能变为非连续,因此在某些情况下,可能需要调用 .contiguous() 来使张量在内存中连续。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/824570.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解 pytest Fixture 方法及其应用

当涉及到编写自动化测试时,测试框架和工具的选择对于测试用例的设计和执行非常重要。在Python 中,pytest是一种广泛使用的测试框架,它提供了丰富的功能和灵活的扩展性。其中一个很有用的功 能是fixture方法,它允许我们初始化测试环…

HTML5漫画风格个人介绍源码

源码介绍 HTML5漫画风格个人介绍源码,源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 效果截图 源码下载 HTML5漫画风格…

设计模式———单例模式

单例也就是只能有一个实例,即只创建一个实例对象,不能有多个。 可能会疑惑,那我写代码的时候注意点,只new一次不就得了。理论上是可以的,但在实际中很难实现,因为你无法预料到后面是否会脑抽一下~~因此我们…

【Pytorch】Conv1d

conv1d 先看看官方文档 再来个简单的例子 import torch import numpy as np import torch.nn as nndata np.arange(1, 13).reshape([1, 4, 3]) data torch.tensor(data, dtypetorch.float) print("[data]:\n", data) conv nn.Conv1d(in_channels4, out_channels1…

启明智显应用分享|基于ESP32-S3方案的SC01PLUS彩屏与chatgpt融合应用DEMO

今天将带大家真实体验科技与智慧的完美融合——SC01PLUS与ChatGPT的深度融合DEMO效果呈现。 彩屏的清晰显示与ChatGPT的精准回答,将为我们带来前所未有的便捷与高效。 SC01PLUS是启明智显基于ESP32-S3打造的一款3.5寸480*320分辨率的彩屏产品,您可以看…

【Git】git命令大全(持续更新)

本文架构 0.描述git简介术语 1.常用命令2. 信息管理新建git库命令更改存在库设置获取当前库信息 3.工作空间相关将工作空间文件添加到缓存区(增)从工作空间中移除文件(删)撤销提交 4.远程仓库相关同步远程仓库分支 (持…

高版本Android studio 使用Markdown无法预览(已解决)

目录 概述 解决方法 概述 本人升级Android studio 当前版本为Android Studio Jellyfish | 2023.3.1 RC 2导致Markdown无法预览。 我尝试了很多网上的方法都无法Markdown解决预览问题,包括升级插件、安装各种和Markdown相关的插件及使用“Choose Boot Java Runtim…

一文了解OCI标准、runC、docker、contianerd、CRI的关系

docker和contanerd都是流行的容器运行时(container runtime);想讲清楚他们两之间的关系,让我们先从runC和OCI规范说起。 一、OCI标准和runC 1、OCI(open container initiative) OCI是容器标准化组织为了…

利用动态规划优化10年投资回报:策略、证明与算法分析

利用动态规划优化10年投资回报:策略、证明与算法分析 a. 存在最优投资策略的证明b. 最优子结构性质的证明c. 最优投资策略规划算法设计d. 新限制条款下最优子结构性质的证明 在面对投资策略规划问题时,我们的目标是在10年后获得最大的回报。Amalgamated投…

牛客 NC205 跳跃游戏(三)【中等 贪心 Java,Go,PHP】

题目 题目链接: https://www.nowcoder.com/practice/14abdfaf0ec4419cbc722decc709938b 思路 参考答案Java import java.util.*;public class Solution {/*** 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可*** …

Go 单元测试之Mysql数据库集成测试

文章目录 一、 sqlmock介绍二、安装三、基本用法四、一个小案例五、Gorm 初始化注意点 一、 sqlmock介绍 sqlmock 是一个用于测试数据库交互的 Go 模拟库。它可以模拟 SQL 查询、插入、更新等操作,并且可以验证 SQL 语句的执行情况,非常适合用于单元测试…

基于SpringBoot+Vue社区医院服务平台(源码+文档+包运行)

一.系统概述 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了社区医院信息平台的开发全过程。通过分析社区医院信息平台管理的不足,创建了一个计算机管理社区医院信息平台的方案。文章介绍了社区医院信息…

如何在Linux CentOS部署宝塔面板并实现固定公网地址访问内网宝塔

文章目录 一、使用官网一键安装命令安装宝塔二、简单配置宝塔,内网穿透三、使用固定公网地址访问宝塔 宝塔面板作为建站运维工具,适合新手,简单好用。当我们在家里/公司搭建了宝塔,没有公网IP,但是想要在外也可以访问内…

QAnything部署Mac m1环境

本次安装时Qanything已经更新到了v1.3.3,支持纯python安装。安装过程比较简单,如下: QAnything/README_zh.md at qanything-python-v1.3.1 netease-youdao/QAnything GitHub 首先需要用Anaconda3创建隔离环境,简要说明下Anaco…

春藤实业启动SAP S/4HANA Cloud Public Edition项目,与工博科技携手数字化转型之路

3月11日,广东省春藤实业有限公司(以下简称“春藤实业”)SAP S/4HANA Cloud Public Edition(以下简称“SAP ERP公有云”)项目正式启动。春藤实业董事长陈董、联络协调项目经理慕总、内部推行项目经理陈总以及工博董事长…

酒店水电能源计量管理系统

酒店水电能源计量管理系统是一种针对酒店行业设计的能源管理系统,旨在实现对水电能源的计量、监测和管理。本文将从系统特点、构成以及带来的效益三个方面展开介绍。 系统特点 1.多元化计量:该系统能够对酒店内的水、电能源进行多元化计量,…

软件项目总体测试计划(Word原件2024)

一、 前言 (一) 背景 (二) 目的 (三) 测试目标 (四) 适用范围与读者对象 (五) 术语与缩写 二、 软件测试实施流程 (一) 测试工作总体流…

2024年MathorCup数学建模C题物流网络分拣中心货量预测及人员排班解题文档与程序

2024年第十四届MathorCup高校数学建模挑战赛 C题 物流网络分拣中心货量预测及人员排班 原题再现: 电商物流网络在订单履约中由多个环节组成,图1是一个简化的物流网络示意图。其中,分拣中心作为网络的中间环节,需要将包按照不同流…

【Python基础】MySQL

文章目录 [toc]创建数据库创建数据表数据插入数据查询数据更新 个人主页:丷从心 系列专栏:Python基础 学习指南:Python学习指南 创建数据库 import pymysqldef create_database():db pymysql.connect(hostlocalhost, userroot, passwordr…

Maven多模块管理

Maven多模块管理 在了解怎么进行Maven多模块管理之前,先聊聊为什么要进行Maven多模块管理 为什么要Maven多模块管理? 在传统的单体架构开发下,一个项目中的依赖只需要使用一个pom.xml文件管理即可。但是随着微服务的流行,将原有…