第9.1讲、Tiny Encoder Transformer:极简文本分类与注意力可视化实战

项目简介

本项目实现了一个极简版的 Transformer Encoder 文本分类器,并通过 Streamlit 提供了交互式可视化界面。用户可以输入任意文本,实时查看模型的分类结果及注意力权重热力图,直观理解 Transformer 的内部机制。项目采用 HuggingFace 的多语言 BERT 分词器,支持中英文等多种语言输入,适合教学、演示和轻量级 NLP 应用开发。


主要功能

  • 多语言支持:集成 HuggingFace bert-base-multilingual-cased 分词器,支持 100+ 语言。
  • 极简 Transformer 结构:自定义实现位置编码、单层/多层 Transformer Encoder、分类头,结构清晰,便于学习和扩展。
  • 注意力可视化:可实时展示输入文本的注意力热力图和每个 token 被关注的占比,帮助理解模型关注机制。
  • 高效演示:训练时仅用 AG News 数据集的前 200 条数据,并只训练 10 个 batch,保证页面加载和交互速度。

代码结构与核心实现

1. 数据加载与预处理

使用 HuggingFace datasets 库加载 AG News 数据集,并用 BERT 分词器对文本进行编码:

from datasets import load_dataset
from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(200))  # 只用前200条数据def encode(example):tokens = tokenizer(example["text"],padding="max_length",truncation=True,max_length=64,return_tensors="pt")return {"input_ids": tokens["input_ids"].squeeze(0),"label": example["label"]}encoded_train = dataset["train"].map(encode)

2. Tiny Encoder 模型结构

模型包含词嵌入层、位置编码、若干 Transformer Encoder 层和分类头,支持输出每层的注意力权重:

import torch.nn as nnclass PositionalEncoding(nn.Module):# ... 位置编码实现,见下文详细代码 ...class TransformerEncoderLayerWithTrace(nn.Module):# ... 支持 trace 的单层 Transformer Encoder,见下文详细代码 ...class TinyEncoderClassifier(nn.Module):# ... 嵌入、位置编码、编码器堆叠、分类头,见下文详细代码 ...

3. 训练流程

采用交叉熵损失和 Adam 优化器,仅训练 10 个 batch,极大提升演示速度:

import torch.optim as optim
from torch.utils.data import DataLoadertrain_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)
model = TinyEncoderClassifier(...)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)model.train()
for i, batch in enumerate(train_loader):if i >= 10:  # 只训练10个batchbreakinput_ids = batch["input_ids"]labels = batch["label"]logits, _ = model(input_ids)loss = criterion(logits, labels)optimizer.zero_grad()loss.backward()optimizer.step()

4. Streamlit 可视化界面

  • 提供文本输入框,用户可输入任意文本。
  • 实时推理并展示分类结果。
  • 可视化 Transformer 第一层各个注意力头的权重热力图和每个 token 被关注的占比(条形图)。
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as pltuser_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:# ... 推理与注意力可视化代码,见下文详细代码 ...

训练与推理流程详解

  1. 数据加载与预处理

    • 加载 AG News 数据集,仅取前 200 条样本。
    • 用多语言 BERT 分词器编码文本,填充/截断到 64 长度。
  2. 模型结构

    • 词嵌入层将 token id 映射为向量。
    • 位置编码为每个 token 添加可区分的位置信息。
    • 堆叠若干 Transformer Encoder 层,支持输出注意力权重。
    • 分类头对第一个 token 的输出做分类(类似 BERT 的 [CLS])。
  3. 训练流程

    • 损失函数为交叉熵,优化器为 Adam。
    • 只训练 1 个 epoch,且只训练 10 个 batch,保证演示速度。
  4. 推理与可视化

    • 用户输入文本,模型输出预测类别编号。
    • 可视化注意力热力图和每个 token 被关注的占比,直观展示模型关注点。

适用场景

  • Transformer 原理教学与可视化演示
  • 注意力机制理解与分析
  • 多语言文本分类任务的快速原型开发
  • NLP 课程、讲座、实验室演示

完整案例说明:


Tiny Encoder

1. 代码主要功能

该脚本实现了一个基于 Transformer Encoder 的文本分类模型,并通过 Streamlit 提供了可视化界面,
支持输入一句话并展示模型的分类结果及注意力权重热力图。

