Vlm-Transformer_demo

import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import random # ===================== 1. 准备数据(字符级语料) ===================== # 简单语料(自己构造,无需下载) #训练样本数: 89 | 词汇表字符: [' ', 'a', 'c', 'd', 'e', 'f', 'h', 'i', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'w', 'y'] corpus = [ "hello transformer", "transformer is a powerful model", "pytorch transformer demo", "transformer uses self attention", "attention is the core of transformer" ] # 收集所有唯一字符,建立字符→索引、索引→字符映射 all_chars = sorted(list(set("".join(corpus)))) # 去重+排序 char2idx = {char: idx for idx, char in enumerate(all_chars)} idx2char = {idx: char for char, idx in char2idx.items()} vocab_size = len(all_chars) # 词汇表大小(字符数) seq_len = 10 # 输入序列长度(取前10个字符,预测第11个) # 生成训练数据:输入是前seq_len个字符,目标是第seq_len+1个字符 def generate_data(corpus, seq_len): inputs = [] targets = [] for sentence in corpus: # 将句子转成字符索引序列 sentence_idx = [char2idx[c] for c in sentence] # 滑动窗口生成样本(确保长度足够) for i in range(len(sentence_idx) - seq_len): input_seq = sentence_idx[i:i+seq_len] target_char = sentence_idx[i+seq_len] inputs.append(input_seq) targets.append(target_char) # 转成Tensor return torch.tensor(inputs), torch.tensor(targets) # 生成训练集 train_inputs, train_targets = generate_data(corpus, seq_len) print(f"训练样本数: {len(train_inputs)} | 词汇表字符: {all_chars}") # ===================== 2. 定义位置编码(Transformer必需) ===================== class PositionalEncoding(nn.Module): def __init__(self, embedding_dim, max_len=5000): super().__init__() # 预计算位置编码 position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embedding_dim)) pe = torch.zeros(max_len, 1, embedding_dim) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) # 不参与训练的参数 def forward(self, x): # x形状: [seq_len, batch_size, embedding_dim] x = x + self.pe[:x.size(0)] return x # ===================== 3. 定义Transformer模型 ===================== class TransformerLM(nn.Module): def __init__(self, vocab_size, embedding_dim=16, nhead=2, num_layers=2): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) # 字符嵌入 self.pos_encoder = PositionalEncoding(embedding_dim) # 位置编码 # Transformer编码器(这里用编码器做语言模型,也可以用解码器) encoder_layers = nn.TransformerEncoderLayer( d_model=embedding_dim, # 输入维度(和嵌入维度一致) nhead=nhead, # 多头注意力的头数 dim_feedforward=64 # 前馈网络的隐藏层维度 ) self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) self.fc = nn.Linear(embedding_dim, vocab_size) # 输出到词汇表 def forward(self, x): # x形状: [batch_size, seq_len] → 转成Transformer要求的[seq_len, batch_size] x = x.transpose(0, 1) # 嵌入+位置编码: [seq_len, batch_size, embedding_dim] x = self.embedding(x) x = self.pos_encoder(x) # Transformer编码: [seq_len, batch_size, embedding_dim] x = self.transformer_encoder(x) # 取最后一个时间步的输出(预测下一个字符): [batch_size, embedding_dim] x = x[-1, :, :] # 输出分类: [batch_size, vocab_size] x = self.fc(x) return x # ===================== 4. 训练模型 ===================== def train(): # 超参数 embedding_dim = 16 #字符嵌入维度(每个字符转成 16 维向量) nhead = 2 #多头 num_layers = 2 #编码器层数 batch_size = 4 #每次训练的样本数 epochs = 50 #训练轮数 lr = 0.001 #学习率(参数更新的步长) # 初始化模型、损失函数、优化器 #实例化我们定义的 Transformer 字符模型,把超参数传入(比如 embedding_dim=16) #此时模型的参数(QKV 权重、嵌入层权重等)都是随机初始化的,还没学到任何东西。 model = TransformerLM(vocab_size, embedding_dim, nhead, num_layers) #交叉熵 #损失函数 #选择交叉熵作为损失函数 #交叉熵损失是分类任务的 “标配” criterion = nn.CrossEntropyLoss() # 分类损失(预测字符) #Adam是SGD的升级版,自带自适应学习 #优化器要更新的是模型的所有可训练参数(QKV、嵌入层、全连接层等) optimizer = optim.Adam(model.parameters(), lr=lr) # 训练循环 #切换训练模式 model.train() for epoch in range(epochs): total_loss = 0.0 #初始化 “本轮 Epoch 的总损失”,用于统计整个 Epoch 的平均损失(损失越小→模型预测越准) # 随机打乱数据(按batch处理) indices = torch.randperm(len(train_inputs)) for i in range(0, len(train_inputs), batch_size):#取一个步长=batch 例如4 88个样本 22 # 取一个batch batch_idx = indices[i:i+batch_size]#取当前批次的 例如 4 那就是4-7 batch_inputs = train_inputs[batch_idx]#取当前批次的输入序列 batch_targets = train_targets[batch_idx]#取当前批次的目标字符 # 前向传播 outputs = model(batch_inputs)#输出 “预测结果” loss = criterion(outputs, batch_targets)#用交叉熵损失函数,计算 “模型预测得分” 和 “真实目标字符” 的差距 # 反向传播+优化 optimizer.zero_grad()#清空模型所有参数的梯度 loss.backward()#反向传播计算梯度 optimizer.step()#用梯度更新参数 total_loss += loss.item() # 每5轮打印一次损失 if (epoch + 1) % 5 == 0: print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_inputs):.4f}") print("训练完成!") return model # ===================== 5. 测试生成(输入前缀,生成后续字符) ===================== def generate_text(model, prefix, max_len=20): model.eval() # 将前缀转成字符索引 input_seq = [char2idx[c] for c in prefix] with torch.no_grad(): for _ in range(max_len): # 取最后seq_len个字符作为输入(不足则补0) current_input = torch.tensor([input_seq[-seq_len:] if len(input_seq)>=seq_len else [0]*(seq_len-len(input_seq)) + input_seq]) # 预测下一个字符 output = model(current_input) next_char_idx = output.argmax(dim=1).item() input_seq.append(next_char_idx) # 如果生成空格或结束符(这里用空格当结束),提前停止 if idx2char[next_char_idx] == " ": break # 转成字符 return "".join([idx2char[idx] for idx in input_seq]) # ===================== 主函数(直接运行) ===================== if __name__ == "__main__": # 训练模型 trained_model = train() # 测试生成(输入不同前缀) prefixes = ["trans", "att", "pyt"] for prefix in prefixes: generated = generate_text(trained_model, prefix) print(f"\n输入前缀: '{prefix}' → 生成结果: '{generated}'")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1150774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务分布式SpringBoot+Vue+Springcloud四川自驾游攻略管理系统

