GPT - GPT(Generative Pre-trained Transformer)模型框架

本节代码主要为实现了一个简化版的 GPT(Generative Pre-trained Transformer)模型。GPT 是一种基于 Transformer 架构的语言生成模型,主要用于生成自然语言文本。
 

1. 模型结构

初始化部分
class GPT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, dff, dropout):super().__init__()self.emb = nn.Embedding(vocab_size, d_model)self.pos = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerDecoderBlock(d_model, dff, dropout)for i in range(N_blocks)])self.fc = nn.Linear(d_model, vocab_size)
  • vocab_size:词汇表的大小,即模型可以处理的唯一词元(token)的数量。

  • d_model:模型的维度,表示嵌入和内部表示的维度。

  • seq_len:序列的最大长度,即输入序列的最大长度。

  • N_blocks:Transformer 解码器块的数量。

  • dff:前馈网络(Feed-Forward Network, FFN)的维度。

  • dropout:Dropout 的概率,用于防止过拟合。

组件说明
  1. self.emb:词嵌入层,将输入的词元索引映射到 d_model 维的向量空间。

  2. self.pos:位置嵌入层,将序列中每个位置的索引映射到 d_model 维的向量空间。位置嵌入用于给模型提供序列中每个词元的位置信息。

  3. self.layers:一个模块列表,包含 N_blocksTransformerDecoderBlock。每个块是一个 Transformer 解码器层,包含多头注意力机制和前馈网络。

  4. self.fc:一个线性层,将解码器的输出映射到词汇表大小的维度,用于生成最终的词元概率分布。

2. 前向传播

def forward(self, x, attn_mask=None):emb = self.emb(x)pos = self.pos(torch.arange(x.shape[1]))x = emb + posfor layer in self.layers:x = layer(x, attn_mask)return self.fc(x)
步骤解析
  1. 词嵌入和位置嵌入

    • self.emb(x):将输入的词元索引 x 转换为词嵌入表示 emb,形状为 (batch_size, seq_len, d_model)

    • self.pos(torch.arange(x.shape[1])):生成位置嵌入 pos,形状为 (seq_len, d_model)torch.arange(x.shape[1]) 生成一个从 0 到 seq_len-1 的序列,表示每个位置的索引。

    • x = emb + pos:将词嵌入和位置嵌入相加,得到最终的输入表示 x。位置嵌入的加入使得模型能够区分序列中不同位置的词元。

  2. Transformer 解码器层

    • for layer in self.layers:将输入 x 逐层传递给每个 TransformerDecoderBlock

    • x = layer(x, attn_mask):每个解码器块会处理输入 x,并应用因果掩码 attn_mask(如果提供)。因果掩码确保模型在解码时只能看到当前及之前的位置,而不能看到未来的信息。

  3. 输出层

    • self.fc(x):将解码器的输出 x 传递给线性层 self.fc,生成最终的输出。输出的形状为 (batch_size, seq_len, vocab_size),表示每个位置上每个词元的预测概率。

截止到本篇文章GPT简单复现完成,下面将附完整代码,方便理解代码整体结构