2. 主要模块说明

  • Tokenizer 初始化
    • 使用 HuggingFace 的多语言 BERT Tokenizer 对输入文本进行分词和编码。
  • 模型结构
    • 包含词嵌入层、位置编码、若干 Transformer Encoder 层(带注意力权重 trace)、分类器。
  • 数据处理与训练
    • 加载 AG News 数据集,编码文本,训练模型并保存。
    • 若已存在训练好的模型则直接加载。
  • Streamlit 可视化
    • 提供文本输入框,实时推理并展示分类结果。
    • 可视化 Transformer 第一层各个注意力头的权重热力图。

3. 数据流向说明

  1. 输入
    • 用户在 Streamlit 网页输入一句英文(或多语言)文本。
  2. 分词与编码
    • Tokenizer 将文本转为固定长度的 token id 序列(input_ids)。
  3. 模型推理
    • input_ids 输入 TinyEncoderClassifier,经过嵌入、位置编码、若干 Transformer 层,输出 logits(分类结果)和注意力权重(trace)。
  4. 分类输出
    • 取 logits 最大值作为类别预测,显示在网页上。
  5. 注意力可视化
    • 取第一层注意力权重,分别绘制每个 head 的热力图,帮助理解模型关注的 token 关系。

4. 适用场景

  • 适合教学、演示 Transformer 注意力机制和文本分类原理。
  • 可扩展用于多语言文本分类任务。

