PyTorch 简单易懂的 Embedding 和 EmbeddingBag - 解析与实践

目录

torch.nn子模块Sparse Layers详解

nn.Embedding

用途

主要参数

注意事项

使用示例

从预训练权重创建嵌入

nn.EmbeddingBag

功能和用途

主要参数

使用示例

从预训练权重创建

总结


torch.nn子模块Sparse Layers详解

nn.Embedding

torch.nn.Embedding 是 PyTorch 中一个重要的模块,用于创建一个简单的查找表,它存储固定字典和大小的嵌入(embeddings)。这个模块通常用于存储单词嵌入并使用索引检索它们。接下来,我将详细解释 Embedding 模块的用途、用法、特点以及如何使用它。

用途

  • 单词嵌入:在自然语言处理中,Embedding 模块用于将单词(或其他类型的标记)映射到一个高维空间,其中相似的单词在嵌入空间中彼此靠近。
  • 特征表示:在非自然语言处理任务中,嵌入可以用于任何类型的分类特征的密集表示。

主要参数

  • num_embeddings(int):嵌入字典的大小。
  • embedding_dim(int):每个嵌入向量的大小。
  • padding_idx(int,可选):如果指定,padding_idx 处的嵌入不会在训练中更新。
  • max_norm(float,可选):如果指定,将重新归一化超过此范数的嵌入向量。
  • norm_type(float,可选):用于max_norm选项的p-范数的p值,默认为2。
  • scale_grad_by_freq(bool,可选):如果为True,将按单词在批次中的频率的倒数来缩放梯度。
  • sparse(bool,可选):如果为True,权重矩阵的梯度将是一个稀疏张量。

注意事项

  • 当使用max_norm参数时,Embedding的前向方法会就地修改权重张量。如果需要对Embedding.weight进行梯度计算,则在调用前向方法前,需要在max_norm不为None时克隆它。
  • 仅有少数优化器支持稀疏梯度。

使用示例

import torch
import torch.nn as nn# 创建一个包含10个大小为3的嵌入的Embedding模块
embedding = nn.Embedding(10, 3)# 一个包含4个索引的2个样本的批次
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])# 通过Embedding模块获取嵌入
output = embedding(input)

此示例创建了一个嵌入字典大小为10、每个嵌入维度为3的 Embedding 模块。然后它接受一个包含索引的输入张量,并返回对应的嵌入向量。

从预训练权重创建嵌入

还可以使用from_pretrained类方法从预先训练的权重创建Embedding实例:

# 预训练的权重
weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])# 从预训练权重创建Embedding
embedding = nn.Embedding.from_pretrained(weight)# 获取索引1的嵌入
input = torch.LongTensor([1])
output = embedding(input)

在这个示例中,Embedding 模块是从一个给定的预训练权重张量创建的。这种方法在迁移学习或使用预先训练好的嵌入时非常有用。

nn.EmbeddingBag

torch.nn.EmbeddingBag 是 PyTorch 中一个高效的模块,用于计算“bags”(即序列或集合)的嵌入的总和或平均值,而无需实例化中间的嵌入。这个模块特别适用于处理具有不同长度的序列,如在自然语言处理任务中处理不同长度的句子或文档。下面我将详细介绍 EmbeddingBag 的功能、用法以及特点。

功能和用途

  • 高效计算EmbeddingBag 直接计算整个包的总和或平均值,比逐个嵌入后再求和或取平均更加高效。
  • 支持不同聚合方式:可以选择 "sum", "mean" 或 "max" 模式来聚合每个包中的嵌入。
  • 支持加权聚合EmbeddingBag 还支持为每个样本指定权重,在 "sum" 模式下进行加权求和。

主要参数

  • num_embeddings(int):嵌入字典的大小。
  • embedding_dim(int):每个嵌入向量的大小。
  • max_norm(float,可选):如果给定,将重新规范化超过此范数的嵌入向量。
  • mode(str,可选):聚合模式,可以是 "sum"、"mean" 或 "max"。
  • sparse(bool,可选):如果为True,权重矩阵的梯度将是一个稀疏张量。
  • padding_idx(int,可选):如果指定,padding_idx 处的嵌入将不会在训练中更新。

使用示例

