torch.nn.Embedding
flyfish
此模块通常用于存储单词嵌入并使用索引检索它们。模块的输入是索引列表,输出是对应的单词嵌入。
import torch
import torch.nn as nn
torch.manual_seed(0)
embedding = nn.Embedding(10, 3)  # an Embedding module containing 10 tensors of size 3
print(embedding.weight)
# tensor([[-0.7588, -0.0094, -0.8549],
#         [-1.9320, -0.1008,  1.1125],
#         [-0.7327,  0.5621,  0.2356],
#         [-1.6812, -0.2477,  0.1624],
#         [ 0.5170,  0.0979, -0.3463],
#         [ 0.4478, -1.3857,  1.8448],
#         [-1.2102, -0.5387, -1.8527],
#         [-0.1349, -0.6765, -2.4845],
#         [-1.5077,  0.4549, -0.9425],
#         [ 1.9715,  0.9959,  0.0415]], requires_grad=True)
#10是embedding的大小
input = torch.LongTensor([[1, 2, 4, 5]])  # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)
#从索引01 是[-1.9320, -0.1008,  1.1125],
2 是[-0.7327,  0.5621,  0.2356],
4 是[ 0.5170,  0.0979, -0.3463],
5 是[ 0.4478, -1.3857,  1.8448]# tensor([[[-1.9320, -0.1008,  1.1125],
#          [-0.7327,  0.5621,  0.2356],
#          [ 0.5170,  0.0979, -0.3463],
#          [ 0.4478, -1.3857,  1.8448]]], grad_fn=<EmbeddingBackward0>)
embedding.weight的值是哪来的呢
是通过nn.init.normal_来的,使用从正态分布中提取的值填充输入张量。
 将种子固定后,会得到相同的数值
torch.manual_seed(0)
w = torch.empty(10, 3)
print(nn.init.normal_(w,0,1))
参考
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html