import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt# ============================
# 位置编码模块
# ============================
class PositionalEncoding(nn.Module):"""位置编码模块:为输入的 token 序列添加可区分位置信息。使用正弦和余弦函数生成不同频率的编码。"""def __init__(self, d_model, max_len=512):super().__init__()# 创建一个 (max_len, d_model) 的全零张量,用于存储位置编码pe = torch.zeros(max_len, d_model)# 生成位置索引 (max_len, 1)position = torch.arange(0, max_len).unsqueeze(1)# 计算每个维度对应的分母项(不同频率)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))# 偶数位置用 sin,奇数位置用 cospe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 增加 batch 维度,形状变为 (1, max_len, d_model)pe = pe.unsqueeze(0)# 注册为 buffer,模型保存时一同保存,但不是参数self.register_buffer('pe', pe)def forward(self, x):"""输入:x,形状为 (batch, seq_len, d_model)输出:加上位置编码后的张量,形状同输入"""return x + self.pe[:, :x.size(1)]# ============================
# 单层 Transformer Encoder,支持输出注意力权重
# ============================
class TransformerEncoderLayerWithTrace(nn.Module):"""单层 Transformer Encoder,支持输出注意力权重。包含多头自注意力、前馈网络、残差连接和层归一化。"""def __init__(self, d_model, nhead, dim_feedforward):super().__init__()# 多头自注意力层self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)# 前馈网络第一层self.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(0.1)# 前馈网络第二层self.linear2 = nn.Linear(dim_feedforward, d_model)# 层归一化self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)# Dropout 层self.dropout1 = nn.Dropout(0.1)self.dropout2 = nn.Dropout(0.1)def forward(self, src, trace=False):"""前向传播。参数:src: 输入序列,形状为 (batch, seq_len, d_model)trace: 是否返回注意力权重返回:src: 输出序列attn_weights: 注意力权重(如果 trace=True)"""# 多头自注意力,attn_weights 形状为 (batch, nhead, seq_len, seq_len)attn_output, attn_weights = self.self_attn(src, src, src, need_weights=trace)# 残差连接 + 层归一化src2 = self.dropout1(attn_output)src = self.norm1(src + src2)# 前馈网络src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))# 残差连接 + 层归一化src = self.norm2(src + self.dropout2(src2))# 返回输出和注意力权重(可选)return src, attn_weights if trace else None# ============================
# Tiny Transformer 分类模型
# ============================
class TinyEncoderClassifier(nn.Module):"""Tiny Transformer 分类模型:包含嵌入层、位置编码、若干 Transformer 编码器层和分类头。支持输出每层的注意力权重。"""def __init__(self, vocab_size, d_model, n_heads, d_ff, num_layers, max_len, num_classes):super().__init__()# 词嵌入层,将 token id 映射为向量self.embedding = nn.Embedding(vocab_size, d_model)# 位置编码模块self.pos_encoder = PositionalEncoding(d_model, max_len)# 堆叠多个 Transformer 编码器层self.layers = nn.ModuleList([TransformerEncoderLayerWithTrace(d_model, n_heads, d_ff) for _ in range(num_layers)])# 分类头,对第一个 token 的输出做分类self.classifier = nn.Linear(d_model, num_classes)def forward(self, input_ids, trace=False):"""前向传播。参数:input_ids: 输入 token id,形状为 (batch, seq_len)trace: 是否输出注意力权重返回:logits: 分类输出 (batch, num_classes)traces: 每层的注意力权重(可选)"""# 词嵌入x = self.embedding(input_ids)# 加位置编码x = self.pos_encoder(x)traces = []# 依次通过每一层 Transformer 编码器for layer in self.layers:x, attn = layer(x, trace=trace)if trace:traces.append({"attn_map": attn})# 只取第一个 token 的输出做分类(类似 BERT 的 [CLS])logits = self.classifier(x[:, 0])return logits, traces if trace else None# ============================
# 模型构建与训练函数,显式使用CPU
# ============================
@st.cache_resource(show_spinner=False)
def build_and_train_model(d_model, n_heads, d_ff, num_layers):device = torch.device('cpu')  # 显式指定使用CPUtokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")dataset = load_dataset("ag_news")dataset["train"] = dataset["train"].select(range(200))  # 只用前200条数据MAX_LEN = 64def encode(example):tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")return {"input_ids": tokens["input_ids"].squeeze(0), "label": example["label"]}encoded_train = dataset["train"].map(encode)encoded_train.set_format(type="torch")train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)model = TinyEncoderClassifier(vocab_size=tokenizer.vocab_size,d_model=d_model,n_heads=n_heads,d_ff=d_ff,num_layers=num_layers,max_len=MAX_LEN,num_classes=4).to(device)  # 模型放到CPUcriterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)model.train()for epoch in range(1):  # 训练1个epochfor i, batch in enumerate(train_loader):if i >= 10:  # 只训练10个batchbreakinput_ids = batch["input_ids"].to(device)  # 输入转到CPUlabels = batch["label"].to(device)logits, _ = model(input_ids)loss = criterion(logits, labels)optimizer.zero_grad()loss.backward()optimizer.step()return model, tokenizer# ============================
# Streamlit 页面设置
# ============================
st.set_page_config(page_title="TinyEncoder")
st.title("🌍 Tiny Encoder Transformer")# 固定模型参数
# d_model: 隐藏层维度,
# n_heads: 注意力头数,
# d_ff: 前馈层维度,
# num_layers: Transformer 层数
d_model = 64
n_heads = 2
d_ff = 128
num_layers = 1# 构建并训练模型
with st.spinner("模型构建中..."):model, tokenizer = build_and_train_model(d_model, n_heads, d_ff, num_layers)# ============================
# 推理与注意力权重可视化
# ============================
model.eval()
device = torch.device('cpu')
model.to(device)user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:tokens = tokenizer(user_input, return_tensors="pt", max_length=64, padding="max_length", truncation=True)input_ids = tokens["input_ids"].to(device)  # 放CPUwith torch.no_grad():logits, traces = model(input_ids, trace=True)pred_class = torch.argmax(logits, dim=-1).item()st.markdown(f"### 🔍 预测类别编号: `{pred_class}`")if traces:attn_map = traces[0]["attn_map"]if attn_map is not None:seq_len = input_ids.shape[1]token_list = tokenizer.convert_ids_to_tokens(input_ids[0])if '[PAD]' in token_list:valid_len = token_list.index('[PAD]')else:valid_len = seq_lentoken_list = token_list[:valid_len]if attn_map.dim() == 4:# [batch, heads, seq_len, seq_len]heads = attn_map.size(1)fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))if heads == 1:axes = [axes]for i in range(heads):matrix = attn_map[0, i][:valid_len, :valid_len].cpu().detach().numpy()sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)axes[i].set_title(f"Head {i}")axes[i].tick_params(labelsize=6)# 显示每个 token 被关注的占比attn_sum = matrix.sum(axis=0)attn_ratio = attn_sum / attn_sum.sum()fig2, ax2 = plt.subplots(figsize=(5, 2))ax2.bar(range(valid_len), attn_ratio)ax2.set_xticks(range(valid_len))ax2.set_xticklabels(token_list, rotation=90, fontsize=6)ax2.set_title(f"Head {i} Token Attention Ratio")st.pyplot(fig2)st.pyplot(fig)elif attn_map.dim() == 3:# [heads, seq_len, seq_len]heads = attn_map.size(0)fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))if heads == 1:axes = [axes]for i in range(heads):matrix = attn_map[i][:valid_len, :valid_len].cpu().detach().numpy()sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)axes[i].set_title(f"Head {i}")axes[i].tick_params(labelsize=6)# 显示每个 token 被关注的占比attn_sum = matrix.sum(axis=0)attn_ratio = attn_sum / attn_sum.sum()fig2, ax2 = plt.subplots(figsize=(5, 2))ax2.bar(range(valid_len), attn_ratio)ax2.set_xticks(range(valid_len))ax2.set_xticklabels(token_list, rotation=90, fontsize=6)ax2.set_title(f"Head {i} Token Attention Ratio")st.pyplot(fig2)st.pyplot(fig)elif attn_map.dim() == 2:# [seq_len, seq_len]fig, ax = plt.subplots(figsize=(5, 3))sns.heatmap(attn_map[:valid_len, :valid_len].cpu().detach().numpy(), ax=ax, cbar=False, xticklabels=token_list, yticklabels=token_list)ax.set_title("Attention Map")ax.tick_params(labelsize=6)st.pyplot(fig)# 显示每个 token 被关注的占比matrix = attn_map[:valid_len, :valid_len].cpu().detach().numpy()attn_sum = matrix.sum(axis=0)attn_ratio = attn_sum / attn_sum.sum()fig2, ax2 = plt.subplots(figsize=(5, 2))ax2.bar(range(valid_len), attn_ratio)ax2.set_xticks(range(valid_len))ax2.set_xticklabels(token_list, rotation=90, fontsize=6)ax2.set_title("Token Attention Ratio")st.pyplot(fig2)else:st.warning("注意力权重维度异常,无法可视化。")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/84000.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java】泛型在 Java 中是怎样实现的?

