PyTorch中的nn.Embedding应用详解

PyTorch


文章目录

  • PyTorch
  • 前言
  • 一、nn.Embedding的基本原理
  • 二、nn.Embedding的实际应用
    • 简单的例子
    • 自然语言处理任务


前言

在深度学习中,词嵌入(Word Embedding)是一种常见的技术,用于将离散的词汇或符号映射到连续的向量空间。这种映射使得相似的词汇在向量空间中具有相似的向量表示,从而可以捕捉词汇之间的语义关系。在PyTorch中,nn.Embedding模块提供了一种简单而高效的方式来实现词嵌入。

一、nn.Embedding的基本原理

nn.Embedding是一个存储固定大小的词典的嵌入向量的查找表。给定一个编号,嵌入层能够返回该编号对应的嵌入向量。这些嵌入向量反映了各个编号代表的符号之间的语义关系。在输入一个编号列表时,nn.Embedding会输出对应的符号嵌入向量列表。

在内部,nn.Embedding实际上是一个参数化的查找表,其中每一行都对应一个符号的嵌入向量。这些嵌入向量在训练过程中通过反向传播算法进行更新,以优化模型的性能。因此,nn.Embedding不仅可以用于降低数据的维度,减少计算和存储开销,还可以通过训练学习输入数据中的语义或结构信息。

二、nn.Embedding的实际应用

简单的例子

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 3)def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)# print(emb_vec1)  ### 输出对同一组词汇的编码output = torch.einsum('ik, kj -> ij', emb_vec1, vec)return output
def simple_train():model = Model()vec = torch.randn((3, 1))label = torch.Tensor(5, 1).fill_(3)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)print('初始化emebding参数权重:\n',model.emb.weight)for iter_num in range(100):output = model(vec)loss = loss_fun(output, label)opt.zero_grad()loss.backward(retain_graph=True)opt.step()# print('第{}次迭代emebding参数权重{}:\n'.format(iter_num, model.emb.weight))print('训练后emebding参数权重:\n',model.emb.weight)torch.save(model.state_dict(),'./embeding.pth')return modeldef simple_test():model = Model()ckpt = torch.load('./embeding.pth')model.load_state_dict(ckpt)model=model.eval()vec = torch.randn((3, 1))print('加载emebding参数权重:\n', model.emb.weight)for iter_num in range(100):output = model(vec)print('n次预测后emebding参数权重:\n', model.emb.weight)if __name__ == '__main__':simple_train()  # 训练与保存权重simple_test()

训练代码

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 10)def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)print(emb_vec1)  ### 输出对同一组词汇的编码output = torch.einsum('ik, kj -> ij', emb_vec1, vec)print(output)return outputdef simple_train():model = Model()vec = torch.randn((10, 1))label = torch.Tensor(5, 1).fill_(3)print(label)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)for iter_num in range(1):output = model(vec)loss = loss_fun(output, label)print('iter:%d loss:%.2f' % (iter_num, loss))opt.zero_grad()loss.backward(retain_graph=True)opt.step()if __name__ == '__main__':simple_train()

自然语言处理任务

在自然语言处理任务中,词嵌入是一种非常有用的技术。通过将每个单词表示为一个实数向量,我们可以将高维的词汇空间映射到一个低维的连续向量空间。这有助于提高模型的泛化能力和计算效率。例如,在文本分类任务中,我们可以使用nn.Embedding将文本中的每个单词转换为嵌入向量,然后将这些向量输入到神经网络中进行分类。

以下是一个简单的示例代码,演示了如何在PyTorch中使用nn.Embedding进行文本分类:

import torch
import torch.nn as nn
# 定义词嵌入层,词典大小为10000,嵌入向量维度为128
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=128)
# 假设我们有一个包含5个单词的文本,每个单词的编号分别为1, 2, 3, 4, 5
input_ids = torch.tensor([1, 2, 3, 4, 5], dtype=torch.long)
# 通过词嵌入层将单词编号转换为嵌入向量
embedded = embedding(input_ids)
# 输出嵌入向量的形状:(5, 128)
print(embedded.shape)
# 定义神经网络模型
class TextClassifier(nn.Module):def __init__(self, embedding_dim, hidden_dim, num_classes):super(TextClassifier, self).__init__()self.embedding = embeddingself.fc1 = nn.Linear(embedding_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, num_classes)def forward(self, input_ids):embedded = self.embedding(input_ids)# 对嵌入向量进行平均池化,得到一个固定长度的向量表示整个文本pooled = embedded.mean(dim=0)# 通过全连接层进行分类logits = self.fc2(self.fc1(pooled))return logits
# 实例化模型并进行训练...