import math
import torch
import random
import torch.nn as nnfrom tqdm import tqdm
from torch.utils.data import Dataset, DataLoader'''
仿 nn.TransformerDecoderLayer 实现
'''class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout):super().__init__()self.num_heads = num_headsself.d_k = d_model // num_headsself.q_project = nn.Linear(d_model, d_model)self.k_project = nn.Linear(d_model, d_model)self.v_project = nn.Linear(d_model, d_model)self.o_project = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, attn_mask=None):batch_size, seq_len, d_model = x.shapeQ = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)atten_scores = Q @ K.transpose(2, 3) / math.sqrt(self.d_k)if attn_mask is not None:attn_mask = attn_mask.unsqueeze(1)atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)atten_scores = torch.softmax(atten_scores, dim=-1)out = atten_scores @ Vout = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)out = self.o_project(out)return self.dropout(out)class TransformerDecoderBlock(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.linear1 = nn.Linear(d_model, dff)self.activation = nn.GELU()# self.activation = nn.ReLU()self.dropout = nn .Dropout(dropout)self.linear2 = nn.Linear(dff, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)self.mha_block1 = MultiHeadAttention(d_model, num_heads, dropout)self.mha_block2 = MultiHeadAttention(d_model, num_heads, dropout)def forward(self, x, mask=None):x = self.norm1(x + self.dropout1(self.mha_block1(x, mask)))x = self.norm2(x + self.dropout2(self.mha_block2(x, mask)))x = self.norm3(self.linear2(self.dropout(self.activation(self.linear1(x)))))return xclass GPT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, dff, dropout):super().__init__()self.emb = nn.Embedding(vocab_size, d_model)self.pos = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerDecoderBlock(d_model, dff, dropout)for i in range(N_blocks)])self.fc = nn.Linear(d_model, vocab_size)def forward(self, x, attn_mask=None):emb = self.emb(x)pos = self.pos(torch.arange(x.shape[1]))x = emb + posfor layer in self.layers:x = layer(x, attn_mask)return self.fc(x)def read_data(file, num=1000):with open(file, "r", encoding="utf-8") as f:data = f.read().strip().split("\n")res = [line[:24] for line in data[:num]]return resdef tokenize(corpus):vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3, ",": 4, "。": 5, "?": 6}for line in corpus:for token in line:vocab.setdefault(token, len(vocab))idx2word = list(vocab)return vocab, idx2wordclass Tokenizer:def __init__(self, vocab, idx2word):self.vocab = vocabself.idx2word = idx2worddef encode(self, text):ids = [self.token2id(token) for token in text]return idsdef decode(self, ids):tokens = [self.id2token(id) for id in ids]return tokensdef id2token(self, id):token = self.idx2word[id]return tokendef token2id(self, token):id = self.vocab.get(token, self.vocab["[UNK]"])return idclass Poetry(Dataset):def __init__(self, poetries, tokenizer: Tokenizer):self.poetries = poetriesself.tokenizer = tokenizerself.pad_id = self.tokenizer.vocab["[PAD]"]self.bos_id = self.tokenizer.vocab["[BOS]"]self.eos_id = self.tokenizer.vocab["[EOS]"]def __len__(self):return len(self.poetries)def __getitem__(self, idx):poetry = self.poetries[idx]poetry_ids = self.tokenizer.encode(poetry)input_ids = torch.tensor([self.bos_id] + poetry_ids)input_msk = causal_mask(input_ids)label_ids = torch.tensor(poetry_ids + [self.eos_id])return {"input_ids": input_ids,"input_msk": input_msk,"label_ids": label_ids}def causal_mask(x):mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0return maskdef generate_poetry(method="greedy", top_k=5):model.eval()with torch.no_grad():input_ids = torch.tensor(vocab["[BOS]"]).view(1, -1)while input_ids.shape[1] < seq_len:output = model(input_ids, None)probabilities = torch.softmax(output[:, -1, :], dim=-1)if method == "greedy":next_token_id = torch.argmax(probabilities, dim=-1)elif method == "top_k":top_k_probs, top_k_indices = torch.topk(probabilities[0], top_k)next_token_id = top_k_indices[torch.multinomial(top_k_probs, 1)]if next_token_id == vocab["[EOS]"]:breakinput_ids = torch.cat([input_ids, next_token_id.view(1, 1)], dim=1)return input_ids.squeeze()if __name__ == "__main__":file = "/Users/azen/Desktop/llm/LLM-FullTime/dataset/text-generation/poetry_data.txt"poetries = read_data(file, num=2000)vocab, idx2word = tokenize(poetries)tokenizer = Tokenizer(vocab, idx2word)trainset = Poetry(poetries, tokenizer)batch_size = 16trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)d_model = 512seq_len = 25 # 有特殊标记符num_heads = 8dropout = 0.1dff = 4*d_modelN_blocks = 2model = GPT(len(vocab), d_model, seq_len, N_blocks, dff, dropout)lr = 1e-4optim = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.CrossEntropyLoss()epochs = 100for epoch in range(epochs):for batch in tqdm(trainloader, desc="Training"):batch_input_ids = batch["input_ids"]batch_input_msk = batch["input_msk"]batch_label_ids = batch["label_ids"]output = model(batch_input_ids, batch_input_msk)loss = loss_fn(output.view(-1, len(vocab)), batch_label_ids.view(-1))loss.backward()optim.step()optim.zero_grad()print("Epoch: {}, Loss: {}".format(epoch, loss))res = generate_poetry(method="top_k")text = tokenizer.decode(res)print("".join(text))pass

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75511.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于FPGA的六层电梯智能控制系统 矩阵键盘-数码管 上板仿真均验证通过

