完整教程:视觉Transformer实战——Vision Transformer(ViT)详解与实现

news/2025/12/1 18:56:22/文章来源:https://www.cnblogs.com/yangykaifa/p/19294363

视觉Transformer实战——Vision Transformer

    • 0. 前言
    • 1. ViT 技术原理
      • 1.1 核心思想
      • 1.2 使用 Transformer 处理图像数据
    • 2. ViT 关键组件
      • 2.1 图像分块
      • 2.2 patch 嵌入
      • 2.3 位置编码
      • 2.4 分类 token
    • 3. 使用 PyTorch 实现 ViT
      • 3.1 模型构建
      • 3.2 模型训练

0. 前言

在计算机视觉领域,卷积神经网络 (Convolutional Neural Network, CNN) 长期以来一直是处理图像任务的主流架构。然而,随着 Transformer 在自然语言处理领域的巨大成功,研究人员开始探索将这种基于自注意力机制的架构应用于视觉任务。Vision Transformer (ViT) 是这一探索的重要里程碑,它首次证明了纯 Transformer 架构在图像分类任务上可以超越最先进的 CNN 模型。本文将详细介绍 ViT 的技术原理,并使用 PyTorch 从零开始构建 ViT 模型用于图像分类任务。

1. ViT 技术原理

1.1 核心思想

Vision Transformer (ViT) 的核心思想是将图像分割成固定大小的小块 (patch),将这些 patch 线性嵌入后加上位置编码,然后像自然语言处理 (Natuarl Language Processing, NLP) 中的词元 (token) 一样将这些 patch 序列输入标准的 Transformer 编码器中进行处理。

1.2 使用 Transformer 处理图像数据

Transformer 非常擅长处理时间序列数据,图像在某种程度上也可以视为时间序列。例如,将图像分解成大小为 16 x 16 的小块,如果我们按顺序将这些图像块依次输入模型,那么这些块也具有序列格式。这与卷积神经网络非常相似,在卷积神经网络 (Convolutional Neural Network, CNN) 中,我们也将图像视为多个小块,并在块上应用卷积核(即创建一个卷积核并在图像上移动)。Transformer 会在在此基础上,增加一个基于全连接层的嵌入 (embedding) 层,这将使得每个块的大小不再是 16 x 16,而是该图像部分的密集表示,此外,还需要添加位置嵌入 (positional embedding)。
这些模型也可以仅包含编码器。例如,可以在每个操作的开头添加一个额外的词元,以创建整个图像的表示。在分类过程中,我们可以使用该词元将整个图像分类为给定的类别。ViT 架构如下图所示:

ViT架构

架构的其余部分与 Transformer 编码器块相同。ViT 架构的主要思想是分块并在图像块上应用位置嵌入。

2. ViT 关键组件

ViT 的成功依赖于几个精心设计的核心组件,这些组件共同实现了将 Transformer 架构有效应用于图像数据的创新方法。接下来,我们将深入剖析每个关键组件的设计原理和实现细节。

2.1 图像分块

Transformer 原本是为序列数据设计的,而图像是 2D 结构,图像分块 (Image Patching) 是将 2D 图像转换为 1D 序列的最直接方法,每个块 (patch) 相当于 NLP 中的一个 token。假设输入图像尺寸为 H × W × C (高度×宽度×通道),patch 大小为 P × P (通常 16×16),那么分块数量为 N=HW/P2N=HW/P^2N=HW/P2。可以通过使用卷积实现高效分块:

self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size)

较大的 patch 会丢失局部细节但计算效率高,较小的 patch 保留更多细节但增加序列长度。

2.2 patch 嵌入

patch 嵌入 (patch Embedding) 将每个 patch 展平并通过线性投影映射到 D 维空间,类似于 NLP 中的词嵌入,包括展平 patch (P×P×C→P2CP×P×C → P²CP×P×CP2C 维向量)和线性投影( P2C→DP²C → DP2CD,通常 D=768),在 PyTorch 中可以使用以下代码实现:

x = x.flatten(2).transpose(1,2)  # [B, N, P²C]
self.proj = nn.Linear(P²C, D)

除此之外,也可以直接使用卷积层实现。

2.3 位置编码

