从0搭建Transformer

1. 位置编码模块:

import torch
import torch.nn as nn
import mathclass PositonalEncoding(nn.Module):def __init__ (self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# [[1, 2, 3],# [4, 5, 6],# [7, 8, 9]]pe = torch.zeros(max_len, d_model)# [[0],# [1],# [2]]position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 位置编码固定,不更新参数# 保存模型时会保存缓冲区,在引入模型时缓冲区也被引入self.register_buffer('pe', pe)def forward(self, x):# 不计算梯度x = x + self.pe[:, :x.size(1)].requires_grad_(False)

2. 多头注意力模块

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_k = d_model // num_headsself.num_heads = num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)self.W_o = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = torch.softmax(scores, dim=-1)context = torch.matmul(attn_weights, V)context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_k * self.num_heads)return self.W_o(context)

3. 编码器层

class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):super().__init__()self.atten = MultiHeadAttention(d_model, num_heads)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_output = self.attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x

4. 解码器层

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return x

5. 模型整合

class Transformer(nn.module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):super(Transformer, self).__init__()self.encoder_embed = nn.Embedding(src_vocab_size, d_model)self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)self.pos_encoder = PositionalEncoding(d_model, dropout)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc_out = nn.Linear(d_model, tgt_vocab_size)def encode(self, src, src_mask):src_embeded = self.encoder_embed(src)src = self.pos_encoder(src_embeded)for layer in self.encoder_layers:src = layer(src, src_mask)return srcdef decode(self, tgt, enc_output, src_mask, tgt_mask):tgt_embeded = self.decoder_embed(tgt)tgt = self.pos_encoder(tgt_embeded)for layer in self.decoder_layers:tgt = layer(tgt, enc_output, src_mask, tgt_mask)return tgtdef forward(self, src, tgt, src_mask, tgt_mask):enc_output = self.encode(src, src_mask)dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)logits = self.fc_out(dec_output)return logits

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/79219.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Bootstrap V4系列】学习入门教程之 表格(Tables)和画像(Figure)

Bootstrap V4系列 学习入门教程之 表格(Tables)和画像(Figure) 表格(Tables)一、Examples二、Table head options 表格头选项三、Striped rows 条纹行四、Bordered table 带边框的表格五、Borderless table…

在C# WebApi 中使用 Nacos02: 配置管理、服务管理实战

一、配置管理 1.添加一个新的命名空间 这里我都填写为publicdemo 2.C#代码配置启动 appsetting.json加上: (nacos默认是8848端口) "NacosConfig": {"ServerAddresses": [ "http://localhost:8848" ], // Nacos 服务器地址"Na…

如何搭建spark yarn 模式的集群集群。