目录微服务分布式SpringBootVueSpringCloud四川自驾游攻略管理系统摘要开发技术源码文档获取/同行可拿货,招校园代理 :文章底部获取博主联系方式!微服务分布式SpringBootVueSpringCloud四川自驾游攻略管理系统摘要 该系统基于微服务分布式架构&#xff…

微服务分布式SpringBoot+Vue+Springcloud微信小程序的宠物美容预约系统设计与实现

目录摘要开发技术源码文档获取/同行可拿货,招校园代理 :文章底部获取博主联系方式!摘要 随着宠物经济的快速发展,宠物美容服务的需求日益增长。传统的线下预约方式存在效率低、信息不对称等问题。基于此,设计并实现了一套基于微服…

USB转串口驱动安装步骤通俗解释

电脑没串口?一文搞懂USB转串口驱动安装与芯片选型 你有没有遇到过这种情况:手握一块开发板,连上USB线准备调试,打开设备管理器却发现“未知设备”或者根本找不到COM口?明明线插好了,灯也亮了,就…

Java SpringBoot+Vue3+MyBatis 网站系统源码|前后端分离+MySQL数据库

摘要 随着互联网技术的快速发展,现代Web应用对高性能、模块化和可扩展性的需求日益增长。传统的单体架构在应对复杂业务逻辑和高并发场景时逐渐显现出局限性,前后端分离架构因其灵活性、开发效率高和易于维护等特点成为主流解决方案。基于此背景&#xf…

易连说-如何寻找具备 Drummond Group AS2 国际认证的EDI 产品?

在数字化供应链重构的浪潮中,电子数据交换(EDI)已从“可选配置”升级为企业对接全球贸易伙伴的“必备能力”。作为 EDI 数据传输的主流协议——AS2 协议凭借安全加密、可靠传输的特性,成为企业间数据交换的核心选择,选…

AD画PCB中HDMI高速通道设计项目应用详解

如何在Altium Designer中搞定HDMI高速通道设计?一文讲透实战要点你有没有遇到过这样的情况:板子打回来了,HDMI接口连上去却黑屏、闪屏,甚至压根不识别显示器?明明原理图画得没错,元器件也焊上了&#xff0c…

小白指南:USB接口各引脚功能详解入门篇

从零开始搞懂USB:别再被那几根线难住了!你有没有试过自己焊一条USB线,结果接上电脑没反应,甚至烧了接口?或者想给开发板单独供电,却不知道哪根线是电源、哪根是地?又或者好奇为什么有些安卓手机…

大数据分布式事务:CAP定理视角下的解决方案对比

