torch.nn.embedding()

作者:top_小酱油
链接:https://www.jianshu.com/p/63e7acc5e890
来源:简书
内容:上述是以RNN为基础解析的

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None)

意义

该方法定义了一个简单的存储固定大小的词典的嵌入向量的查找表,意思就是说,这就是一个词典,里面包含了各个单词的向量,如果要访问该查找表,需要你给定一个编号,嵌入层就能返回这个编号对应的嵌入向量,嵌入向量反映了各个编号代表的符号之间的语义关系。

注意这里是定义的查找表! 如果要访问还需要后面的操作! 也就是先定义后操作!

当在访问该查找表的时候:
输入为一个编号列表,输出为对应的符号嵌入向量列表

shape

  • input:(∗),包含提取的编号的intTensor或任意形状的LongTensor。
  • output:(∗,H),其中 * 是输入形状,H = embeddding_dim

参数:

  • Num_embeddings (int) -词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999)
  • embeddding_dim (int) - 嵌入向量的维度,即用多少维来表示一个符号(单词)。
  • padding_idx (int, optional) -填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0)。
  • max_norm(float,optional)-最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。
  • norm_type(float,optional)-指定利用什么范数计算,并用于对比max_norm,默认为2范数。
  • scale_grad_by_freq(boolean,optional)—— 根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False.
  • sparse (bool,optional)-若为True,则与权重矩阵相关的梯度转变为稀疏张量。

变量:

~Embedding.weight (Tensor) –形状为 (num_embeddings, embedding_dim) 的学习权重采用标准正态分布N(0,1)进行初始化

NOTE 1:

请记住,只有有限数量的优化器支持稀疏梯度:目前是optim.SGD(CUDA和CPU), optim.SparseAdam (CUDA和CPU)和optim.Adagrad (CPU)

NOTE 2:

当max_norm≠None时,Embedding的前向方法将就地修改权重张量。由于梯度计算所需的张量不能被就地修改,所以在Embedding上执行可微运算。在调用Embedding的前向方法之前,需要克隆Embedding。max_norm不是None时的权重。例如:

n, d, m = 3, 5, 7
embedding = nn.Embedding(n, d, max_norm=True)
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight.clone() @ W.t()  # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t()  # modifies weight in-place(就地修改权重)
out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()

整个举例:

实际上,Embedding通过随机初始化建立了词向量层后,建立了一个“二维表”,存储了词典中每个词的词向量。每个mini-batch的训练,都要从词向量表找到mini-batch对应的单词的词向量作为模型的输入放进网络。

>>> # an Embedding module 包含了10个张量,每个张量的大小为3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch 含有两个样本,每个样本长度为4,也就是四个索引,索引是词典中的index序号
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],[-0.6431,  0.0748,  0.6969],[ 1.4970,  1.3448, -0.9685],[-0.3677, -2.7265, -0.1685]],[[ 1.4970,  1.3448, -0.9685],[ 0.4362, -0.4004,  0.9400],[-0.6431,  0.0748,  0.6969],[ 0.9124, -2.3616,  1.1151]]])>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],[ 0.1535, -2.0309,  0.9315],[ 0.0000,  0.0000,  0.0000],[-0.1655,  0.9897,  0.0635]]])>>> # example of changing `pad` vector
>>> padding_idx = 0
>>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
>>> embedding.weight
Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],[-0.7895, -0.7089, -0.0364],[ 0.6778,  0.5803,  0.2678]], requires_grad=True)
>>> with torch.no_grad():
...     embedding.weight[padding_idx] = torch.ones(3)
>>> embedding.weight
Parameter containing:
tensor([[ 1.0000,  1.0000,  1.0000],[-0.7895, -0.7089, -0.0364],[ 0.6778,  0.5803,  0.2678]], requires_grad=True)

一些注意的点

  • nn.embedding的输入只能是编号,不能是隐藏变量,比如one-hot,或者其它,这种情况,可以自己建一个自定义维度的线性网络层,参数训练可以单独训练或者跟随整个网络一起训练(看实验需要)
  • 如果你指定了padding_idx,注意这个padding_idx也是在num_embeddings尺寸内的,比如符号总共有500个,指定了padding_idx,那么num_embeddings应该为501
  • embedding_dim的选择要注意,根据自己的符号数量,举个例子,如果你的词典尺寸是1024,那么极限压缩(用二进制表示)也需要10维,再考虑词性之间的相关性,怎么也要在15-20维左右,虽然embedding是用来降维的,但是>- 也要注意这种极限维度,结合实际情况,合理定义

