前一篇文章,Tensor 基本操作3 理解 shape, stride, storage, view,is_contiguous 和 reshape 操作 | PyTorch 深度学习实战
本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started
Tensor 基本使用
- 索引 indexing
- 示例代码
 
- 加减乘除
- 加法和减法
- 乘法和除法
 
- broadcasting 机制
- 更多运算
- Links
索引 indexing
Tensor 的索引类似于 Python List 的索引和分片。
比如一个 AxBxC 的三个维度的 Tensor a。
a[第0维的分片, 第1维的分片, 第2维的分片]
分片的语法和 Python List 分片语法一致,开始:结束:步进。
更多索引的高级语法介绍。
示例代码
    print("*" * 8, " a")a = torch.randn(5,4,3)print(a)print("*" * 8, " b")b = a[1,]     # 只要第 0 维的第一个成员print(b)print("*" * 8, " c")c = a[1:]   # 第 0 维从第一个成员开始都要,注意:这里索引从 0 开始print(c)print("*" * 8, " d")d = a[1:, 1] # 第 0 维从第一个成员开始都要,第二维只要第一个成员print(d)
Result
********  a
tensor([[[ 0.1874, -0.0980, -0.3815],[-0.8175,  1.5976, -1.4927],[-0.1507,  1.1806, -0.3685],[ 1.1583,  0.9419, -0.5540]],[[ 1.3078, -1.4250, -1.5981],[-0.0756,  2.0776,  0.7708],[ 1.6020, -1.9133,  1.2459],[-0.2817, -0.7238, -0.5413]],[[-0.8057, -0.4368, -1.2398],[ 0.8415,  1.7679,  0.6469],[ 0.7046, -0.4872,  1.1219],[-2.5866, -0.1263,  2.0684]],[[ 1.8756,  1.4231, -1.2082],[ 0.2111,  0.5244,  2.2242],[-0.9658, -1.3731, -0.9126],[-0.3850, -0.7273, -0.0519]],[[ 0.7949,  2.2807, -0.8793],[ 0.4037,  1.2422, -0.2393],[ 0.4786,  0.6107,  1.4225],[ 0.6104,  1.2682, -0.0801]]])
********  b = a[1,]
tensor([[ 1.3078, -1.4250, -1.5981],[-0.0756,  2.0776,  0.7708],[ 1.6020, -1.9133,  1.2459],[-0.2817, -0.7238, -0.5413]])
********  c = a[1:]
tensor([[[ 1.3078, -1.4250, -1.5981],[-0.0756,  2.0776,  0.7708],[ 1.6020, -1.9133,  1.2459],[-0.2817, -0.7238, -0.5413]],[[-0.8057, -0.4368, -1.2398],[ 0.8415,  1.7679,  0.6469],[ 0.7046, -0.4872,  1.1219],[-2.5866, -0.1263,  2.0684]],[[ 1.8756,  1.4231, -1.2082],[ 0.2111,  0.5244,  2.2242],[-0.9658, -1.3731, -0.9126],[-0.3850, -0.7273, -0.0519]],[[ 0.7949,  2.2807, -0.8793],[ 0.4037,  1.2422, -0.2393],[ 0.4786,  0.6107,  1.4225],[ 0.6104,  1.2682, -0.0801]]])
********  d = a[1:, 1]
tensor([[-0.0756,  2.0776,  0.7708],[ 0.8415,  1.7679,  0.6469],[ 0.2111,  0.5244,  2.2242],[ 0.4037,  1.2422, -0.2393]])
加减乘除
加法和减法
import torch# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3, 4)
b = torch.rand(4)c1 = a + b
c2 = torch.add(a, b)
print(c1.shape, c2.shape)
print(torch.all(torch.eq(c1, c2)))
乘法和除法
*, torch.mul, torch.mm, torch.matmul
参考: torch.Tensor的4种乘法
除法可以用乘法 API 完成。
broadcasting 机制
在 Tensor 的加减运算中,当两个 tensor 不能直接符合数学的运算规则时,PyTorch 会先尝试将 tensor 进行变换,再进行计算,这个变换的规则就是:broadcasting。
 
更多 broadcasting 机制的介绍。
更多运算
更多加法和其他运算,参考Pytorch Tensor基本数学运算:
- 减法运算
- 哈达玛积(对应元素相乘,也称为 element wise)
- 除法运算
- 幂运算
- 开方运算
- 指数与对数运算
- 近似值运算
- 裁剪运算
Links
- Tensor Broadcasting under the hood
- Mastering PyTorch Indexing: Simple Techniques with Practical Examples
- torch.Tensor的4种乘法
- Pytorch Tensor基本数学运算