基于FPGA的六层电梯智能控制系统 前言一、整体方案二、软件设计总结 前言 本设计基于FPGA实现了一个完整的六层电梯智能控制系统&#xff0c;旨在解决传统电梯控制系统在别墅环境中存在的个性化控制不足、响应速度慢等问题。系统采用Verilog HDL语言编程&#xff0c;基于Cyclo…

车载通信系统中基于ISO26262的功能安全与抗辐照协同设计研究

摘要&#xff1a;随着智能网联汽车的快速发展&#xff0c;车载通信系统正面临着功能安全与抗辐照设计的双重挑战。在高可靠性要求的车载应用场景下&#xff0c;如何实现功能安全标准与抗辐照技术的协同优化&#xff0c;构建满足ISO26262安全完整性等级要求的可靠通信架构&#…

Node.js种cluster模块详解

Node.js 中 cluster 模块全部 API 详解 1. 模块属性 const cluster require(cluster);// 1. isMaster // 判断当前进程是否为主进程 console.log(是否为主进程:, cluster.isMaster);// 2. isWorker // 判断当前进程是否为工作进程 console.log(是否为工作进程:, cluster.isW…

融合动态权重与抗刷机制的网文评分系统——基于优书网、IMDB与Reddit的混合算法实践

✨ Yumuing 博客 &#x1f680; 探索技术的每一个角落&#xff0c;解码世界的每一种可能&#xff01; &#x1f48c; 如果你对 AI 充满好奇&#xff0c;欢迎关注博主&#xff0c;订阅专栏&#xff0c;让我们一起开启这段奇妙的旅程&#xff01; 以权威用户为核心&#xff0c;时…

使用Golang打包jar应用

文章目录 背景Go 的 go:embed 功能介绍与打包 JAR 文件示例1. go:embed 基础介绍基本特性基本语法 2. 嵌入 JAR 文件示例项目结构代码实现 3. 高级用法&#xff1a;嵌入多个文件或目录4. 使用注意事项5. 实际应用场景6. 完整示例&#xff1a;运行嵌入的JAR 背景 想把自己的一个…

前端大屏可视化项目 局部全屏(指定盒子全屏)

需求是这样的&#xff0c;我用的项目是vue admin 项目 现在需要在做大屏项目 不希望显示除了大屏的其他东西 于是想了这个办法 至于大屏适配问题 请看我文章 底部的代码直接复制就可以运行 vue2 px转rem 大屏适配方案 postcss-pxtorem-CSDN博客 <template><div …

《2025蓝桥杯C++B组:D:产值调整》

**作者的个人gitee**​​ 作者的算法讲解主页▶️ 每日一言&#xff1a;“泪眼问花花不语&#xff0c;乱红飞过秋千去&#x1f338;&#x1f338;” 题目 二.解题策略 本题比较简单&#xff0c;我的思路是写三个函数分别计算黄金白银铜一次新产值&#xff0c;通过k次循环即可获…

[VTK] 四元素实现旋转平移

VTK 实现旋转&#xff0c;有四元数的方案&#xff0c;也有 vtkTransform 的方案&#xff1b;主要示例代码如下&#xff1a; //构造旋转四元数vtkQuaterniond rotation;rotation.SetRotationAngleAndAxis(vtkMath::RadiansFromDegrees(90.0),0.0, 1.0, 0.0);//构造旋转点四元数v…

华为hcie证书的有效期怎么判断?

在ICT行业&#xff0c;华为HCIE证书堪称含金量极高的“敲门砖”&#xff0c;拥有它往往意味着在职场上更上一层楼。然而&#xff0c;很多人在辛苦考取HCIE证书后&#xff0c;却对其有效期相关事宜一知半解。今天&#xff0c;咱们就来好好唠唠华为HCIE证书的有效期怎么判断这个关…

【精品PPT】2025固态电池知识体系及最佳实践PPT合集(36份).zip

精品推荐&#xff0c;2025固态电池知识体系及最佳实践PPT合集&#xff0c;共36份。供大家学习参考。 1、中科院化学所郭玉国研究员&#xff1a;固态金属锂电池及其关键材料.pdf 2、中科院物理所-李泓固态电池.pdf 3、全固态电池技术研究进展.pdf 4、全固态电池生产工艺.pdf 5、…

