第TR3周:Pytorch复现Transformer

  •    🍨 本文为🔗365天深度学习训练营中的学习记录博客
  •    🍖 原作者:K同学啊

Transformer通过自注意力机制,改变了序列建模的方式,成为AI领域的基础架构

编码器:理解输入,提取上下文特征。

解码器:基于编码特征,按顺序生成输出。


1.多头注意力机制

import math
import torch
import torch.nn as nndevice=torch.device('cuda' if torch.cuda.is_available() else 'cpu')class MultiHeadAttention(nn.Module):# n_heads:多头注意力的数量# hid_dim:每个词输出的向量维度def __init__(self,hid_dim,n_heads):super(MultiHeadAttention,self).__init__()self.hid_dim=hid_dimself.n_heads=n_heads#强制hid_dim必须整除 hassert hid_dim % n_heads == 0#定义W_q矩阵ceself.w_q=nn.Linear(hid_dim,hid_dim)#定义W_k矩阵self.w_k=nn.Linear(hid_dim,hid_dim)#定义W_v矩阵self.w_v=nn.Linear(hid_dim,hid_dim)self.fc =nn.Linear(hid_dim,hid_dim)#缩放self.scale=torch.sqrt(torch.FloatTensor([hid_dim//n_heads]))def forward(self,query,key,value,mask=None):#Q,K,V的在句子这长度这一个维度的数值可以不一样,可以一样#K:[64,10,300],假设batch_size为64,有10个词,每个词的Query向量是300维bsz=query.shape[0]Q  =self.w_q(query)K  =self.w_k(key)V  =self.w_v(value)#这里把K Q V 矩阵拆分为多组注意力#最后一维就是是用self.hid_dim // self.n_heads 来得到的,表示每组注意力的向量长度,每个head的向量长度是:300/6=50#64表示batch size,6表示有6组注意力,10表示有10词,50表示每组注意力的词的向量长度#K: [64,10,300] 拆分多组注意力 -> [64,10,6,50] 转置得到 -> [64,6,10,50]#转置是为了把注意力的数量6放在前面,把10和50放在后面,方便下面计算Q=Q.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)K=K.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)V=V.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)#Q乘以K的转置,除以scale#[64,6,12,50]*[64,6,50,10]=[64,6,12,10]#attention:[64,6,12,10]attention=torch.matmul(Q,K.permute(0,1,3,2)) / self.scale#如果mask不为空,那么就把mask为0的位置的attention分数设置为-1e10,这里用‘0’来指示哪些位置的词向量不能被attention到,比如padding位置if mask is not None:attention=attention.masked_fill(mask==0,-1e10)#第二步:计算上一步结果的softmax,再经过dropout,得到attention#注意,这里是对最后一维做softmax,也就是在输入序列的维度做softmax#attention: [64,6,12,10]attention=torch.softmax(attention,dim=-1)#第三步,attention结果与V相乘,得到多头注意力的结果#[64,6,12,10] * [64,6,10,50] =[64,6,12,50]# x: [64,6,12,50]x=torch.matmul(attention,V)#因为query有12个词,所以把12放在前面,把50和6放在后面,方便下面拼接多组的结果#x: [64,6,12,50] 转置 -> [64,12,6,50]x=x.permute(0,2,1,3).contiguous()#这里的矩阵转换就是:把多头注意力的结果拼接起来#最后结果就是[64,12,300]# x:[64,12,6,50] -> [64,12,300]x=x.view(bsz,-1,self.n_heads*(self.hid_dim//self.n_heads))x=self.fc(x)return x

2.前馈传播

class Feedforward(nn.Module):def __init__(self,d_model,d_ff,dropout=0.1):super(Feedforward,self).__init__()#两层线性映射和激活函数self.linear1=nn.Linear(d_model,d_ff)self.dropout=nn.Dropout(dropout)self.linear2=nn.Linear(d_ff,d_model)def forward(self,x):x=torch.nn.functional.relu(self.linear1(x))x=self.dropout(x)x=self.linear2(x)return x

3.位置编码

class PositionalEncoding(nn.Module):"实现位置编码"def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# 初始化Shape为(max_len, d_model)的PE (positional encoding)pe = torch.zeros(max_len, d_model).to(device)# 初始化一个tensor [[0, 1, 2, 3, ...]]position = torch.arange(0, max_len).unsqueeze(1)# 这里就是sin和cos括号中的内容,通过e和ln进行了变换div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term) # 计算PE(pos, 2i)pe[:, 1::2] = torch.cos(position * div_term) # 计算PE(pos, 2i+1)pe = pe.unsqueeze(0) # 为了方便计算,在最外面在unsqueeze出一个batch# 如果一个参数不参与梯度下降,但又希望保存model的时候将其保存下来# 这个时候就可以用register_bufferself.register_buffer("pe", pe)def forward(self, x):"""x 为embedding后的inputs,例如(1,7, 128),batch size为1,7个单词,单词维度为128"""# 将x和positional encoding相加。x = x + self.pe[:, :x.size(1)].requires_grad_(False)return self.dropout(x)

