Pytorch基础:torch.expand() 和 torch.repeat()

在torch中,如果要改变某一个tensor的维度,可以利用viewexpandrepeattransposepermute等方法,这里对这些方法的一些容易混淆的地方做个总结。

expand和repeat函数是pytorch中常用于进行张量数据复制维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1. torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上

参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1

expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。

import torcha = torch.tensor([1, 0, 2])     # a -> torch.Size([3])
b1 = a.expand(2, -1)            # 第一个维度为升维,第二个维度保持原样
'''
b1为 -> torch.Size([3, 2])
tensor([[1, 0, 2],[1, 0, 2]])
'''a = torch.tensor([[1], [0], [2]])   # a -> torch.Size([3, 1])
b2 = a.expand(-1, 2)                 # 保持第一个维度,第二个维度只有一个元素,可扩展
'''
b2 -> torch.Size([3, 2])
b2为  tensor([[1, 1],[0, 0],[2, 2]])
'''a = torch.Tensor([[1, 2, 3]])   # a -> torch.Size([1, 3])
b3 = a.expand(4, 3)              # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,# 可以在该维度上进行tensor的复制,若大于1则不行
'''
b3 -> torch.Size([4, 3])
tensor([[1.,2.,3.],[1.,2.,3.],[1.,2.,3.],[1.,2.,3.]]
)
'''a = torch.Tensor([[1, 2, 3], [4, 5, 6]])  # a -> torch.Size([2, 3])
b4 = a.expand(4, 6)  # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''b5 = a.expand(1, 2, 3)  # 可以在tensor的低维增加更多维度
'''
b5 -> torch.Size([1,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]]]
)
'''
b6 = a.expand(2, 2, 3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
b5 -> torch.Size([2,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''b7 = a.expand(2, 3, 2)  # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''b8 = a.expand(2, -1, -1)  # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
b8 -> torch.Size([2,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''# expand返回的张量与原版张量具有相同内存地址
print(b8.storage())  # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

沿着特定维度扩展张量,并返回扩展后的张量

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变
import torchif __name__ == '__main__':x = torch.rand(2, 3)y1 = x.repeat(4, 2)print(y1.shape)  # torch.Size([8, 6])

3. 两者内存占用的区别

  • torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图

  • torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间

示例如下:

import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)y2 = x.repeat(2, 3)print(x.storage().data_ptr(), y1.storage().data_ptr())  # 52364352 52364352print(x.storage().data_ptr(), y2.storage().data_ptr())  # 52364352 8852096

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/9347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AcWing 835:Trie字符串统计 ← 字典树(Trie树)模板题

【题目来源】https://www.acwing.com/problem/content/837/【题目描述】 维护一个字符串集合,支持两种操作: ● I x 向集合中插入一个字符串 x; ● Q x 询问一个字符串在集合中出现了多少次。 共有 N 个操作,所有输入的字符…

如何在Python中实现文本相似度比较?

在Python中实现文本相似度比较可以通过多种方法,每种方法都有其适用场景和优缺点。以下是一些常见的文本相似度比较方法: 1. 余弦相似度(Cosine Similarity) 余弦相似度是通过计算两个向量之间夹角的余弦值来确定它们之间的相似…

Linux sudo 指令

sudo命令 概念: sudo是linux下常用的允许普通用户使用超级用户权限的工具,允许系统管理员让普通用户执行一些或者全部的root命令,如halt,reboot,su等。这样不仅减少了root用户的登录和管理时间,同样也提高…

解决Windows下共享文件夹的再三访问失败问题

今天遇到一个奇葩的情况, 明明局域网的两台电脑可以ping通, 明明共享文件夹的Everyone权限设置妥当, 明明高级网络共享中的文件共享,密码保护都已经开启了, 明明防火墙也关闭了, 明明输入连接地址后也…

22、Flink 背压下的 Checkpoint处理

1.概述 通常,对齐 Checkpoint 的时长主要受 Checkpointing 过程中的同步和异步两个部分的影响;但当 Flink 作业正运行在严重的背压下时,Checkpoint 端到端延迟的主要影响因子将会是传递 Checkpoint Barrier 到 所有的算子/子任务的时间&…

线程与进程

进程 进程是程序的一次执行过程,系统程序的基本单位。有自己的main方法,并且主要由主方法运行起来的基本上就是进程。 线程 线程与进程相似,但线程是一个比进程更小的执行单位。一个进程在其执行的过程中可以产生多个线程。与进程不同的是…

vty、带内/带外管理、带内/带外ip简介

1、vty是什么? (virtual teletype)虚拟电传打字机 漫谈VTY (qq.com) 视频链接:使用20世纪30年代的电传打字机作为Linux系统的终端 https://hackaday.com/2020/04/15/logging-into-linux-with-a-1930s-teletype/ 2、console端口和vty端口的区别&#xf…

局域网语音对讲系统_IP广播对讲系统停车场解决方案

局域网语音对讲系统_IP广播对讲系统停车场解决方案 需求分析: 随着国民经济和社会的发展, 选择坐车出行的民众越来越多。在保护交通安全的同时,也给停车场服务部门提出了更高的要求。人们对停车场系统提出了更高的要求与挑战, 需要…

部分设计模式概述

单例模式 工厂模式 适配器模式 模板方法模式 策略模式 责任链 观察者模式(又叫发布订阅模式)

容器化Jenkins远程发布java应用(方式一:pipline+ssh)

1.创建pipline工程 2.准备工程Jenkinsfile文件(java目录) 1.文件脚本内容 env.fileName "planetflix-app.jar" env.configName "planetflix_prod" env.remoteDirectory "/data/project/java" env.sourceFile "/…

LangChain基本概念

概述 LangChain是一个为大型语言模型(LLMs)应用程序开发设计的框架,它包含一些关键的基本概念,理解这些概念对于使用LangChain至关重要: 组件(Components): LangChain的构建块&…

Leetcode - 周赛396

目录 一,3136. 有效单词 二,3137. K 周期字符串需要的最少操作次数 三,3138. 同位字符串连接的最小长度 四,3139. 使数组中所有元素相等的最小开销 一,3136. 有效单词 本题就是一道阅读理解题: 字符串长…

NSSCTF | [SWPUCTF 2021 新生赛]easy_sql

打开题目,提示输入一些东西,很显眼的可以看到网站标题为“参数是wllm” 首先单引号判断闭合方式 ?wllm1 报错了,可以判断为单引号闭合。 然后判断字节数(注意‘--’后面的空格) ?wllm1 order by 3-- 接着输入4就…

Linux呈现数据

目录 一、理解输入和输出 1.1 标准文件描述符 1.1.1 STDIN 1.1.2 STDOUT 1.1.3 STDERR 1.2 重定向错误 二、在脚本中重定向输出 2.1 临时重定向 2.2 永久重定向 三、在脚本中重定向输入 四、创建自己的重定向 4.1 创建输出文件描述符 4.2 重定向文件描述符 4.3 关…

vue多选功能

废话不多说&#xff0c;直接上代码&#xff01;&#xff01;&#xff01; <template><div class"duo-xuan-page"><liv-for"(item, index) in list":key"index"click"toggleSelection(item)":class"{ active: sel…

ue引擎游戏开发笔记(36)——为射击落点添加特效

1.需求分析&#xff1a; 在debug测试中能看到子弹落点后&#xff0c;需要给子弹添加击中特效&#xff0c;更真实也更具反馈感。 2.操作实现&#xff1a; 1.思路&#xff1a;很简单&#xff0c;类似开枪特效一样&#xff0c;只要在头文件声明特效变量&#xff0c;在fire函数中…

Leetcode—232. 用栈实现队列【简单】

2024每日刷题&#xff08;131&#xff09; Leetcode—232. 用栈实现队列 实现代码 class MyQueue { public:MyQueue() {}void push(int x) {st.push(x);}int pop() {if(show.empty()) {if(empty()) {return -1;} else {int ans show.top();show.pop();return ans;}} else {i…

Burp Suite 抓包,浏览器提示有软件正在阻止Firefox安全地连接到此网站

问题现象 有软件正在阻止Firefox安全地连接到此网站 解决办法 没有安装证书&#xff0c;在浏览器里面安装bp的证书就可以了 参考&#xff1a;教程合集 《H01-启动和激活Burp.docx》——第5步

金蝶BI应收分析报表:关于应收,这样分析

这是一张出自奥威-金蝶BI方案的BI应收分析报表&#xff0c;是一张综合运用了筛选、内存计算等智能分析功能以及数据可视化图表打造而成的BI数据可视化分析报表&#xff0c;可以让企业运用决策层快速知道应收账款有多少&#xff1f;账龄如何&#xff1f;周转情况如何&#xff1f…

ARTS Week 27

Algorithm 本周的算法题为 58. 最后一个单词的长度 给你一个字符串 s&#xff0c;由若干单词组成&#xff0c;单词前后用一些空格字符隔开。返回字符串中 最后一个 单词的长度。 单词 是指仅由字母组成、不包含任何空格字符的最大子字符串。 示例 1&#xff1a;输入&#xff1a…