MySQL 中为产品添加灵活的自定义属性(如 color/size)

方案 1&#xff1a;EAV 模型&#xff08;最灵活但较复杂&#xff09; 适合需要无限扩展自定义属性的场景 -- 产品表 CREATE TABLE products (id INT PRIMARY KEY AUTO_INCREMENT,name VARCHAR(100),price DECIMAL(10,2) );-- 属性名表 CREATE TABLE attributes (id INT PRIMA…

CSPM认证对项目论证的范式革新:从合规审查到价值创造的战略跃迁

引言 在数字化转型浪潮中&#xff0c;全球企业每年因项目论证缺陷导致的损失高达1.7万亿美元&#xff08;Gartner 2023&#xff09;。CSPM&#xff08;Certified Strategic Project Manager&#xff09;认证体系通过结构化方法论&#xff0c;将传统的项目可行性评估升级为战略…

CLIP中的Zero-Shot Learning原理

CLIP&#xff08;Contrastive Language-Image Pretraining&#xff09;是一种由OpenAI提出的多模态模型&#xff0c;它通过对比学习的方式同时学习图像和文本的表示&#xff0c;并且能在多种任务中进行零样本学习&#xff08;Zero-Shot Learning&#xff09;。CLIP模型的核心创…

spring mvc 中 RestTemplate 全面详解及示例

RestTemplate 全面详解及示例 1. RestTemplate 简介 定义&#xff1a;Spring 提供的同步 HTTP 客户端&#xff0c;支持多种 HTTP 方法&#xff08;GET/POST/PUT/DELETE 等&#xff09;&#xff0c;用于调用 RESTful API。核心特性&#xff1a; 支持请求头、请求体、URI 参数的…

北大:LLM在NL2SQL中任务分解

&#x1f4d6;标题&#xff1a;LearNAT: Learning NL2SQL with AST-guided Task Decomposition for Large Language Models &#x1f310;来源&#xff1a;arXiv, 2504.02327 &#x1f31f;摘要 &#x1f538;自然语言到SQL&#xff08;NL2SQL&#xff09;已成为实现与数据库…

STM32LL库编程系列第八讲——ADC模数转换

系列文章目录 往期文章 STM32LL库编程系列第一讲——Delay精准延时函数&#xff08;详细&#xff0c;适合新手&#xff09; STM32LL库编程系列第二讲——蓝牙USART串口通信&#xff08;步骤详细、原理清晰&#xff09; STM32LL库编程系列第三讲——USARTDMA通信 STM32LL库编程…

网络5 TCP/IP 虚拟机桥接模式、NAT、仅主机模式

TCP/IP模型 用于局域网和广域网&#xff1b;多个协议&#xff1b;每一层呼叫下一层&#xff1b;四层&#xff1b;通用标准 TCP/IP模型 OSI七层模型 应用层 应用层 表示层 会话层 传输层 传输层 网络层 网络层 链路层 数据链路层 物理层 链路层&#xff1a;传数据帧&#xff0…

【C语言】预处理(下)(C语言完结篇)

一、#和## 1、#运算符 这里的#是一个运算符&#xff0c;整个运算符会将宏的参数转换为字符串字面量&#xff0c;它仅可以出现在带参数的宏的替换列表中&#xff0c;我们可以将其理解为字符串化。 我们先看下面的一段代码&#xff1a; 第二个printf中是由两个字符串组成的&am…

【高性能缓存Redis_中间件】一、快速上手redis缓存中间件

一、铺垫 在当今的软件开发领域&#xff0c;消息队列扮演着至关重要的角色。它能够帮助我们实现系统的异步处理、流量削峰以及系统解耦等功能&#xff0c;从而提升系统的性能和可维护性。Redis 作为一款高性能的键值对数据库&#xff0c;不仅提供了丰富的数据结构&#xff0c;…

Java如何获取文件的编码格式?

Java获取文件的编码格式 在计算机中&#xff0c;文件编码是指将文件内容转换成二进制形式以便存储和传输的过程。常见的文件编码格式包括UTF-8、GBK等。不同的编码使用不同的字符集和字节序列&#xff0c;因此在读取文件时需要正确地确定文件的编码格式 Java提供了多种方式以获…