先说结论 , Java 的泛型是伪泛型 , 在运行期间不存在泛型的概念 , 泛型在 Java 中是 编译检查 运行强转 实现的 泛型是指 允许在定义类 , 接口和方法时使用的类型参数 , 使得代码可以在不指定具体类型的情况下操作不同的数据类型 , 从而实现类型安全的代码复用 的语言机制 . …

linux如何查找软连接的实际地址

在Linux系统中,查找软连接(符号链接,即symbolic link)的实际地址可以通过多种方法实现。软连接是一个特殊的文件类型,它包含了一个指向另一个文件或目录的引用。要找到软连接所指向的实际文件或目录,可以使…

Token类型与用途详解:数字身份的安全载体图谱

在现代数字身份体系中,Token如同"数字DNA",以不同形态流转于各类应用场景。根据Okta的最新研究报告,平均每个企业应用使用2.7种不同类型的Token实现身份验证和授权。本文将系统梳理主流Token类型及其应用场景,通过行业典…

火山 RTC 引擎9 ----集成 appkey

一、集成 appkey 1、网易RTC 初始化过程 1)、添加头文件 实现互动直播 - 互动直播 2.0网易云信互动直播产品的基本功能包括音视频通话和连麦直播,当您成功初始化 SDK 之后,您可以简单体验本产品的基本业务流程,例如主播加入房间…

详细介绍Qwen3技术报告中提到的模型架构技术

详细介绍Qwen3技术报告中提到的一些主流模型架构技术,并为核心流程配上相关的LaTeX公式。 这些技术都是当前大型语言模型(LLM)领域为了提升模型性能、训练效率、推理速度或稳定性而采用的关键组件。 1. Grouped Query Attention (GQA) - 分组…

光电效应理论与实验 | 从爱因斯坦光量子假说到普朗克常量测定

注:本文为“光电效应”相关文章合辑。 英文引文,机翻未校。 中文引文,略作重排,未整理去重。 图片清晰度受引文原图所限。 如有内容异常,请看原文。 Photoelectric Effect 光电效应 Discussion dilemma Under the…

Visual Studio 2019/2022:当前不会命中断点,还没有为该文档加载任何符号。

1、打开调试的模块窗口,该窗口一定要在调试状态下才会显示。 vs2019打开调试的模块窗口 2、Visual Studio 2019提示未使用调试信息生成二进制文件 未使用调试信息生成二进制文件 3、然后到debug目录下看下确实未生成CoreCms.Net.Web.WebApi.pdb文件。 那下面的…

打破性能瓶颈:用DBB重参数化模块优化YOLOv8检测头

文章目录 引言DBB 重参数化模块简介DBB 的优势 YOLOv8 检测头的结构分析使用 DBB 模块魔改检测头替换策略代码实现改进后的效果预期 实验与验证总结与展望 引言 在目标检测领域,YOLO 系列算法一直以其高效的检测速度和不错的检测精度受到广泛关注。随着版本的不断更…

如何成为更好的自己?