Transformer 本身是排列不变的,因此必须注入空间位置信息,不同于 Transformer 的固定编码,ViT 使用可学习的位置编码 (position Embedding),形状为 N+1 × D (Npatche + 1 个分类 token),在 PyTorch 中可以使用以下代码实现:

self.pos_embed = nn.Parameter(torch.zeros(1, N+1, D))
nn.init.trunc_normal_(self.pos_embed, std=0.02)

2.4 分类 token

分类 token (Class Token) 类似 BERT[CLS] token,用于分类任务,作为整个图像的表征,通过自注意力聚合全局信息,在 PyTorch 中可以使用以下代码添加分类 token

self.cls_token = nn.Parameter(torch.zeros(1, 1, D))

3. 使用 PyTorch 实现 ViT

接下来,下面我们将从零开始实现 ViT 模型,并使用 CIFAR-10 数据集训练模型。ViT 工作流程如下:

  • 输入图像 H×W×C
  • 分割为 NP×P×Cpatch (N=HW/P2N = HW/P²N=HW/P2)
  • 每个 patch 展平为 P2CP²CP2C 维向量
  • 通过线性投影映射到 D 维 (Patch Embedding)
  • 添加位置编码和分类 token
  • 输入 L 层的 Transformer 编码器
  • 使用分类 token 对应的输出进行分类

3.1 模型构建

(1) 首先,导入所需库:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm

(2) 将图像分割为 patch 并线性嵌入到 D 维空间:

class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# 使用卷积层实现patch分割和嵌入
self.proj = nn.Conv2d(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
# 输入x形状: [batch_size, in_channels, img_size, img_size]
# 输出形状: [batch_size, n_patches, embed_dim]
x = self.proj(x)  # [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]
x = x.flatten(2)  # [batch_size, embed_dim, n_patches]
x = x.transpose(1, 2)  # [batch_size, n_patches, embed_dim]
return x

(3) 实现位置编码:

class PositionEmbedding(nn.Module):
def __init__(self, n_patches, embed_dim, dropout=0.1):
super().__init__()
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))  # +1 for class token
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x形状: [batch_size, n_patches+1, embed_dim]
x = x + self.pos_embed # 添加位置编码
x = self.dropout(x)
return x

(4) 实现多头注意力机制:

class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"
self.qkv = nn.Linear(embed_dim, embed_dim * 3)  # 同时计算Q,K,V
self.attn_dropout = nn.Dropout(dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(self, x):
batch_size, n_patches, embed_dim = x.shape
# 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]
qkv = self.qkv(x).reshape(batch_size, n_patches, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]
out = attn @ v
out = out.transpose(1, 2).reshape(batch_size, n_patches, embed_dim)
# 线性投影和dropout
out = self.proj(out)
out = self.proj_dropout(out)
return out

(5) 实现多层感知机 (Multilayer Perceptron, MLP) 模块,自注意力机制后进行非线性特征变换和维度扩展/收缩:

class MLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x

(6) 实现 Transformer 编码器模块 TransformerBlock

class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(
in_features=embed_dim,
hidden_features=embed_dim * mlp_ratio,
out_features=embed_dim,
dropout=dropout
)
def forward(self, x):
# 残差连接和层归一化
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x

(7) 实现 ViT 模型:

class VisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
n_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
dropout=0.1
):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
# 分类token和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = PositionEmbedding(n_patches, embed_dim, dropout)
# Transformer编码器
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, n_classes)
# 初始化权重
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
batch_size = x.shape[0]
# 生成patch嵌入
x = self.patch_embed(x)  # [batch_size, n_patches, embed_dim]
# 添加class token
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)  # [batch_size, n_patches+1, embed_dim]
# 添加位置编码
x = self.pos_embed(x)
# 通过Transformer编码器
for block in self.blocks:
x = block(x)
# 分类
x = self.norm(x)
cls_token_final = x[:, 0]  # 只取class token对应的输出
x = self.head(cls_token_final)
return x

3.2 模型训练

(1) 实现模型训练与评估函数:

def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(dataloader, desc="Training"):
images, labels = images.to(device), labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计信息
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / len(dataloader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def evaluate(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(dataloader, desc="Evaluating"):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / len(dataloader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc

(2) 定义模型超参数:

img_size = 224
patch_size = 16
batch_size = 32
num_epochs = 20
learning_rate = 0.0001
num_classes = 10  # CIFAR-10有10个类别
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

(3) 加载 CIFAR-10 数据集,并进行数据预处理:

transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

(4) 初始化模型、损失函数和优化器:

model = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
n_classes=num_classes,
embed_dim=768,
depth=6,  # 减少深度以加快训练
num_heads=8
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

(5) 训练模型 20epoch

train_losses, train_accs = [], []
test_losses, test_accs = [], []
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
# 训练
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
train_accs.append(train_acc)
# 评估
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
test_losses.append(test_loss)
test_accs.append(test_acc)
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
print()

(6) 绘制模型训练过程中损失值和分类性能变化曲线:

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy')
plt.show()

模型性能
可以看到,从零开始训练的 ViTCIFAR-10 数据集上的准确率大约在 67% 左右,在小规模数据集上从头训练时,ViT 的表现通常不如 CNN,这是由于ViT 的核心是全局自注意力机制,它需要足够多的数据来学习长距离依赖关系,在小规模数据集(如 CIFAR-10,仅 5 万张 32×32 图像)上,ViT 容易过拟合,无法有效学习有意义的特征映射。而使用在 ImageNet 上预训练的 ViT 进行微调,在 CIFAR-10 上可达到 98.5% 的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/983314.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北京知名家事律所排名:专注婚姻家庭法律服务机构推荐

随着社会经济的发展,婚姻家庭领域的法律需求日益多元化,涉及离婚、财产分割、继承等复杂问题时,选择专业的家事律所成为许多当事人的重要考量。本文结合北京地区家事法律服务机构的专业能力、服务口碑等因素,整理了…

东城区离婚律师事务所推荐:本地婚姻家事法律服务机构参考

东城区作为北京核心区域,婚姻家事法律需求近年来呈现多样化趋势,涉及财产分割、子女抚养权、情感调解等多方面问题。选择专业的律师事务所提供支持,有助于在法律框架内妥善处理纠纷,维护当事人合法权益。以下结合机…

2025年市面上耐用的乳胶床垫厂商推荐几家

在选择乳胶床垫时,了解各个厂家和品牌的特点是至关重要的。这些厂家在生产过程中会严格把控材料质量,以确保床垫的耐用性及舒适性。选择知名的实力厂家,通常能获得更可靠的产品体验。此外,不同厂家在售后服务上的表…

朝阳区离婚律师事务所推荐:区域内专业机构参考

在社会发展过程中,婚姻家事领域的法律咨询需求逐渐受到关注。朝阳区作为北京的核心区域之一,聚集了多家专注于婚姻家事法律服务的专业机构。为帮助有相关需求的人士了解区域内的服务资源,以下从团队配置、业务经验等…

2025 年 12 月红木家具权威推荐榜:匠心实木与雅致软装,甄选传世家居臻品

2025 年 12 月红木家具权威推荐榜:匠心实木与雅致软装,甄选传世家居臻品 在追求生活品质与空间美学的当下,红木家具以其独特的材质、精湛的工艺与深厚的文化底蕴,成为高端家居市场不可或缺的组成部分。它不仅是一件…

朝阳区婚姻律师事务所推荐:聚焦家事法律服务的专业参考

在家庭关系的维护与纠纷解决中,婚姻家事法律服务的专业性与可靠性备受关注。朝阳区作为法律服务资源较为集中的区域,众多律师事务所在婚姻家事领域积累了丰富的实践经验,为有需求的当事人提供专业支持。一、推荐榜单…

北京十佳婚姻家事律师事务所综合实力解析

婚姻家事法律事务涉及个人情感与财产安全,选择专业的律师事务所至关重要。北京作为法律服务资源集中地,涌现出多家专注于婚姻家事领域的机构,为当事人提供从咨询到诉讼的全方位支持,其专业能力与服务质量成为公众关…

海淀区离婚律师事务所推荐:聚焦婚姻家事法律服务的机构参考

在海淀区,面对婚姻家事相关的法律需求时,选择专业的法律服务机构是许多人关注的重点。随着社会对婚姻家庭法律问题的重视,专注于离婚、财产分割、继承等领域的律师事务所逐渐成为关注焦点,其专业能力和服务质量直接…

海淀区婚姻律师事务所推荐:专注家事法律服务的机构盘点

在社会生活中,婚姻家事法律问题涉及个人情感与财产权益,选择专业的法律服务机构至关重要。海淀区作为北京的核心区域,聚集了众多专注于婚姻家事领域的律师事务所,为有需求的群体提供专业支持。以下结合服务特点与行…

SQLBot 达梦数据库访问配置手册

介绍 先快速了解一下 SQLBot。 SQLBot 是一款由飞致云 DataEase 开源团队出品、基于大语言模型(LLM)和 RAG(检索增强生成)技术的智能问数系统。它的核心价值在于,用户可以通过自然语言的方式直接向数据库“提问”…

2025年建筑加固技术权威推荐榜:碳纤维加固、粘钢加固,专业施工与持久安全双重保障

2025年建筑加固技术权威推荐榜:碳纤维加固、粘钢加固,专业施工与持久安全双重保障 随着我国城市化进程进入存量提质阶段,大量既有建筑因设计标准提升、功能改造、自然灾害或材料老化等原因,面临结构安全与性能提升…

西城区离婚律师事务所推荐:婚姻家事法律服务机构盘点

在现代社会,婚姻家事法律问题涉及情感、财产、子女等多方面复杂因素,选择专业的法律服务机构至关重要。西城区作为北京核心区域,聚集了多家专注于婚姻家事领域的律师事务所,它们凭借专业的法律知识和实践经验,为有…

上海十大留学中介

上海十大留学中介一、如何找留学中介作为一名拥有12年经验的国际教育规划师,我每天都会收到大量关于上海留学中介的咨询。家长们最常纠结的问题包括:上海留学中介哪家更靠谱?服务质量和专业度到底怎么看?网上众说纷…

2025年12月武汉废旧金属回收厂家权威推荐榜:不锈钢/钛钢,模具钢,废铁/废铜/废铝/铝合金/旧电缆/废旧物资回收与厂房拆除实力解析

2025年12月武汉废旧金属回收厂家权威推荐榜:不锈钢/钛钢,模具钢,废铁/废铜/废铝/铝合金/旧电缆/废旧物资回收与厂房拆除实力解析 随着全球循环经济战略的深入实施与“双碳”目标的持续推进,再生资源回收利用产业已从…

上海十大留学机构排名

上海十大留学机构排名一、上海留学机构怎么选?这些疑问你有吗?作为从事国际教育规划工作超过十年的专业人士,我经常遇到上海学生和家长咨询留学机构的选择问题。在准备这篇文章时,我参考了2025年最新发布的《中国留…

最高法--当事人基于合同解除的法律规定,可选择将合同中已约定“抵销”/抵消(即冲抵)的债务恢复原状

最高法--当事人基于合同解除的法律规定,可选择将合同中已约定“抵销”/抵消(即冲抵)的债务恢复原状2025-12-01 18:43 wwx的个人博客 阅读(0) 评论(0) 收藏 举报1. (2021)最高法民终675号 北京中集宏达房地产…

20232310 2025-2026-1 《网络与系统攻防技术》实验八实验报告

1.实验内容及要求 (1)Web前端HTML 能正常安装、启停Apache。理解HTML,理解表单,理解GET与POST方法,编写一个含有表单的HTML。 (2)Web前端javascipt 理解JavaScript的基本功能,理解DOM。 在(1)的基础上,编写Java…

香港比较靠谱的留学中介

香港比较靠谱的留学中介一、香港留学中介怎么选?这五大问题帮你避坑作为从事国际教育规划师工作已有十年的专业人士,我经常被学生和家长问及如何挑选香港的留学中介。今天,基于2025年11月20日的最新行业数据,我来为…

香港比较好的留学机构

香港比较好的留学机构一、如何挑选香港留学机构?五大疑问帮你理清思路作为一名从事国际教育规划工作超过15年的专业人士,我经常遇到学生和家长提出这样的困惑:香港留学机构究竟该怎么选?哪家中介更适合我的背景?申…

北京陪诊机构排名揭晓 守嘉陪诊以专业实力领跑行业

在“健康中国2030”规划纲要深入实施的背景下,陪诊服务作为优化医疗服务供给、改善群众就医体验的重要民生业态,其发展质量备受社会关注。为精准呈现北京地区陪诊机构的服务水平,为公众就医选择提供权威依据,近日,…