下载 App 如何搭建spark yarn 模式的集群集群。 搭建Spark on YARN集群的详细步骤 Spark on YARN模式允许Spark作业在Hadoop YARN资源管理器上运行,利用YARN进行资源调度。以下是搭建步骤: 一、前提条件 已安装并配置好的Hadoop集群(包括HDF…

C++--入门基础

C入门基础 1. C的第一个程序 C继承C语言许多大多数的语法,所以以C语言实现的hello world也可以运行,C中需要把文件定义为.cpp,vs编译器看是.cpp就会调用C编译器编译,linux下要用g编译,不再是gcc。 // test.cpp #inc…

从实列中学习linux shell9 如何确认 服务器反应迟钝是因为cpu还是 硬盘io 到底是那个程序引起的。cpu负载多高算高

在 Linux 系统中,Load Average(平均负载) 是衡量系统整体压力的关键指标,但它本身没有绝对的“高/低”阈值,需要结合 CPU 核心数 和 其他性能指标 综合分析。以下是具体判断方法: 一、Load Average 的基本含义 定义:Load Average 表示 单位时间内处于可运行状态(R)和不…

聊一聊接口测试更侧重于哪方面的验证

目录 一、功能性验证 输入与输出正确性 参数校验 业务逻辑覆盖 二、数据一致性验证 数据格式规范 数据完整性 数据类型与范围 三、异常场景验证 容错能力测试 边界条件覆盖 错误码与信息清晰度 四、安全与权限验证 身份认证 数据安全 防攻击能力 五、性能与可…

Fiddler抓取APP端,HTTPS报错全解析及解决方案(一篇解决常见问题)

环境:雷电模拟器Android9系统 ​ 你所遇到的fiddler中抓取HTTPS的问题可以分为三类:一类是你自己证书安装上逻辑错误,另一种是APP中使用了“证书固定”的手段。三类fiddler中生成证书时的参数过程。 1.Fiddler证书安装上的逻辑错误 更新Opt…

OpenGL-ES 学习(15) ----纹理

目录 纹理简介纹理映射纹理映射流程示例代码:纹理的环绕和过滤方式纹理的过滤方式 纹理简介 现实生活中,纹理(Texture) 类似于游戏中皮肤的概念,最通常的作用是装饰 3D 物体,它像贴纸一样贴在物体的表面,丰富物体的表…

OpenCV计算机视觉实战(2)——环境搭建与OpenCV简介

OpenCV计算机视觉实战(2)——环境搭建与OpenCV简介 0. 前言1. OpenCV 安装与配置1.1 安装 Python-OpenCV1.2 配置开发环境 2. OpenCV 基础2.1 图像读取与显示2.2 图像保存 3. 摄像头实时捕获小结系列链接 0. 前言 OpenCV (Open Source Computer Vision …

ubuntu22.04安装显卡驱动与cuda+cuDNN

背景: 紧接前文:Proxmox VE 8.4 显卡直通完整指南:NVIDIA 2080 Ti 实战。在R740服务器完成了proxmox的安装,并且安装了一张2080ti 魔改22g显存的的显卡。配置完了proxmox显卡直通,并将显卡挂载到了vm 301(…

A2A Python 教程 - 综合指南

目录 • 介绍• 设置环境• 创建项目• 代理技能• 代理卡片• A2A服务器• 与A2A服务器交互• 添加代理功能• 使用本地Ollama模型• 后续步骤 介绍 在本教程中,您将使用Python构建一个简单的echo A2A服务器。这个基础实现将向您展示A2A提供的所有功能。完成本教…

MySQL基础关键_005_DQL(四)

目 录 一、分组函数 1.说明 2.max/min 3.sum/avg/count 二、分组查询 1.说明 2.实例 (1)查询岗位和平均薪资 (2)查询每个部门编号的不同岗位的最低薪资 3.having (1)说明 (2&#xff…

GAMES202-高质量实时渲染(Assignment 2)

目录 作业介绍环境光贴图预计算传输项的预计算Diffuse unshadowedDiffuse shadowedDiffuse Inter-reflection(bonus) 实时球谐光照计算 GitHub主页:https://github.com/sdpyy1 作业实现:https://github.com/sdpyy1/CppLearn/tree/main/games202 作业介绍 物体在不同…

2025年- H21-Lc129-160. 相交链表(链表)---java版

1.题目描述 2.思路 当pa!pb的时候,执行pa不为空,遍历pa链表。执行pb不为空,遍历pb链表。 3.代码实现 // 单链表节点定义 class ListNode {int val;ListNode next;ListNode(int x){valx;nextnull;}}public class H160 {// 主方法…

win10系统安卓开发环境搭建

一 安装jdk 下载jdk17 ,下载路径:https://download.oracle.com/java/17/archive/jdk-17.0.12_windows-x64_bin.exe 下载完毕后,按照提示一步步完成,然后接着创建环境变量, 在cmd控制台输入java -version 验证: 有上面的输出代表jdk安装并配置成功。 二 安装Android stu…

【算法基础】选择排序算法 - JAVA

一、算法基础 1.1 什么是选择排序 选择排序是一种简单直观的排序算法,它的工作原理是:首先在未排序序列中找到最小(或最大)元素,存放到排序序列的起始位置,然后再从剩余未排序元素中继续寻找最小&#xf…

LabVIEW异步调用VI介绍

在 LabVIEW 编程环境里,借助结合异步 VI 调用,并使用 “Open VI Reference” 函数上的 “Enable simultaneous calls on reentrant VIs” 选项(0x40),达成了对多个 VI 调用执行效率的优化。以下将从多方面详细介绍该 V…

Leetcode刷题 | Day50_图论02_岛屿问题01_dfs两种方法+bfs一种方法

一、学习任务 99. 岛屿数量_深搜dfs代码随想录99. 岛屿数量_广搜bfs100. 岛屿的最大面积101. 孤岛的总面积 第一类DFS(主函数中处理第一个节点,DFS处理相连节点): 主函数中先将起始节点标记为已访问DFS函数中不处理起始节点&…

深入理解网络安全中的加密技术

1 引言 在当今数字化的世界中,网络安全已经成为个人隐私保护、企业数据安全乃至国家安全的重要组成部分。随着网络攻击的复杂性和频率不断增加,保护敏感信息不被未授权访问变得尤为关键。加密技术作为保障信息安全的核心手段,通过将信息转换为…