Transformer模型:WordEmbedding实现

前言

        最近在学Transformer,学了理论的部分之后就开始学代码的实现,这里是跟着b站的up主的视频记的笔记,视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili


正文

        首先导入所需要的包:

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

         关于Word Embedding,这里以序列建模为例,考虑source sentence、target sentence,构建序列,序列的字符以其在词表中的索引的形式表示。

        首先使用定义batch_size的大小,并且使用torch.randint()函数随机生成序列长度,这里的src是生成原本的序列,tgt是生成目标的序列。

        以机器翻译实现英文翻译为中文来说,src就是英文句子,tgt就是中文句子,这也就是规定了要翻译的英文句子的长度和翻译出来的句子长度。(举个例子而已,不用纠结为什么翻译要限制句子的长度)

batch_size = 2src_len=torch.randint(2,5,(batch_size,))
tgt_len=torch.randint(2,5,(batch_size,))

        将生成的src_len、tgt_len输出:

tensor([2, 3])    生成的原序列第一个句子长度为2,第二个句子长度为3
tensor([4, 4])    生成的目标序列第一个句子长度为4,第二个句子长度为4

        因为随机生成的,所以每次运行都会有新的结果,也就是生成的src和tgt两个序列,其子句的长度每次都是随机的,这里改成生成固定长度的序列:

src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)

         将生成的src_len、tgt_len输出,此时就固定好了序列长度了:

tensor([11,  9], dtype=torch.int32)
tensor([10, 11], dtype=torch.int32)

        接着是要实现单词索引构成的句子,首先定义单词表的大小和序列的最大长度。

# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12

         以生成原序列为例,使用torch.randint()生成第一个句子和第二个句子,然后放到列表中:

src_seq = [torch.randint(1, max_num_src_words, (L,)) for L in src_len]
[tensor([5, 3, 7, 5, 6, 3, 4, 3]), tensor([1, 6, 3, 1, 1, 7, 4])]

         可以发现生成的两个序列长度不一样(因为我们自己定义的时候就是不一样的),在这里需要使用F.pad()函数进行padding保证序列长度一致:

src_seq = [F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seg_len-L)) for L in src_len]
[tensor([8, 5, 2, 4, 6, 8, 1, 4, 0, 0, 0, 0]), tensor([5, 5, 5, 3, 7, 9, 3, 0, 0, 0, 0, 0])]

         此时已经填充为同样的长度了,但是不同的句子各为一个张量,需要使用torch.cat()函数把不同句子的tensor转化为二维的tensor,在此之前需要先把每个张量变成二维的,使用torch.unsqueeze()函数:

src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tensor([[9, 7, 7, 4, 7, 3, 9, 4, 7, 8, 8, 0],[1, 1, 5, 9, 5, 6, 2, 7, 4, 0, 0, 0]])
tensor([[3, 3, 2, 8, 3, 4, 1, 2, 9, 4, 0, 0],[1, 6, 3, 8, 5, 1, 5, 5, 1, 5, 3, 0]])

         这里把tgt的也补充了,得到的就是src和tgt的内容各自在一个二维张量里(batch_size,max_seg_len),batch_size也就是句子数,max_seg_len也就是句子的单词数(分为src的长度跟tgt两种)。

        补充:可以看到上面三次运行出来的结果都不一样,因为三次运行的时候,每次都是随机生成,所以结果肯定不一样,第三次为什么有两个二维的tensor是因为第三次把tgt的部分也补上去了,所以就有两个二维的tensor。

        接下来就是构造embedding了,这里nn.Embedding()传入了两个参数,第一个是embedding的长度,也就是单词个数+1,+1的原因是因为有个0是作为填充的,第二个参数就是embedding的维度,也就是一个单词会被映射为多少维度的向量。

        然后调用forward,得到我们的src和tgt的embedding

src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)   
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)    # 每一行代表一个embedding向量,第0行让给pad,从第1行到第行分配给各个单词,单词的索引是多少就取对应的行位置的向量
print(src_embedding)    # 根据src_seq,从src_embedding_table获取得到的embedding vector,三维张量:batch_size、max_seq_len、model_dim
print(tgt_embedding)

        此时src_embedding_table.weight的输出内容如下,第一行为填充(0)的向量:

tensor([[-0.3412,  1.5198, -1.7252,  0.6905, -0.3832, -0.8586, -2.0788,  0.3269],
        [-0.5613,  0.3953,  1.6818, -2.0385,  1.1072,  0.2145, -0.9349, -0.7091],
        [ 1.5881, -0.2389, -0.0347,  0.3808,  0.5261,  0.7253,  0.8557, -1.0020],
        [-0.2725,  1.3238, -0.4087,  1.0758,  0.5321, -0.3466, -0.9051, -0.8938],
        [-1.5393,  0.4966, -1.4887,  0.2795, -1.6751, -0.8635, -0.4689, -0.0827],
        [ 0.6798,  0.1168, -0.5410,  0.5363, -0.0503,  0.4518, -0.3134, -0.6160],
        [-1.1223,  0.3817, -0.6903,  0.0479, -0.6894,  0.7666,  0.9695, -1.0962],
        [ 0.9608,  0.0764,  0.0914,  1.1949, -1.3853,  1.1089, -0.9282, -0.9793],
        [-0.9118, -1.4221, -2.4675, -0.1321,  0.7458, -0.8015,  0.5114, -0.5023],
        [-1.7504,  0.0824,  2.2088, -0.4486,  0.7324,  1.8790,  1.7644,  1.2731],
        [-0.3791,  1.9915, -1.0117,  0.8238, -2.1784, -1.2824, -0.4275,  0.3202]],
       requires_grad=True)

        src_embedding的输出结果如下所示,往前看src_seq的第一个句子前三个为9  7  7,往前看第9+1行与第7+1行的向量,就是现在输出的前3个向量:

tensor([[[-1.7504,  0.0824,  2.2088, -0.4486,  0.7324,  1.8790,  1.7644, 1.2731],[ 0.9608,  0.0764,  0.0914,  1.1949, -1.3853,  1.1089, -0.9282, -0.9793],[ 0.9608,  0.0764,  0.0914,  1.1949, -1.3853,  1.1089, -0.9282, -0.9793],[-1.5393,  0.4966, -1.4887,  0.2795, -1.6751, -0.8635, -0.4689, -0.0827],[ 0.9608,  0.0764,  0.0914,  1.1949, -1.3853,  1.1089, -0.9282, -0.9793],[-0.2725,  1.3238, -0.4087,  1.0758,  0.5321, -0.3466, -0.9051, -0.8938],[-1.7504,  0.0824,  2.2088, -0.4486,  0.7324,  1.8790,  1.7644, 1.2731],[-1.5393,  0.4966, -1.4887,  0.2795, -1.6751, -0.8635, -0.4689, -0.0827],[ 0.9608,  0.0764,  0.0914,  1.1949, -1.3853,  1.1089, -0.9282, -0.9793],[-0.9118, -1.4221, -2.4675, -0.1321,  0.7458, -0.8015,  0.5114, -0.5023],[-0.9118, -1.4221, -2.4675, -0.1321,  0.7458, -0.8015,  0.5114, -0.5023],[-0.3412,  1.5198, -1.7252,  0.6905, -0.3832, -0.8586, -2.0788, 0.3269]],[[-0.5613,  0.3953,  1.6818, -2.0385,  1.1072,  0.2145, -0.9349, -0.7091],[-0.5613,  0.3953,  1.6818, -2.0385,  1.1072,  0.2145, -0.9349, -0.7091],[ 0.6798,  0.1168, -0.5410,  0.5363, -0.0503,  0.4518, -0.3134, -0.6160],[-1.7504,  0.0824,  2.2088, -0.4486,  0.7324,  1.8790,  1.7644, 1.2731],[ 0.6798,  0.1168, -0.5410,  0.5363, -0.0503,  0.4518, -0.3134, -0.6160],[-1.1223,  0.3817, -0.6903,  0.0479, -0.6894,  0.7666,  0.9695, -1.0962],[ 1.5881, -0.2389, -0.0347,  0.3808,  0.5261,  0.7253,  0.8557, -1.0020],[ 0.9608,  0.0764,  0.0914,  1.1949, -1.3853,  1.1089, -0.9282, -0.9793],[-1.5393,  0.4966, -1.4887,  0.2795, -1.6751, -0.8635, -0.4689, -0.0827],[-0.3412,  1.5198, -1.7252,  0.6905, -0.3832, -0.8586, -2.0788, 0.3269],[-0.3412,  1.5198, -1.7252,  0.6905, -0.3832, -0.8586, -2.0788, 0.3269],[-0.3412,  1.5198, -1.7252,  0.6905, -0.3832, -0.8586, -2.0788, 0.3269]]], grad_fn=<EmbeddingBackward>)

         同理tgt_embedding的输出结果如下所示:

tensor([[[-1.3681, -0.1619, -0.3676,  0.4312, -1.3842, -0.6180,  0.3685, 1.6281],[-1.3681, -0.1619, -0.3676,  0.4312, -1.3842, -0.6180,  0.3685, 1.6281],[-2.6519, -0.8566,  1.2268,  2.6479, -0.2011, -0.1394, -0.2449, 1.0309],[-0.8919,  0.5235, -3.1833,  0.9388, -0.6213, -0.5146,  0.7913, 0.5126],[-1.3681, -0.1619, -0.3676,  0.4312, -1.3842, -0.6180,  0.3685, 1.6281],[-0.4984,  0.2948, -0.2804, -1.1943, -0.4495,  0.3793, -0.1562, -1.0122],[ 0.8976,  0.5226,  0.0286,  0.1434, -0.2600, -0.7661,  0.1225, -0.7869],[-2.6519, -0.8566,  1.2268,  2.6479, -0.2011, -0.1394, -0.2449, 1.0309],[ 2.2026,  1.8504, -0.6285, -0.0996, -0.0994, -0.0828,  0.6004, -0.3173],[-0.4984,  0.2948, -0.2804, -1.1943, -0.4495,  0.3793, -0.1562, -1.0122],[ 0.3637,  0.4256,  0.7674,  1.4321, -0.1164, -0.6032, -0.8182, -0.6119],[ 0.3637,  0.4256,  0.7674,  1.4321, -0.1164, -0.6032, -0.8182, -0.6119]],[[ 0.8976,  0.5226,  0.0286,  0.1434, -0.2600, -0.7661,  0.1225, -0.7869],[-1.0356,  0.8212,  1.0538,  0.4510,  0.2734,  0.3254,  0.4503, 0.1694],[-1.3681, -0.1619, -0.3676,  0.4312, -1.3842, -0.6180,  0.3685, 1.6281],[-0.8919,  0.5235, -3.1833,  0.9388, -0.6213, -0.5146,  0.7913, 0.5126],[-0.4783, -1.5936,  0.5033,  0.3483, -1.3354,  1.4553, -1.1344, -1.9280],[ 0.8976,  0.5226,  0.0286,  0.1434, -0.2600, -0.7661,  0.1225, -0.7869],[-0.4783, -1.5936,  0.5033,  0.3483, -1.3354,  1.4553, -1.1344, -1.9280],[-0.4783, -1.5936,  0.5033,  0.3483, -1.3354,  1.4553, -1.1344, -1.9280],[ 0.8976,  0.5226,  0.0286,  0.1434, -0.2600, -0.7661,  0.1225, -0.7869],[-0.4783, -1.5936,  0.5033,  0.3483, -1.3354,  1.4553, -1.1344, -1.9280],[-1.3681, -0.1619, -0.3676,  0.4312, -1.3842, -0.6180,  0.3685, 1.6281],[ 0.3637,  0.4256,  0.7674,  1.4321, -0.1164, -0.6032, -0.8182, -0.6119]]], grad_fn=<EmbeddingBackward>)

        实际想要把文本句子嵌入到Embedding中,需要先根据自己的词典,将文本信息转化为每个词在词典中的位置,然后第0个位置依旧要让给Padding,得到索引然后构建Batch再去构造Embedding,以索引为输入得到每个样本的Embedding。 

代码

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F# 句子数
batch_size = 2# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12# 模型的维度
model_dim = 8# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
print(src_len)
print(tgt_len)#单词索引构成的句子
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max_tgt_seg_len-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)# 构造embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)  
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)    
print(src_embedding)    
print(tgt_embedding)

