引言
在自然语言处理和序列生成任务中,自注意力机制(Self-Attention)是提升模型性能的关键技术。本文将通过一个自定义的PyTorch模型实现,展示如何构建一个结合多头注意力与前馈网络的序列生成模型(如文本或字符生成)。该模型通过创新的 MaxStateSuper
模块实现动态特征融合,适用于字体生成、文本预测等场景。
技术背景
1. 模型结构解析
核心组件
-
MaxStateSuper(自注意力模块)
- 功能:通过多头注意力机制提取序列中的关键特征,并结合累积最大值操作增强长期依赖建模。
- 实现亮点:
- 合并三个线性层为一个
combined
层,减少参数冗余。 - 使用
torch.cummax
实现动态状态积累,提升序列记忆能力。
- 合并三个线性层为一个
-
FeedForward(前馈网络)
- 结构:两层全连接网络,中间夹杂
ReLU
激活函数和门控机制(gate
)。 - 作用:非线性变换,增强模型表达能力。
- 结构:两层全连接网络,中间夹杂
-
DecoderLayer(解码器层)
- 创新点:
- 引入
alpha
参数平衡前馈网络输出与原始输入的权重,实现动态特征融合。 - 层归一化(
LayerNorm
)确保梯度稳定性。
- 引入
- 创新点:
-
SamOut(整体模型)
- 输入:字符或token的Embedding向量。
- 输出:预测的下一时刻token概率分布。
2. 关键技术
- 多头注意力机制:通过
heads
参数将特征空间划分为多个子空间,提升模型对不同模式的捕捉能力。 - 累积最大值操作:
out2 = torch.cummax(out2, dim=2)[0]
保留序列中的关键特征轨迹。 - 动态参数平衡:
alpha
参数通过梯度下降自动学习前馈网络与原始输入的权重比例。
代码实现
完整代码
import torch
import torch.nn as nn
import torch.optim as optimclass MaxStateSuper(nn.Module):def __init__(self, dim_size, heads):super().__init__()self.heads = headsassert dim_size % heads == 0, "Dimension size must be divisible by head size."self.combined = nn.Linear(dim_size, 3 * dim_size, bias=False) # 合并QKV线性层def forward(self, x):b, s, d = x.shape# 合并后的线性变换并分割为QKVqkv = self.combined(x).chunk(3, dim=-1)q, k, v = qkv# 调整形状并执行注意力计算# ...(此处省略具体注意力计算逻辑,参考标准多头注意力实现)...return out, stateclass FeedForward(nn.Module):def __init__(self, hidden_size):super().__init__()self.ffn1 = nn.Linear(hidden_size, hidden_size)self.ffn2 = nn.Linear(hidden_size, hidden_size)self.gate = nn.Linear(hidden_size, hidden_size)self.relu = nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = x1 * x2return self.ffn2(xx)class DecoderLayer(nn.Module):def __init__(self, hidden_size, num_heads):super().__init__()self.self_attn = MaxStateSuper(hidden_size, num_heads)self.ffn = FeedForward(hidden_size)self.norm = nn.LayerNorm(hidden_size)self.alpha = nn.Parameter(torch.tensor(0.5)) # 动态平衡参数def forward(self, x):attn_out, _ = self.self_attn(x)ffn_out = self.ffn(attn_out)x = self.norm(self.alpha * ffn_out + (1 - self.alpha) * x)return xclass SamOut(nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super().__init__()self.embedding = nn.Embedding(voc_size, hidden_size, padding_idx=3)self.layers = nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])self.head = nn.Linear(hidden_size, voc_size, bias=False)def forward(self, x):x = self.embedding(x)for layer in self.layers:x = layer(x)return self.head(x)# 训练流程(简化版)
if __name__ == '__main__':voc_size = 10000 # 假设词汇表大小model = SamOut(voc_size, hidden_size=256, num_heads=8, num_layers=6)criterion = nn.CrossEntropyLoss(ignore_index=3)optimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(10):# 假设 input_tensor 和 target_tensor 已准备output = model(input_tensor)loss = criterion(output.view(-1, voc_size), target_tensor.view(-1))loss.backward()optimizer.step()
关键步骤解析
1. MaxStateSuper
模块的创新点
# 合并QKV层
qkv = self.combined(x<