4.编码层

class EncoderLayer(nn.Module):def __init__(self,d_model,n_heads,d_ff,dropout=0.1):super(EncoderLayer,self).__init__()#编码器层包含自注意机制和前馈神经网络self.self_attn=MultiHeadAttention(d_model,n_heads)self.feedforward=Feedforward(d_model,d_ff,dropout)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)self.dropout=nn.Dropout(dropout)def forward(self,x,mask):#自注意力机制atten_output=self.self_attn(x,x,x,mask)x=x+self.dropout(atten_output)x=self.norm1(x)#前馈神经网络ff_output=self.feedforward(x)x=x+self.dropout(ff_output)x=self.norm2(x)return x

5.解码层

class DecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()# 解码器层包含自注意力机制、编码器-解码器注意力机制和前馈神经网络self.self_attn   = MultiHeadAttention(d_model, n_heads)self.enc_attn    = MultiHeadAttention(d_model, n_heads)self.feedforward = Feedforward(d_model, d_ff, dropout)self.norm1   = nn.LayerNorm(d_model)self.norm2   = nn.LayerNorm(d_model)self.norm3   = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, self_mask, context_mask):# 自注意力机制attn_output = self.self_attn(x, x, x, self_mask)x           = x + self.dropout(attn_output)x           = self.norm1(x)# 编码器-解码器注意力机制attn_output = self.enc_attn(x, enc_output, enc_output, context_mask)x           = x + self.dropout(attn_output)x           = self.norm2(x)# 前馈神经网络ff_output = self.feedforward(x)x = x + self.dropout(ff_output)x = self.norm3(x)return x

6.Transformer模型构建

class Transformer(nn.Module):def __init__(self, vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):super(Transformer, self).__init__()# Transformer 模型包含词嵌入、位置编码、编码器和解码器self.embedding           = nn.Embedding(vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model)self.encoder_layers      = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])self.decoder_layers      = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])self.fc_out              = nn.Linear(d_model, vocab_size)self.dropout             = nn.Dropout(dropout)def forward(self, src, trg, src_mask, trg_mask):# 词嵌入和位置编码src = self.embedding(src)src = self.positional_encoding(src)trg = self.embedding(trg)trg = self.positional_encoding(trg)# 编码器for layer in self.encoder_layers:src = layer(src, src_mask)# 解码器for layer in self.decoder_layers:trg = layer(trg, src, trg_mask, src_mask)# 输出层output = self.fc_out(trg)return output

vocab_size = 10000
d_model    = 128
n_heads    = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff             = 2048
dropout          = 0.1device = torch.device('cpu')transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout)# 定义输入
src = torch.randint(0, vocab_size, (32, 10))  # 源语言句子
trg = torch.randint(0, vocab_size, (32, 20))  # 目标语言句子
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # 掩码,用于屏蔽填充的位置
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2)  # 掩码,用于屏蔽填充的位置# 模型前向传播
output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/71577.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FreeRTOS 任务间通信机制:队列、信号量、事件标志组详解与实验

1. FreeRTOS 消息队列 1.1 简介 ​ 队列是 任务间通信的主要形式,可用于在任务之间以及中断与任务之间传递消息。队列在 FreeRTOS 中具有以下关键特点: 队列默认采用 先进先出 FIFO 方式,也可以使用 xQueueSendToFront()实现 LIFO。FreeRT…

【虚拟化】Docker Desktop 架构简介

在阅读前您需要了解 docker 架构:Docker architecture WSL 技术:什么是 WSL 2 1.Hyper-V backend 我们知道,Docker Desktop 最开始的架构的后端是采用的 Hyper-V。 Docker daemon (dockerd) 运行在一个 Linux distro (LinuxKit build) 中&…

Unity光照之Halo组件

简介 Halo 组件 是一种用于在游戏中创建光晕效果的工具,主要用于模拟光源周围的发光区域(如太阳、灯泡等)或物体表面的光线反射扩散效果。 核心功能 1.光晕生成 Halo 组件会在光源或物体的周围生成一个圆形光晕,模拟光线在空气…

Flink深入浅出之01:应用场景、基本架构、部署模式

Flink 1️⃣ 一 、知识要点 📖 1. Flink简介 Apache Flink — Stateful Computations over Data StreamsApache Flink 是一个分布式大数据处理引擎,可对有界数据流和无界数据流进行有状态的计算。Flink 能在所有常见集群环境中运行,并能以…

2025年【高压电工】报名考试及高压电工考试总结