参考

 torch.randint — PyTorch 2.3 documentation

torch.nn.functional.pad — PyTorch 2.3 文档

F.pad 的理解_domain:luyixian.cn-CSDN博客

嵌入 — PyTorch 2.3 文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/871044.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

useRef和useState的区别

在React中&#xff0c;useRef和useState都是Hooks&#xff0c;它们用于在函数组件中添加React状态&#xff0c;但它们的用途和行为有所不同&#xff1a; useState useState用于在函数组件中添加可变状态。它让你能够保存和管理随时间变化的数据。它返回一个数组&#xff0c;包…

离线安装docker-compse

离线安装 Docker Compose 可以通过以下步骤完成&#xff1a; 下载 Docker Compose 二进制文件&#xff1a; 首先&#xff0c;你需要在一个可以访问互联网的机器上下载 Docker Compose 的二进制文件。你可以使用以下命令来下载&#xff1a; sudo curl -L "https://github.c…

云WAF在电子商务领域具体能提供哪些安全功能?

云WAF&#xff08;Cloud Web Application Firewall&#xff09;在电子商务领域提供了一系列关键的安全功能&#xff0c;以保护在线交易平台免受各种网络攻击和威胁。以下是云WAF能够提供的具体安全功能&#xff1a; 实时流量监控与分析&#xff1a;云WAF能够对电子商务网站的流…

Matlab结合ChatGPT—如何计算置信区间?

​前面分享了带置信区间的折线图和带置信区间的折线散点图的绘图教程&#xff1a; 很多人表示&#xff0c;昆哥&#xff0c;图是很好看啦&#xff0c;但咱不会求置信区间啊&#xff0c;咋办嘞&#xff1f; 说实话&#xff0c;这种事情属于数据处理&#xff0c;一般都是在画图前…

家政服务小程序:提高家政服务,新商机!

当下&#xff0c;社会生活的节奏非常快&#xff0c;人们忙于工作&#xff0c;在日常生活家务清洁中面临着时间、精力不足的问题&#xff0c;因此对家政服务的需求日益增加&#xff0c;这也推动了家政行业的迅速发展。目前不少年轻人都开始涌入到了家政行业中&#xff0c;市场的…

HTTP协议。(HTTP-概述和特点、HTTP-请求协议、HTTP-请求数据格式、浏览器访问服务器的几种方式)

2.1 HTTP-概述 HTTP协议又分为&#xff1a;请求协议和响应协议 请求协议&#xff1a;浏览器将数据以请求格式发送到服务器 包括&#xff1a;请求行、请求头 、请求体 响应协议&#xff1a;服务器将数据以响应格式返回给浏览器 包括&#xff1a;响应行 、响应头 、响应体 2.…

重要!!!MySQL 9.0存在重大BUG!!

7/11日开源数据库软件服务商percona发布重要警告&#xff0c;最新的mysql版本存在重大bug&#xff0c;原文如下 Do Not Upgrade to Any Version of MySQL After 8.0.37 Warning! Recently, Jean-Franois Gagn opened a bug on bug.mysql.com #115517; unfortunately, the bug…

CT金属伪影去除的去噪扩散概率模型| 文献速递-基于深度学习的多模态数据分析与生存分析

Title 题目 A denoising diffusion probabilistic model for metal artifact reduction in CT CT金属伪影去除的去噪扩散概率模型 01 文献速递介绍 CT图像中的金属伪影是在CT扫描视野内存在金属物体&#xff08;如牙科填充物、骨科假体、支架、手术器械等&#xff09;时出…

探索Java网络编程精髓:UDP与TCP的实战魔法!

Java 中提供了专门的网络编程程序包 java.net&#xff0c;提供了两种通信协议&#xff1a;UDP&#xff08;数据报协议&#xff09;和 TCP&#xff08;传输控制协议&#xff09;&#xff0c;本文对两种通信协议的开发进行详细介绍。 1 UDP 介绍 UDP&#xff1a;User Datagram Pr…

css横向滚动条支持鼠标滚轮