上述代码中,我们首先定义了一个词嵌入层embedding,词典大小为10000,嵌入向量维度为128。然后,我们创建了一个包含5个单词的文本,每个单词的编号分别为1到5。通过调用embedding(input_ids),我们将单词编号转换为嵌入向量。最后,我们定义了一个文本分类器模型TextClassifier,其中包含了词嵌入层、全连接层等组件。在模型的前向传播过程中,我们首先对嵌入向量进行平均池化,得到一个固定长度的向量表示整个文本,然后通过全连接层进行分类。

除了自然语言处理任务外,nn.Embedding还可以用于图像处理任务。例如,在卷积神经网络(CNN)中,嵌入层可以将图像的像素值映射到一个高维的空间,从而更好地捕捉图像中的复杂特征和结构。这有助于提高模型的性能和泛化能力。

需要注意的是,在图像处理任务中,我们通常使用卷积层(nn.Conv2d)或像素嵌入层(nn.PixelEmbed)等模块来处理图像数据,而不是直接使用nn.Embedding。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/905375.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI 检测原创论文:技术迷思与教育本质的悖论思考

当高校将 AI 写作检测工具作为学术诚信的 "电子判官",一场由技术理性引发的教育异化正在悄然上演。GPT-4 检测工具将人类创作的论文误判为 AI 生成的概率高达 23%(斯坦福大学 2024 年研究数据),这种 "以 AI 制 AI&…

langchain4j集成QWen、Redis聊天记忆持久化

langchain4j实现聊天记忆默认是基于进程内存的方式,InMemoryChatMemoryStore是具体的实现了,是将聊天记录到一个map中,如果用户大的话,会造成内存溢出以及数据安全问题。位了解决这个问题 langchain4提供了ChatMemoryStore接口&am…

Tomcat 日志体系深度解析:从访问日志配置到错误日志分析的全链路指南

一、Tomcat 核心日志文件架构与核心功能 1. 三大基础日志文件对比(权威定义) 日志文件数据来源核心功能典型场景catalina.out标准输出 / 错误重定向包含 Tomcat 引擎日志与应用控制台输出(System.out/System.err)排查 Tomcat 启…

万物互联时代:ONVIF协议如何重构安防监控系统架构