import torch
import torch.nn as nn# 创建一个包含10个大小为3的嵌入的EmbeddingBag模块
embedding_bag = nn.EmbeddingBag(10, 3, mode='mean')# 一个示例包含4个索引的输入
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)# 指定每个包的开始索引
offsets = torch.tensor([0, 4], dtype=torch.long)# 通过EmbeddingBag模块获取嵌入
output = embedding_bag(input, offsets)

在这个示例中,创建了一个嵌入字典大小为10、每个嵌入维度为3的 EmbeddingBag 模块,并设置为 "mean" 模式。输入是一个索引序列,offsets 指定了每个包的开始位置。EmbeddingBag 会计算每个包的平均嵌入向量。

从预训练权重创建

EmbeddingBag 也可以从预训练的权重创建:

# 预训练的权重
weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])# 从预训练权重创建EmbeddingBag
embedding_bag = nn.EmbeddingBag.from_pretrained(weight)# 获取索引1的嵌入
input = torch.LongTensor([[1, 0]])
output = embedding_bag(input)

 这种方法在需要使用预先训练好的嵌入或在迁移学习中非常有用。EmbeddingBag 通过高效地处理不同长度的序列数据,在自然语言处理等领域中发挥着重要作用。

总结

 本篇博客探讨了 PyTorch 中的 nn.Embeddingnn.EmbeddingBag 两个关键模块,它们是处理和表示离散数据特征的强大工具。nn.Embedding 提供了一种有效的方式来将单词或其他类型的标记映射到高维空间中,而 nn.EmbeddingBag 以其独特的方式处理变长序列,通过聚合嵌入来提高计算效率。这两个模块不仅在自然语言处理中发挥关键作用,也适用于其他需要稠密特征表示的任务。此外,这些模块支持从预训练权重初始化,使其在迁移学习和复杂模型训练中极为重要。综上所述,nn.Embeddingnn.EmbeddingBag 是理解和应用 PyTorch 中嵌入层的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/605419.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

06.函数和模块的使用

函数和模块的使用 在讲解本章节的内容之前,我们先来研究一道数学题,请说出下面的方程有多少组正整数解。 事实上,上面的问题等同于将8个苹果分成四组每组至少一个苹果有多少种方案。想到这一点问题的答案就呼之欲出了。 可以用Python的程序来…

基于ssm的大湾区旅游推荐系统的设计与实现+vue论文

摘 要 如今社会上各行各业,都喜欢用自己行业的专属软件工作,互联网发展到这个时候,人们已经发现离不开了互联网。新技术的产生,往往能解决一些老技术的弊端问题。因为传统大湾区旅游景点信息管理难度大,容错率低&…

美创科技葛宏彬:夯实安全基础,对医疗数据风险“逐个击破”

导读 解决医疗机构“临床业务数据合规流动”与“重要数据安全防护”两大难题。 2023年11月11日,在2023年南湖HIT论坛上,HIT专家网联合杭州美创科技股份有限公司(以下简称美创科技)发布《医疗数据安全风险分析及防范实践》白皮书…

完成python+neo4j+django踩坑记录

使用Django进行后端控制,Echarts进行前端显示 例子django安装1、django启动2、django初体验3、django踩坑【已解决】You have 18 unapplied migration(s). Your project may not work properly until you apply the migra【已解决】运行neo4j出现报错 Failed to sta…

回车事件怎样绑定?