在做视频会议的时候&#xff0c;标准模式视图会有顶部收缩的一种交互方式&#xff0c;用到了横向滚动&#xff1b;一般情况下鼠标滚轮只支持竖向滚动&#xff0c;这次写个demo是适配横向滚动&#xff1b; 效果图展示 实现横向滚动条顶部显示 <div className{style.remote_u…

已知经纬度坐标,评价数据空间分布均匀性

文章目录 基本介绍1. 可视化分析使用Python的matplotlib和Basemap库&#xff1a; 2. 统计检验使用Python的scipy库进行Kolmogorov-Smirnov检验&#xff1a; 3. 空间分析技术使用Python的geopandas和sklearn库进行核密度估计&#xff1a; 调用函数1. 可视化分析函数2. 统计检验函…

如何在Linux系统下安装Anaconda

安装步骤 一、在Linux服务器下获取Anaconda安装包二、启动Anaconda安装程序三、修改PATH环境变量四、验证Anaconda是否安装成功 最近课题组实验室又新购了两台服务器&#xff0c;需要重新部署深度学习环境才能使用&#xff0c;但我突然发现自己不太记得Anaconda具体的安装过程了…

【YOLO格式的数据标签,目标检测】

标签为 YOLO 格式&#xff0c;每幅图像一个 *.txt 文件&#xff08;如果图像中没有对象&#xff0c;则不需要 *.txt 文件&#xff09;。*.txt 文件规格如下: 每个对象一行 每一行都是 class x_center y_center width height 格式。 边框坐标必须是 归一化的 xywh 格式&#x…

nginx正向代理和反向代理

nginx正向代理和反向代理 正向代理以及缓存配置 代理&#xff1a;客户端不再是直接访问服务器&#xff0c;通过代理服务器访问服务端。 正向代理&#xff1a;面向客户端&#xff0c;我们通过代理服务器的IP地址访问目标服务端。 服务端只知道代理服务器的地址&#xff0c;真…

CRC32简述

CRC32简述 crc32 通常指的是 CRC-32&#xff08;Cyclic Redundancy Check 32-bit,即循环冗余检查&#xff09;算法&#xff0c;而 foobar 是一个示例字符串&#xff0c;用来作为 CRC-32 算法的输入。CRC-32 是一种广泛使用的循环冗余校验&#xff08;CRC&#xff09;算法&#…

面试题 21. 调整数组顺序使奇数位于偶数前面

调整数组顺序使奇数位于偶数前面 题目描述示例 题解 题目描述 输入一个整数数组&#xff0c;实现一个函数来调整该数组中数字的顺序&#xff0c;使得所有奇数在数组的前半部分&#xff0c;所有偶数在数组的后半部分。 示例 输入&#xff1a;nums [1,2,3,4] 输出&#xff1a;…

每日一练 - OSPF邻居关系建立故障排查

01 真题题目 OSPF邻居关系建立出现故障&#xff0c;通过display ospf error命令查看&#xff0c;显示如下信息&#xff0c;则邻居建立失败的原因可能是&#xff1a; A. Router ID冲突 B.区域ID不匹配 C.网络掩码不一致 D.MTU不一致 02 真题答案 B 03 答案解析 从图片中可以…

爬虫学习日记

引言&#xff1a; 1.语言&#xff1a;python 2.预备知识——python&#xff1a;爬虫学习前记----Python-CSDN博客 3.学习资源&#xff1a;【Python爬虫】 html&#xff1a; <!DOCTYPE html> <html><head><title>czy_demo</title><meta c…

数据丢失?不存在的!

今年3月份&#xff0c;AT&T遭遇了严重的数据泄露事件&#xff0c;导致7300万客户账户信息被泄露。泄露的信息包括客户的姓名、电话号码、邮寄地址等敏感资料&#xff0c;甚至部分客户的加密密码也被泄露&#xff0c;使得约760万AT&T用户的账户面临被劫持的风险。 此次…

android inflate 参数含义

在Android开发中&#xff0c;inflate 方法用于将 XML 布局文件转换为相应的 View 对象。在调用 inflate 方法时&#xff0c;有几个参数需要特别注意&#xff1a; resource (int resId): 布局资源文件的ID。通常是通过 R.layout.layout_name 这种形式指定的。 root (ViewGroup …