【动手学深度学习-Pytorch版】BERT预测系列——BERTModel

本小节主要实现了以下几部分内容:

  • 从一个句子中提取BERT输入序列以及相对的segments段落索引(因为BERT支持输入两个句子)
  • BERT使用的是Transformer的Encoder部分,所以需要需要使用Encoder进行前向传播:输出的特征等于词嵌入+位置编码+Encoder块
  • 用于BERT预训练时预测的掩蔽语言模型任务中的掩蔽标记< mask >
  • 用于预训练任务的下一个句子的预测——在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
  • 通过BERTModel整合代码

"""可学习的位置编码也需要进行初始化"""
import torch
import d2l.torch
from torch import nn
import transformers
"""将一个句子或者两个句子作为输入,然后返回BERT输入序列及其相应的序列对的片段索引segments"""
def get_tokens_segments(tokens_a,tokens_b=None):"""获取输入序列的词元及其片段索引"""tokens = ['<cls>'] + tokens_a + ['<sep>']# 利用0和1分别标记片段A和片段Bsegments = [0] * (len(tokens_a)+2)  #加上<cls>和sepif tokens_b is not None:# 如果是句子对tokens += tokens_b+['<sep>']segments += [1]*(len(tokens_b)+1)  # 加上<sep>return tokens,segments"""在原始的Transformer架构中,编码器的位置嵌入信息是直接加到了输入序列的每个位置,但是BERT使用的是可学习的位置嵌入"""
"""bert-input = tokens_embedding + position_embedding + segment_embedding"""
class BERTEncoder(nn.Module):"""BERT编码器"""def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=1000,key_size=768,query_size=768,value_size=768,use_bias=True):super(BERTEncoder, self).__init__()self.token_embedding = nn.Embedding(vocab_size,num_hiddens)self.segment_embedding = nn.Embedding(2,num_hiddens)# 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入的参数self.pos_embedding = nn.Parameter(torch.randn(size=(1,max_len,num_hiddens)))# print('self.pos_embedding:',self.pos_embedding)"""self.pos_embedding.data : [1,1000,768]在下面与X相加时利用的是广播机制"""self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module(f'{i}',d2l.torch.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias))def forward(self,tokens,segments,valid_lens):# 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)X = self.token_embedding(tokens)+self.segment_embedding(segments)print('X.shape:',X.shape)   # [2,8,768]X += self.pos_embedding.data[:,:X.shape[1],:]  #[2,8,768]for blk in self.blks:X = blk(X,valid_lens)return X
"""演示BERTEncoder的前向传播--->词表大小:10000"""
vocab_size,num_hiddens,ffn_num_input,ffn_num_hiddens,num_heads,num_layers = 1000,768,768,1024,4,2
norm_shape,dropout = [768],0.2
encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout)
"""将tokens定义为长度为8的2个输入序列"""
tokens = torch.randint(0,vocab_size,(2,8))
print('tokens:',tokens)
print('tokens_shape:',tokens.shape)
"""其中每个词元由向量表示,其长度由超参数num_hiddens定义,此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)"""
segments = torch.tensor([[0,0,0,0,1,1,1,1],[0,0,0,1,1,1,1,1]])
print('segments:',segments)
enc_outputs = encoder(tokens,segments,None)
print('enc_outputs.shape',enc_outputs.shape)# 预训练任务---》双向编码上下文:掩蔽语言模型
"""预测BERT预训练的掩蔽语言模型任务中的掩蔽标记"""
#@save
class MaskLM(nn.Module):"""BERT的掩蔽语言模型任务"""def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):super(MaskLM, self).__init__(**kwargs)# 两层的MLP,同时使用激活函数ReLU  和 层归一化self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),nn.ReLU(),nn.LayerNorm(num_hiddens),nn.Linear(num_hiddens, vocab_size))# 前向传播时的输入信息包括:# 1 BERTEncoder编码结果# 2 用于预测词元的位置def forward(self, X, pred_positions):num_pred_positions = pred_positions.shape[1]# 将预测的位置压缩成一维向量空间pred_positions = pred_positions.reshape(-1)# BERTEncoder的输出特征形状:[batch_size,...]batch_size = X.shape[0]batch_idx = torch.arange(0, batch_size)# 假设batch_size=2,num_pred_positions=3# 那么batch_idx是np.array([0,0,0,1,1,1])# torch.repeat_interleave用于重复张量元素batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)print('输入的X形状:',X.shape)# batch_idx# pred_positions# 都是两个list其中batch_idx选择的是屏蔽的行# pred_positions选择的是屏蔽的列masked_X = X[batch_idx, pred_positions]print('masked后X的内容:',masked_X)# 最后把所有要屏蔽的数据拉成一个一维的向量masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))mlm_Y_hat = self.mlp(masked_X)# 最后返回的是利用MLP预测这些位置的结果return mlm_Y_hat"""将mlm_positions定义为在encoded_X的任一输如系列中预测3个值"""
"""而且对于每一个预测的结果都等于词表的大小"""
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(enc_outputs, mlm_positions)
mlm_Y_hat_shape = mlm_Y_hat.shape
print('mlm_Y_hat_shape:',mlm_Y_hat_shape)# 通过掩码下的预测词元mlm_Y的真实标签mlm_Y_hat,我们可以计算在BERT预训练中的遮蔽语言模型任务的交叉熵损失
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l_shape = mlm_l.shape
print('mlm_l_shape:',mlm_l_shape)# 预训练任务---》下一个句子的预测
"""在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
"""
#@save
class NextSentencePred(nn.Module):"""BERT的下一句预测任务"""def __init__(self, num_inputs, **kwargs):super(NextSentencePred, self).__init__(**kwargs)self.output = nn.Linear(num_inputs, 2)def forward(self, X):# X的形状:(batchsize,num_hiddens)return self.output(X)
"""NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测"""
enc_outputs = torch.flatten(enc_outputs, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(enc_outputs.shape[-1])
nsp_Y_hat = nsp(enc_outputs)
print('nsp_Y_hat.shape',nsp_Y_hat.shape)
# 计算两个二元分类的交叉熵损失
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l_shape = nsp_l.shape
print('nsp_l_shape:',nsp_l_shape)#@save
class BERTModel(nn.Module):"""BERT模型"""def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768,hid_in_features=768, mlm_in_features=768,nsp_in_features=768):super(BERTModel, self).__init__()self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,ffn_num_input, ffn_num_hiddens, num_heads, num_layers,dropout, max_len=max_len, key_size=key_size,query_size=query_size, value_size=value_size)self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),nn.Tanh())self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)self.nsp = NextSentencePred(nsp_in_features)def forward(self, tokens, segments, valid_lens=None,pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/103671.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【TA 工具积累】参考图展示 PureRef | 截图 Snipaste

贴两个平常看图和截图比较方便的工具&#xff1a; PureRef 官网指路&#xff1a;PureRef 油管简单的使用教程视频&#xff1a;Free Download | PureRef 知乎上大佬总结的快捷键&#xff1a; PureRef 快捷键 提炼总结 - 知乎 (zhihu.com) b站大佬总结的快捷键&#xff1a;…

一文告知HTTP GET是否可以有请求体

HTTP GET是否可以有请求体 先说结论&#xff1a; HTTP协议没有规定GET请求不能携带请求体&#xff0c;但是部分浏览器会不支持&#xff0c;因此不建议GET请求携带请求体。 HTTP 协议没有为 GET 请求的 body 赋予语义&#xff0c;也就是即不要求也不禁止 GET 请求带 body。大多数…

Hazelcast系列(八):数据结构

系列文章 Hazelcast系列(一)&#xff1a;初识hazelcast Hazelcast系列(二)&#xff1a;hazelcast集成&#xff08;嵌入式&#xff09; Hazelcast系列(三)&#xff1a;hazelcast集成&#xff08;服务器/客户端&#xff09; Hazelcast系列(四)&#xff1a;hazelcast管理中心 …

Idea集成Docker

1、前言 上一节中&#xff0c;我们介绍了Dockerfile的方式构建自己的镜像。但是在实际开发过程中&#xff0c;一般都会和开发工具直接集成&#xff0c;如Idea。今天就介绍下idea和Docker如何集成。 2、开启docker远程 要集成之前&#xff0c;需要我们本机能够访问docker服务…

在liunx下读取串口的数据

1. 设置串口参数 首先是通过stty工具设置串口参数&#xff1a; sudo stty -F /dev/ttyUSB0 比特率 cs8 -cstopb如&#xff1a;sudo stty -F /dev/ttyUSB0 115200 cs8 -cstopb. 注意&#xff1a; 需要注意的是这里需要sudo权限&#xff1b; 2. 读取串口数据 然后读取串口的…

vue3 组件化的优势

Vue3是一款前端框架&#xff0c;其最大的特点是支持组件化开发。组件化开发可以将页面拆分成多个模块&#xff0c;每个模块都是一个独立的组件&#xff0c;便于开发和维护。 Vue3是一款全新的前端框架&#xff0c;相比于Vue2&#xff0c;它有很多优势&#xff0c;包括以下几个…

基于LoRa的远程气象站:实现远程气象监测与数据传输

随着物联网技术的不断发展&#xff0c;基于无线通信的远程气象监测系统得以广泛应用。本文将介绍一种基于LoRa技术的远程气象站&#xff0c;通过LoRa模块实现气象数据的远程采集和传输&#xff0c;为气象监测提供了一种高效、低功耗的解决方案。 LoRa技术概述 LoRa&#xff08…

ssm+vue的课程网络学习平台管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频&#xff1a; ssmvue的课程网络学习平台管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;ssm vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体…

《从程序员到架构师》:从现在开始培养架构思维,一点都不晚

《从程序员到架构师》&#xff1a;从现在开始培养架构思维&#xff0c;一点都不晚 尽管大家都明白软件架构非常重要&#xff0c;但是能够真正理解并应用软件架构的核心思维去解决实战的商业项目&#xff0c;确实大多数程序员所欠缺的。本文将从一个全新的视角&#xff0c;重新带…

【算法练习Day19】二叉搜索树的最近公共祖先二叉搜索树中的插入操作删除二叉搜索树中的节点

​&#x1f4dd;个人主页&#xff1a;Sherry的成长之路 &#x1f3e0;学习社区&#xff1a;Sherry的成长之路&#xff08;个人社区&#xff09; &#x1f4d6;专栏链接&#xff1a;练题 &#x1f3af;长路漫漫浩浩&#xff0c;万事皆有期待 文章目录 二叉搜索树的最近公共祖先叉…

Avalonia常用小控件Svg

1.项目下载地址&#xff1a;https://gitee.com/confusedkitten/avalonia-demo 2.UI库Semi.Avalonia&#xff0c;项目地址 https://github.com/irihitech/Semi.Avalonia 3.SVG库&#xff0c;Avalonia.Svg.Skia&#xff0c;项目地址 https://github.com/wieslawsoltes/Svg.Ski…

【数据库——MySQL(实战项目1)】(4)图书借阅系统——触发器

目录 1. 简述2. 功能代码2.1 创建两个触发器&#xff0c;分别在借出或归还图书时&#xff0c;修改借阅人表中的已借数目(附加&#xff1a;借阅人表的总借书数、图书表的借阅次数以及更新图书表的图书状态为(已借出/在架上))字段&#xff1b;2.2 创建触发器&#xff0c;当借阅者…

redis 哨兵 sentinel(一)配置

sentinel巡查监控后台master主机是否故障&#xff0c;如果故障根据投票数自动将某一个从库转换为新主库&#xff0c;继续对外服务 sentinel 哨兵的功能 监控 监控主从redis库运行是否正常消息通知 哨兵可以将故障转移的结果发送给客户端故障转移 如果master异常&#xff0c;则…

IOS17 轻松签全能签还能不能用?多开能否使用?升级后微信底栏消失怎么办?BY:后厂村路灯

从iphone15还没出就有小伙伴们追着问&#xff0c; 到现在也有人一直再问iOS17能不能用&#xff0c;看来换手机的人很多呀。 这里统一回答一下&#xff1a;“iOS17苹果签名可以用&#xff0c;多开也可以用”但是还是有些地方注意。 如果你是16系统直接升级刀17就可以&#xff…

使用 Splashtop 驾驭未来媒体和娱乐

在当今时代&#xff0c;数字转型不再是可选项&#xff0c;而是必选项。如今&#xff0c;媒体与娱乐业处于关键时刻&#xff0c;正在错综复杂的创意、技术和远程协作迷宫之中摸索前进。过去几年发生的全球事件影响了我们的日常生活&#xff0c;不可逆转地改变了行业的运作方式&a…

给你一个项目,你将如何开展性能测试工作?

一、性能三连问 1、何时进行性能测试&#xff1f; 性能测试的工作是基于系统功能已经完备或者已经趋于完备之上的&#xff0c;在功能还不够完备的情况下没有多大的意义。因为后期功能完善上会对系统的性能有影响&#xff0c;过早进入性能测试会出现测试结果不准确、浪费测试资…

【C++STL基础入门】list基本使用

文章目录 前言一、list简介1.1 list是什么1.2 list的头文件 二、list2.1 定义对象2.2 list构造函数2.3 list的属性函数 总结 前言 STL&#xff08;Standard Template Library&#xff09;是C标准库的一个重要组成部分&#xff0c;提供了一套丰富的数据结构和算法&#xff0c;可…

【Lombok的Bug记录】前端传的有值,但是到后端就全为空了

项目场景&#xff1a; 项目背景&#xff1a;使用Data注解标注类 问题描述 前端传的有值&#xff0c;但是到后端就全为空了 原因分析&#xff1a; AName和aName生成的set方法名是一样的&#xff0c;所以换名字就行了&#xff01; 解决方案&#xff1a; 属性不要写成xXxx的形式…

【伪彩色图像处理】将灰度图像转换为彩色图像研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

YB4014是可以对单节磷酸铁锂电池进行恒流/恒压充电管理的集成电路。

概述&#xff1a; YB4014是可以对单节磷酸铁锂电池进行恒流/恒 压充电管理的集成电路。该器件内部包括功率晶 体管&#xff0c;不需要外部的电流检测电阻和阻流二极管 YB4014只需要极少的外围元器件&#xff0c;非常适合于 便携式应用的领域。热调制电路可以在器件的功 耗比较大…