input.unsqueeze(0) 是 PyTorch 张量(Tensor)的方法之一,用于增加张量的维度。具体来说,它会在索引为 0 的位置上插入一个维度。
假设 input 是一个形状为 (n,) 的一维张量,其中 n 是任意长度。调用 unsqueeze(0) 后,它会返回一个形状为 (1, n) 的二维张量,新插入的维度的大小为 1。
以下是一个示例:
import torchinput = torch.tensor([1, 2, 3, 4])# 调用 unsqueeze(0) 增加维度
output = input.unsqueeze(0)print(input.shape) # 输出: torch.Size([4])
print(output.shape) # 输出: torch.Size([1, 4])
在上述示例中,input 是一个长度为 4 的一维张量。通过 unsqueeze(0) 将其转换为一个形状为 (1, 4) 的二维张量 output。新插入的维度位于索引 0 的位置。
unsqueeze(0) 的应用场景通常是在需要对张量进行运算或与其他张量进行操作时,需要调整张量的维度匹配。例如,将一维张量作为输入传递给大小为 (batch_size, ...) 的神经网络,就通常需要在维度上插入一个批次大小的维度。
需要注意的是,unsqueeze(0) 并不会在原地修改输入张量,而是返回一个新的张量。因此,我们在示例中将结果赋值给 output,以便进行打印输出。