首先要记住一点,回车事件只能在js代码中绑定,在HTML中绑定是获取不到的,下面是我自己写的一个示例,大家可以参考一下。 importt.onkeydown function(event) {let val document.getElementById(importt).value;let e event || …

目标检测COCO数据集与评价体系mAP

1.mAP 2.IoU IoU也就是交并比,也称为 Jaccard 指数,用于计算真实边界框与预测边界框之间的重叠程度。它是真值框与预测边界框的交集和并集之间的比值。Ground Truth边界框是测试集中手工标记的边界框,用于指定目标图像的位置以及预测的边界框…

浅讲人工智能,初识人工智能几个重要领域。

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

设置5台SSH互免的虚拟机服务器配置

搭建一套集群虚拟机,往往都需要互免设置,过程很简单,避免以后再搭建还得网上搜索,我直接将这一个步骤写成笔记,记录下来,方便后续查阅。 步骤如下—— 1、准备五台机器 服务器名字服务器IPhadoop1192.16…

昇腾多卡通信教程【配置网络检测对象IP】

无法通信会出现的错误如下 一、网络健康状态报错 命令原型 hccn_tool [-i %d] -netdetect -s [address %s]命令功能 本功能支持用户执行命令获取网络健康状态(本端与所配置的检测IP之间的连通状态),用户可指定上报的状态信息名称。 状态信…

Javaweb之Mybatis的XML配置文件的详细解析

2. Mybatis的XML配置文件 Mybatis的开发有两种方式: 注解 XML 2.1 XML配置文件规范 使用Mybatis的注解方式,主要是来完成一些简单的增删改查功能。如果需要实现复杂的SQL功能,建议使用XML来配置映射语句,也就是将SQL语句写在…

kotlin的抽象类和抽象方法

在 Kotlin 中,抽象类和抽象方法是面向对象编程中的概念,用于实现抽象和多态性。以下是有关 Kotlin 抽象类和抽象方法的详细信息: 抽象类: 定义: 抽象类是用 abstract 关键字声明的类,不能直接实例化。它可…

从零开始构建区块链:我的区块链开发之旅

1.引言 1.区块链技术的兴起和重要性 区块链技术,作为数字化时代的一项颠覆性创新,已经成为当今世界最令人瞩目的技术之一。自比特币的问世以来,区块链技术已经从仅仅支持加密货币发展成为一种具有广泛应用前景的分布式账本技术。其核心优势…

顺序表的实现(C语言)

本文章主要对顺序表的介绍以及数据结构的定义,以及几道相关例题,帮助大家更好理解顺序表. 文章目录 前言 一、顺序表的静态实现 二、顺序表的动态实现 三.定义打印顺序表函数 四.定义动态增加顺序表长度函数 五.创建顺序表并初始化 六.顺序表的按位查找 七.顺序表的按值…

如何下载 GOES(Geostationary Operational Environmental Satellite)卫星数据

GOES是指地球静止轨道卫星(Geostationary Operational Environmental Satellite)系统,它是美国国家海洋和大气管理局(NOAA)和美国国家航空航天局(NASA)合作开发和运营的一系列气象卫星。这些卫星…

如何编写高效的正则表达式?

正则表达式(Regular Expression,简称regex)是一种强大的文本处理技术,广泛应用于各种编程语言和工具中。本文将从多个方面介绍正则表达式的原理、应用和实践,帮助你掌握这一关键技术。 正则可视化 | 一个覆盖广泛主题…

19、Kubernetes核心技术 - 资源限制

目录 一、概述 二、Kubernetes 中的资源单位 2.1、CPU资源单位 2.2、内存资源单位 三、Pod资源限制 四、namespace资源限制 4.1、为命名空间配置内存和 CPU 配额 4.2、为命名空间配置默认的内存请求和限制 4.3、为命名空间配置默认的CPU请求和限制 五、超过容器限制的…

240107-RHEL8+RHEL9配置安装:NVIDIA驱动(15步)+CUDA(4步)+CUDNN(5步)+GPU压力测试

Section 0: 基础知识 CUDA、cuDNN 和 PyTorch 版本的选择与搭配指南 安装优先级: 显卡驱动 → CUDA → CUDA Toolkit → cuDNN → Pytorch 即显卡驱动决定了CUDA版本,CUDA版本决定了CUDA Toolkit、cuDNN、Pytorch各自的版本提前下载 | CUDA提前下载 &am…

超自动化助力企业财务转型升级

在快节奏的财务规划与分析环境中,传统的预算方法虽长期以来一直是企业制定有效决策的支柱,但已不足以驾驭当今复杂的商业环境。不断的经济变化、市场的不确定性以及利益相关者的需求增加促使企业寻求更敏捷的解决方案。如今,部分企业开始尝试…

lc 140. 单词拆分 II

回溯算法查询匹配单词 class Solution { public:unordered_map<string, int> word_map;void mapping(vector<string>& wordDict){for(auto &a : wordDict)word_map[a];}vector<string> ret;// s: 原始字符串// tmp: 已查询到的单词// …

CSS 彩虹按钮效果

<template><view class"content"><button class"btn">彩虹按钮</button></view> </template><script></script><style>body{background-color: #000;}.content {margin-top: 300px;}.btn {width: 1…