类方法(ClassMethod)

.from_pretrained(embeddings, freeze=True, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

和上面的使用是类似的,也是定义一个查找表!
从给定的2维FloatTensor创建嵌入实例。

参数

  • embeddings (Tensor)-包含嵌入权值的浮点数。第一个维度作为num_embeddings传递给Embedding,第二个维度作为embeding_dim。
  • freeze (boolean,optional)-如果为True,张量不会在学习过程中得到更新。相当于embedding.weight。requires_grad = False。默认值:真正的
  • padding_idx (int, optional) -如果指定了,则padding_idx上的项不会影响梯度;因此,在训练过程中,padding_idx处的嵌入向量并没有被更新,即它仍然是一个固定的“pad”。
  • max_norm (float,optional)-参见模块初始化文档。
  • norm_type (float,optional)——请参阅模块初始化文档。默认2。
  • scale_grad_by_freq (boolean,optional)-参见模块初始化文档。默认的错误。
  • sparse (bool,optional)-参见模块初始化文档。

举例

>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000,  5.1000,  6.3000]])

实际代码举例

if pretrained_word_embedding is None:  #如果没有预训练词典,那么就用第一个self.word_embedding = nn.Embedding(config.num_words,config.word_embedding_dim,padding_idx=0)
else:  #我们就用预训练的词典self.word_embedding = nn.Embedding.from_pretrained(pretrained_word_embedding, freeze=False, padding_idx=0)
if pretrained_entity_embedding is None:self.entity_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)
else: #实体的嵌入也是一样的,同上self.entity_embedding = nn.Embedding.from_pretrained(pretrained_entity_embedding, freeze=False, padding_idx=0)
if config.use_context:   #上下文嵌入也是一样的,同上if pretrained_context_embedding is None:self.context_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.context_embedding = nn.Embedding.from_pretrained(pretrained_context_embedding, freeze=False, padding_idx=0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/476211.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

oracle杀死进程时权限不足_在oracle中创建函数时权限不足

我对oracle有一点了解。我试图创建一个如下所示的函数。在oracle中创建函数时权限不足CREATE OR REPLACE FUNCTION "BOOK"."CONVERT_TO_WORD" (totpayable IN NUMBER) RETURN VARCHARAStotlength NUMBER;num VARCHAR2(14);word VARCHAR2(70);word1 VARCHAR…

哇塞,打开一个页面访问了这么多次数据库??

用SQL Server 事件探查器看了一下,哇塞,每打开一个页面都select了n多次数据库,而且很多都是类似的代码?为啥? (1)、二级嵌套绑定数据源 (2)、二级联动 (3)、……多着呢! 解决方法: 对于数据不大…

torch.nn

torch.nn 与 torch.nn.functional 说起torch.nn,不得不说torch.nn.functional! 这两个库很类似,都涵盖了神经网络的各层操作,只是用法有点不同,比如在损失函数Loss中实现交叉熵! 但是两个库都可以实现神经网络的各层运算。其他包…

ORACLE使用JOB定时备份数据库

Oracle的备份一般都是在操作系统上完成,因此定时备份Oracle的功能一般都是由操作系统功能完成,比如crontab。但是Oracle的PIPE接口使得在Oracle数据库中通过JOB来备份Oracle变得可能。 这篇文章给出一个简单的例子,说明如何在JOB中定期备份数…

mysql 装载dump文件_mysql命令、mysqldump命令找不到解决

1、解决bash: mysql: command not found 的方法[rootDB-02 ~]# mysql -u root-bash: mysql: command not found原因:这是由于系统默认会查找/usr/bin下的命令,如果这个命令不在这个目录下,当然会找不到命令,我们需要做的就是映射一个链接到/u…

LeetCode 796. 旋转字符串

1. 题目 给定两个字符串, A 和 B。 A 的旋转操作就是将 A 最左边的字符移动到最右边。 例如, 若 A ‘abcde’,在移动一次之后结果就是’bcdea’ 。如果在若干次旋转操作之后,A 能变成B,那么返回True。 示例 1: 输入: A abcde, B cdeab …

【DKN】(一)KCN详解

_ init _()函数 参数: self, config, pretrained_word_embedding, pretrained_entity_embedding, pretrained_context_embedding config: 设置的固定的参数! pretrained_word_embedding: 根据下面的使用是…