大数据分布式事务:CAP定理视角下的解决方案对比关键词:大数据、分布式事务、CAP定理、解决方案对比摘要:本文主要从CAP定理的视角出发,深入探讨大数据分布式事务的多种解决方案。首先介绍了大数据分布式事务的背景知识和CAP定理的…

企业级大创管理系统管理系统源码|SpringBoot+Vue+MyBatis架构+MySQL数据库【完整版】

摘要 随着高等教育改革的不断深化,大学生创新创业训练计划(大创)已成为培养创新型人才的重要途径。传统的大创项目管理多依赖手工操作或简易电子表格,存在信息分散、流程不透明、统计效率低下等问题。高校亟需一套标准化、数字化的…

微服务分布式SpringBoot+Vue+Springcloud万里学院摄影作品活动报名商城系统社团管理系统

目录摘要开发技术源码文档获取/同行可拿货,招校园代理 :文章底部获取博主联系方式!摘要 该系统基于微服务分布式架构,采用SpringBoot、Vue.js和SpringCloud技术栈,为万里学院设计了一套集摄影作品展示、活动报名、商城交易及社团…

Win11升级后Multisim数据库异常?核心要点解析

Win11升级后Multisim打不开元件库?一文讲透数据库异常的底层真相与实战修复你有没有遇到过这种情况:辛辛苦苦把电脑从Win10升级到Win11,结果一打开熟悉的Multisim——满屏报错,“multisim数据库无法访问”几个大字赫然在目&#x…

嘉立创EDA原理图注释与标注操作指南:提升图纸可读性

嘉立创EDA原理图注释与标注实战:让电路图“会说话”你有没有遇到过这样的情况?打开一张几个月前自己画的原理图,满屏飞线交错、元件编号跳跃混乱,连电源线都找不到从哪来、到哪去。更别提团队协作时,同事指着某个引脚问…

深度解析|当 Prometheus 遇见大模型:解密下一代智能监控体系

导读在云原生时代,Prometheus Alertmanager 虽然解决了“看得见”的问题,却无法解决“看得懂”和“看得早”的难题。运维团队往往陷入“故障发生->收到告警->紧急救火”的被动循环。 本文将探讨如何利用 AI 大模型技术赋能现有监控体系&#xff0…

全加器晶体管级实现指南:手把手构建CMOS电路

从逻辑门到晶体管:手把手设计一个高性能CMOS全加器你有没有想过,当你在Verilog里写下assign S A ^ B ^ Cin;的时候,背后到底发生了什么?那行看似简单的代码,最终会变成芯片上几十个微小的MOS晶体管,它们协…

从零搭建日志分析系统:es数据库手把手教程

从零搭建日志分析系统:Elasticsearch 实战手记当你的服务开始“失联”,你靠什么找回真相?想象一下这样的场景:凌晨两点,告警突然响起。线上 API 响应时间飙升,用户请求大面积超时。你登录服务器&#xff0c…

工业控制面板中LCD1602的布局与驱动技巧

工业控制面板中的LCD1602:从电路设计到驱动优化的实战指南在自动化设备遍布车间的今天,你是否曾注意到——那些看似“过时”的黑白字符屏,依然稳稳地嵌在一台台控制柜的前面板上?它们没有炫彩动画,也不支持触控滑动&am…

SpringBoot+Vue 图书进销存管理系统平台完整项目源码+SQL脚本+接口文档【Java Web毕设】

摘要 随着信息技术的快速发展,传统图书管理方式已无法满足现代企业的需求。纸质记录和手工操作效率低下,容易出错,且难以实现数据的实时共享与分析。图书进销存管理系统通过数字化手段优化图书采购、销售、库存管理等核心业务流程&#xff0c…

有源与无源蜂鸣器电路对比:一文说清核心差异与应用场景

有源与无源蜂鸣器电路对比:一文讲透设计本质与实战选型你有没有遇到过这样的情况?项目快收尾了,突然发现报警提示音“嘀——”一声单调得像老式微波炉;或者想让设备播放一段简单的“do re mi”,结果接上蜂鸣器后只发出…

【AI】光速理解YOLO框架

1.要点解析 我们前面学的PyTorch是用来搭建神经网络模型的脚手架,即利用一些算子搭建网络结构,并且支持评估推理等全套API。 区别于PyTorch,YOLO包含了丰富的计算机视觉模型库。有了YOLO,就不需要自己从0开始搭建模型了。YOLO内置…

全面讲解Windows下USB Serial驱动下载步骤

一次搞定!Windows下USB转串口驱动安装全攻略 你有没有遇到过这样的场景:手握一块开发板,满心期待地插上USB线,打开设备管理器却发现——“未知设备”、“COM端口没出来”?调试日志收不到,固件也刷不进去&a…