前言 一、ONVIF协议是什么 ONVIF(Open Network Video Interface Forum,开放式网络视频接口论坛)是一种全球性的开放行业标准,由安讯士(AXIS)、博世(BOSCH)和索尼(SONY&…

leetcode - 双指针问题

文章目录 前言 题1 移动零: 思路: 参考代码: 题2 复写零: 思考: 参考代码: 题3 快乐数: 思考: 参考代码: 题4 盛最多水的容器: 思考:…

从概念表达到安全验证:智能驾驶功能迎来系统性规范

随着辅助驾驶事故频发,监管机制正在迅速补位。面对能力表达、使用责任、功能部署等方面的新要求,行业开始重估技术边界与验证能力,数字样机正成为企业合规落地的重要抓手。 2025年以来,围绕智能驾驶功能的争议不断升级。多起因辅…

java数组题(5)

(1): 思路: 1.首先要对数组nums排序,这样两数之间的差距最小。 2.题目要求我们通过最多 k 次递增操作,使数组中某个元素的频数(出现次数)最大化。经过上面的排序,最大数…

Python(1) 做一个随机数的游戏

有关变量的,其实就是 可以直接打印对应变量。 并且最后倒数第二行就是可以让两个数进行交换。 Py快捷键“ALTP 就是显示上一句的代码。 —————————————————————————————— 字符串 用 双引号或者单引号 。 然后 保证成双出现即可 要是…

【认知思维】验证性偏差:认知陷阱的识别与克服

什么是验证性偏差 验证性偏差(Confirmation Bias)是人类认知中最普遍、最根深蒂固的心理现象之一,指的是人们倾向于寻找、解释、偏爱和回忆那些能够确认自己已有信念或假设的信息,同时忽视或贬低与之相矛盾的证据。这种认知偏差影…

Wpf学习片段

IRegionManager 和IContainerExtension IRegionManager 是 Prism 框架中用于管理 UI 区域(Regions)的核心接口,它实现了模块化应用中视图(Views)的动态加载、导航和生命周期管理。 IContainerExtension 是依赖注入&…

消息~组件(群聊类型)ConcurrentHashMap发送

为什么选择ConcurrentHashMap? 在开发聊天应用时,我们需要存储和管理大量的聊天消息数据,这些数据会被多个线程频繁访问和修改。比如,当多个用户同时发送消息时,服务端需要同时处理这些消息的存储和查询。如果用普通的…

Stapi知识框架

一、Stapi 基础认知 1. 框架定位 自动化API开发框架:专注于快速生成RESTful API 约定优于配置:通过标准化约定减少样板代码 企业级应用支持:适合构建中大型API服务 代码生成导向:显著提升开发效率 2. 核心特性 自动CRUD端点…

基于深度学习的水果识别系统设计

一、选择YOLOv5s模型 YOLOv5:YOLOv5 是一个轻量级的目标检测模型,它在 YOLOv4 的基础上进行了进一步优化,使其在保持较高检测精度的同时,具有更快的推理速度。YOLOv5 的网络结构更加灵活,可以根据不同的需求选择不同大…

Spring Security与SaToken的对比

Spring Security与SaToken的详细对照与优缺点分析 1. 核心功能与设计理念 对比维度Spring SecuritySaToken核心定位企业级安全框架,深度集成Spring生态,提供全面的安全解决方案(认证、授权、攻击防护等)轻量级权限认证框架&#…

【docker】--镜像管理

文章目录 拉取镜像启动镜像为容器连接容器法一法二 保存镜像加载镜像镜像打标签移除镜像 拉取镜像 docker pull mysql:8.0.42启动镜像为容器 docker run -dp 8080:8080 --name container_mysql8.0.42 -e MYSQL_ROOT_PASSWORD123123123 mysql:8.0.42 连接容器 法一 docker e…

力扣HOT100之二叉树:543. 二叉树的直径

这道题本来想到可以用递归做,但是还是没想明白,最后还是去看灵神题解了,感觉这道题最大的收获就是巩固了我对lambda表达式的掌握。 按照灵神的思路,直径可以理解为从一个叶子出发向上,在某个节点处拐弯,然后…

web 自动化之 yaml 数据/日志/截图

文章目录 一、yaml 数据获取二、日志获取三、截图 一、yaml 数据获取 需要安装 PyYAML 库 import yaml import os from TestPOM.common import dir_config as Dirdef read_yaml(key,file_name"test_datas.yaml"):file_path os.path.join(Dir.testcases_dir, file_…

rtty操作记录说明

rtty操作记录说明 前言 整理资料发现了几年前做的操作记录,分享出来,希望对大家有用。 rtty-master:rtty客户端程序,其中buffer\log\ssl为源码的子目录,从git上下载https://github.com/zhaojh329, rtty…

mybatis中${}和#{}的区别

先测试&#xff0c;再说结论 userService.selectStudentByClssIds(10000, "wzh or 11");List<StudentEntity> selectStudentByClssIds(Param("stuId") int stuId, Param("field") String field);<select id"selectStudentByClssI…

【运维】MacOS蓝牙故障排查与修复指南

在日常使用macOS系统过程中&#xff0c;蓝牙连接问题时有发生。无论是无法连接设备、连接不稳定还是蓝牙功能完全失效&#xff0c;这些问题都会严重影响我们的工作效率。本文将分享一些实用的排查方法和修复技巧&#xff0c;帮助你解决macOS系统上的蓝牙故障。 问题症状 常见…