拼接: cat, stack …
-
使用 cat 在
指定维度dim上拼接:torch.cat(element_list, dim)>>> a = torch.rand(2,3) >>> b = torch.rand(1,3) >>> c = torch.cat([a,b], dim=0) >>> c.shape torch.Size([3, 3]) -
使用 stack 在
新增维度dim上拼接:torch.cat(element_list, dim),- 注:
element_list中 element 的 shape 必须完全一致
>>> a = torch.rand(2,3) >>> b = torch.rand(2,3) >>> c = torch.stack([a,b], dim=0) >>> c.shape torch.Size([2, 2, 3]) - 注:
拆分:split,chunk …
- 使用 split 根据
长度拆分:a.split(l, dim)- 注:长度不一样时:
a.split(l_list, dim)
>>> a.split(1, dim=0) # 或 a.split([1,1], dim=0) (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]])) - 注:长度不一样时:
- 使用 chunk根据
数量拆分:a.chunk(n, dim)>>> a.chunk(2, dim=0) (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
- B站视频参考资料