随着电力行业的快速发展,高压电工成为确保电力系统安全稳定运行的重要一环。为了提高高压电工的专业技能和安全意识,“安全生产模拟考试一点通”平台特别整理了2025年高压电工报名考试的相关信息及考试总结,并提供了一套完整的题库&#xff0…

网络HTTP

HTTP Network Request Library A Retrofit-based HTTP network request encapsulation library that provides simple and easy-to-use API interfaces with complete network request functionality. 基于Retrofit的HTTP网络请求封装库,提供简单易用的API接口和完…

os-copilot安装和使用体验测评

简介: OS Copilot是阿里云基于大模型构建的Linux系统智能助手,支持自然语言问答、命令执行和系统运维调优。本文介绍其产品优势、功能及使用方法,并分享个人开发者在云服务器资源管理中的实际应用体验。通过-t/-f/管道功能,OS Cop…

Python Flask框架学习汇编

1、入门级: 《Python Flask Web 框架入门》 这篇博文条理清晰,由简入繁,案例丰富,分十五节详细讲解了Flask框架,强烈推荐! 《python的简单web框架flask【附例子】》 讲解的特别清楚,每一步都…

【HarmonyOS Next之旅】DevEco Studio使用指南(一)

目录 1 -> 工具简介 1.1 -> 概述 1.2 -> HarmonyOS应用/服务开发流程 1.2.1 -> 开发准备 1.2.2 -> 开发应用/服务 1.2.3 -> 运行、调试和测试应用/服务 1.2.4 -> 发布应用/服务 2 -> 工程介绍 2.1 -> APP包结构 2.2 -> 切换工程视图 …

Manus开源平替-开源通用智能体

原文链接:https://i68.ltd/notes/posts/250306-opensource-agi-agent/ OWL-比Manus还强的全能开源Agent OWL: Optimized Workforce Learning for General Multi-Agent Assistance in Real-World Task Automation,现实世界中执行自动化任务的通用多代理辅助优化学习…

【3.2-3.8学习周报】

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 摘要Abstract一、方法介绍1.任务适应性持续预训练(TACP)2.领域自适应连续预训练(DACP)3.ETS-DACP和ETA-DACP 二、实验…

【Linux】用户和组

思考 使用useradd在Linux下面创建一个用户,默认情况下,会自动创建一个同名组,并且加入其中,那么是先创建用户呢?还是先创建组呢? 很简单,我们实践一下不就知道了,如下所示&#xff…

新编大学应用英语综合教程2 U校园全套参考答案

全套答案获取: 链接:https://pan.quark.cn/s/389618f53143

SAP 顾问的五年职业规划

SAP 顾问的职业发展受到技术进步、企业需求变化和全球经济环境的影响,因此制定长远规划充满挑战。面对 SAP 产品路线图的不确定性,如向 S/4HANA 和 Business Technology Platform (BTP) 的转变,顾问必须具备灵活性,以保持竞争力和…

图像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image

图像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image 文章目录 图像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image主要创新点模型架构图生成器生成器源码 判别器判别器源码 损失函数需要源码讲解的私信我 S…

Networking Based ISAC Hardware Testbed and Performance Evaluation

文章目录 Applications and Challenges of Networked SensingCooperation Mechanism in Networked SensingChallenges and Key Enabling Technologies 5G NR Frame Structure Based ISAC ApproachSignals Available for Radio SensingMulti-Dimensiona Resource Optimization S…

2025年主流原型工具测评:墨刀、Axure、Figma、Sketch

2025年主流原型工具测评:墨刀、Axure、Figma、Sketch 要说2025年国内产品经理使用的主流原型设计工具,当然是墨刀、Axure、Figma和Sketch了,但是很多刚入行的产品经理不了解自己适合哪些工具,本文将从核心优势、局限短板、协作能…

我代表中国受邀在亚马逊云科技全球云计算大会re:Invent中技术演讲

大家好我是小李哥,本名叫李少奕,目前在一家金融行业公司担任首席云计算工程师。去年5月很荣幸在全球千万名开发者中被选为了全球亚马逊云科技认证技术专家(AWS Hero),是近10年来大陆地区仅有的第9名大陆专家。同时作为…

LeetCode 解题思路 12(Hot 100)

解题思路: 定义三个指针: prev(前驱节点)、current(当前节点)、nextNode(临时保存下一个节点)遍历链表: 每次将 current.next 指向 prev,移动指针直到 curre…

Ubuntu搭建最简单WEB服务器

安装apache2 sudo apt install apache2 检查状态 $ sudo systemctl status apache2 ● apache2.service - The Apache HTTP ServerLoaded: loaded (/lib/systemd/system/apache2.service; enabled; vendor prese>Active: active (running) since Thu 2025-03-06 09:51:10…