搜索引擎优化经验谈

转自:http://blog.donews.com/zszwyds/archive/2009/08/24/1551179.aspx 费话少说,直入正题。 1. “白马非马”的关键字(词) 很多客户对于自己网站的关键词无从下手,大部分的客户选择都是大而全的关键词,很多的关键词如果选择…

iphone版 天行skyline_Skyline QT

应用标题Skyline QT应用描述An information and feedback gathering tool for our Skyline Queenstown visitor to discover the complex and its array of activities and food and beverage outlets.Welcome to the world of SkylineAre you looking for things to do in New…

LeetCode 788. 旋转数字

1. 题目 我们称一个数 X 为好数, 如果它的每位数字逐个地被旋转 180 度后,我们仍可以得到一个有效的,且和 X 不同的数。要求每位数字都要被旋转。 如果一个数的每位数字被旋转以后仍然还是一个数字, 则这个数是有效的。 0, 1, 和 8 被旋转后…

pycharm中无法识别相对路径的问题

这种情况如果在Windows下操作如下: 第一步: 往往拷贝下来的程序是在linux上运行的 第二步: 设置根路径 要调整有python.exe文件的地方! 这两个路径要设置成为自己的项目根目录!

vue变量传值_Vue各类组件之间传值的实现方式

1、父组件向子组件传值首先在父组件定义好数据,接着将子组件导入到父组件中。父组件只要在调用子组件的地方使用v-bind指令定义一个属性,并传值在该属性中即可,此时父组件的使命完成,请看下面关键代码::content"i…

Linux常用指令自己备用

~ 和 / 的区别: ~ 是当前用户的目录地址 / 是根目录的地址(一般称呼为root,/ 和 /root/ 是有区别的) 当用户是root用户时 ~ 代表/root/,即根目录下的root目录 / 代表/ ,即根目录 当用户是jack用户时 ~…

『号外』 排名进入3000,特致感谢!

开博半个月来,老孙项目管理成功地闯入了博客园3000名!! 谢谢博客园的朋友们!非常感谢!!“老孙项目管理”今日排名2975。这样的成绩,老孙没有预料到,开心极了。比奥巴马当选总统&…

qt如和调用linux底层驱动_擅长复杂硬件体系设计,多核系统设计,以及基于RTOS或者Linux,QT等进行相关底层驱动。...

双向可控硅在使用时,其触发限流电阻的阻值和封装应该怎么选取?(1)首先我们在进行TRIAC其驱动电路设计的时候,我们一般不直接进行驱动,而是通过DIAC或者Photo-TRIAC即光学的双向可控硅配合来使用进行驱动电路的设计,为什…

学习:Web安装项目创建桌面快捷方式及重写安装类(转)

一、WEB安装项目部署1、新建: 新建项目-安装和部署项目-WEB安装项目 2、部署: (1)进入文件系统视图,"项目-右键-视图-文件系统";也可以直接点"解决方案资源管理器"上部的快捷图标(2)在"WEB应用程序文件夹"添加文件,例如aspx文件,ico文…

12c oracle 激活_Oracle 12C 安装教程

Oracle 12c,全称Oracle Database 12c,是Oracle 11g的升级版,新增了很多新的特性。本章节就为大家介绍Oracle 12c的下载和安装步骤。Oracle 12c下载打开Oracle的官方中文网站,选择相应的版本即可。注意:下载时&#xff…

运行试错合集

试错: 在服务器训练好的参数直接被pycharm映射给覆盖了! 记得把这里取消掉! 如果在py文件中修改了代码,手动上传! 就是上面的upload! 运行结果: 运行train的结果 评估阶段: 出错…

LeetCode 806. 写字符串需要的行数

1. 题目 我们要把给定的字符串 S 从左到右写到每一行上,每一行的最大宽度为100个单位,如果我们在写某个字母的时候会使这行超过了100 个单位,那么我们应该把这个字母写到下一行。 我们给定了一个数组 widths ,这个数组 widths[0…

【转载】揭开硬件中断请求IRQ所有秘密(图解)

转载自:http://news.csdn.net/n/20040517/45868.html IRQ(Interrupt Request)的作用就是在我们所用的电脑中,执行硬件中断请求的动作,用来停止其相关硬件的工作状态。比如我们要打印一份文件,在打印结束时就需要由系统对打印机提出…