成为更好的自己是一个持续成长的过程,需要结合自我认知、目标规划和行动力。以下是一些具体建议,帮助你逐步提升: 1. 自我觉察:认识自己 反思与复盘:每天花10分钟记录当天的决策、情绪和行为,分析哪些做得…

免费使用GPU的探索笔记

多种有免费时长的平台 https://www.cnblogs.com/java-note/p/18760386 Kaggle免费使用GPU的探索 https://www.kaggle.com/ 注册Kaggle账号 访问Kaggle官网,使用邮箱注册账号。 发现gpu都是灰色的 返回home,右上角的头像点开 验证手机号 再次code-you…

CSS- 4.2 相对定位(position: relative)

本系列可作为前端学习系列的笔记,代码的运行环境是在HBuilder中,小编会将代码复制下来,大家复制下来就可以练习了,方便大家学习。 HTML系列文章 已经收录在前端专栏,有需要的宝宝们可以点击前端专栏查看! 点…

如何使用Antv X6使用拖拽布局?

拖拽效果图 拖拽后 布局预览 官方: X6 图编辑引擎 | AntV 安装依赖 # npm npm install antv/x6 --save npm install antv/x6-plugin-dnd --save npm install antv/x6-plugin-export --save需要引入的代码 import { Graph, Shape } from antv/x6; import { Dnd } …

数据库健康监测器(BHM)实战:如何通过 HTML 报告识别潜在问题

在数据库运维中,健康监测是保障系统稳定性与性能的关键环节。通过 HTML 报告,开发者可以直观查看数据库的运行状态、资源使用情况与潜在风险。 本文将围绕 数据库健康监测器(Database Health Monitor, BHM) 的核心功能展开分析,结合 Prometheus + Grafana + MySQL Export…

PCB设计实践(二十四)PCB设计时如何避免EMI

PCB设计中避免电磁干扰(EMI)是一项涉及电路架构、布局布线、材料选择及制造工艺的系统工程。本文从设计原理到工程实践,系统阐述EMI产生机制及综合抑制策略,覆盖高频信号控制、接地优化、屏蔽技术等核心维度,为高密度、…

嵌入式硬件篇---陀螺仪|PID

文章目录 前言1. 硬件准备主控芯片陀螺仪模块电机驱动电源其他2. 硬件连接3. 软件实现步骤(1) MPU6050初始化与数据读取(2) 姿态解算(互补滤波或DMP)(3) PID控制器设计(4) 麦克纳姆轮协同控制4. 主程序逻辑5. 关键优化与调试技巧(1) 传感器校准(2) PID参数整定先调P再调D最后…

【Linux基础I/O】文件调用接口、文件描述符、重定向和缓冲区

【Linux基础I/O一】文件描述符和重定向 1.C语言的文件调用接口2.操作系统的文件调用接口2.1open接口2.2close接口2.3write接口2.4read接口 3.文件描述符fd的本质4.标准输入、输出、错误5.重定向5.1什么是重定向5.2输入重定向和输出重定向5.3系统调用的重定向dup2 6.缓冲区 1.C语…

鸿蒙HarmonyOS 【ArkTS组件】通用属性-背景设置

📑往期推文全新看点(附带最新鸿蒙全栈学习笔记) 嵌入式开发适不适合做鸿蒙南向开发?看完这篇你就了解了~ 鸿蒙岗位需求突增!移动端、PC端、IoT到底该怎么选? 分享一场鸿蒙开发面试经验记录(三面…

【76. 最小覆盖子串】

Leetcode算法练习 笔记记录 76. 最小覆盖子串 76. 最小覆盖子串 滑动窗口的hard题目,思路先找到第一个覆盖的窗口,不断缩小左边界,找到更小的窗口并记录。 思路很简单,写起来就不是一会事了,看题解看了几个h&#xff0…

Spring事务简单操作

什么是事务? 事务是一组操作的集合,是一个不可分割的操作 事务会把所有的操作作为⼀个整体, ⼀起向数据库提交或者是撤销操作请求. 所以这组操作要么同时 成功, 要么同时失败. 事务的操作 分为三步: 1. 开启事start transaction/ begin …

Rust 学习笔记:关于错误处理的练习题

Rust 学习笔记:关于错误处理的练习题 Rust 学习笔记:关于错误处理的练习题想看到回溯,需要把哪个环境变量设置为 1?以下哪一项不是使用 panic 的好理由?以下哪一项最能描述为什么 File::open 返回的是 Result 而不是 O…