Transformer与ViT

news/2025/9/21 16:30:32/文章来源:https://www.cnblogs.com/zyh778/p/19103799

前言:

Transformer 结构非常重要,需要认真学习一遍
李沐老师课程
Transformer 论文
Transformer 代码
Transformer 自测题目
[Transformer 博客](Transformer/BERT/实战 | 冬于的博客 (ifwind.github.io))

一.Transformer :

Transformer基本结构

Transformer 摒弃了循环结构,并完全通过注意力机制完成对源语言序列和目标语言序列全局依赖的建模。在抽取每个单词的上下文特征时,Transformer 通过自注意力机制衡量上下文中每一个单词对当前单词的重要程度。在这个过程当中没有任何的循环单元参与计算。这种高度可并行化的编码过程使得模型的运行变得十分高效。
Transformer 的主要组件包括编码器(Encoder)、解码器(Decoder)和注意力层。其核心是利用多头自注意力机制(Multi-Head Self-Attention),使每个位置的表示不仅依赖于当前位置,还能够直接获取其他位置的表示。
从宏观角度来看,Transformer 的编码器是由多个相同的层叠加而成的,每个层都有两个子层(子层表示为 sublayer)。第⼀个子层是多头自注意力(multi-head self-attention)汇聚;第二个子层是基于位置的前馈网络(positionwise feed-forward network)。主要涉及到如下几个模块:

1. 位置编码:

image.png

本层主要是针对输入的文本序列,将文本中的单词分别转换为其对应的向量表示的方法,即对每一个单词创建一个向量表示。
值得注意的地方是,由于 Transformer 结构中没有循环结构,因此,序列中没有提示模型单词间相互位置和前后关系的信息。在送入编码器端建模其上下文语义之前,需要在嵌入表示层进行词嵌入时加入位置编码信息 Positional Encoding
具体来说,序列中每一个单词所在的位置都对应一个向量。这一向量会与单词表示对应相加并送入到后续模块中做进一步处理。在训练的过程当中,模型会自动地学习到如何利用这部分位置信息。为了得到不同位置对应的编码,Transformer 模型使用不同频率的正余弦函数如下所示:

\[PE_{(pos,2i)}=\sin(pos/10000^{2i/d})^{\cdots}\\PE_{(pos,2i+1)}=\cos(pos/10000^{2i/d}) \]

首先,正余弦函数的范围是在 \([-1,+1]\)导出的位置编码与原词嵌入相加不会使得结果偏离过远而破坏原有单词的语义信息

其中,pos 表示单词所在的位置,\(2i\)\(2i+1\) 表示位置编码向量中的对应维度,\(d\) 则对应位置编码的总维度。依据三角函数的基本性质,可以得知第 \(pos+k\) 个位置的编码是第 \(pos\) 个位置的编码的线性组合,这就意味着位置编码中蕴含着单词之间的距离信息

1.1. 位置编码的线性组合

在位置编码的上下文中,线性组合意味着一个位置的编码可以通过另一个位置编码的线性变换来获得。这通常通过正弦和余弦函数实现,具体如下:

  1. 正弦编码: \(𝑃𝐸(𝑝𝑜𝑠,2𝑖)=sin⁡(𝑝𝑜𝑠/10000^{2𝑖/𝑑_{model}})\) 这里,\(𝑃𝐸(𝑝𝑜𝑠,2𝑖)\)是位置 𝑝𝑜𝑠在第 \(2𝑖\)维的编码。
  2. 余弦编码: \(𝑃𝐸(𝑝𝑜𝑠,2𝑖+1)=cos⁡(𝑝𝑜𝑠/10000^{2𝑖/𝑑_{model}})\) 这里,\(𝑃𝐸(𝑝𝑜𝑠,2𝑖+1)​\) 是位置𝑝𝑜𝑠在第 \(2𝑖+1\) 维的编码。
    线性组合的示例
    假设我们有一个模型的维度 \(d_{\mathrm{model}}=512\),我们想要计算位置 \(pos=3\) 的编码。我们可以使用以下步骤:计算缩放因子 \(\frac1{10000^{2i/512}}\),对于 \(i=0\), \(i=1\), \(i=2\) 等。
    2.计算正弦和余弦编码:
    对于 \(i= 0\) : $$PE_{( 3, 0) }= \sin ( 3/ 10000^{2* 0/ 512}) = \sin ( 3)$$$$PE_(3,1)=\cos(3/10000^{20/512})=\cos(3)$$
    对于 \(i= 1{: }\) $$PE_{( 3, 2) }= \sin ( 3/ 10000^{2
    1/ 512})$$ $$PE_{(3,3)}=\cos(3/10000^{2*1/512})$$
    如果我们想要计算位置 \(pos+k\) 的编码,比如 \(pos+1\) (即位置 4),我们可以通过将位置 3 的编码乘以相应的缩放因子来获得:

\[PE_{( 4, 0) }= \sin ( 4/ 10000^{2* 0/ 512}) = \sin ( 3+ 1) = \sin(3)\cos(1)+\cos(3)\sin(1)$$$$PE_{(4,1)}=\cos(4/10000^{2*0/512})=\cos(3+1)=\cos(3)\cos(1)-\sin(3)\sin(1) \]

  • 相似编码表示近位置:如果 𝑝𝑜𝑠和𝑝𝑜𝑠+𝑘 的编码非常相似,这可能意味着 𝑘 较小,即这两个位置在句子中的距离较近。
  • 不同编码表示远位置:如果 𝑝𝑜𝑠和𝑝𝑜𝑠+𝑘 的编码差异较大,这可能意味着 𝑘 较大,即这两个位置在句子中的距离较远。
    https://zhuanlan.zhihu.com/p/121126531
    https://zhuanlan.zhihu.com/p/105001610

2. 自注意力层:

2.1. 基本概念:

image.png
在 Transformer 中使用的注意力层为自注意层(Self-Attention),本质可以被精炼地概括为“动态的、内容感知的加权求和”。其数学化形式为缩放点积注意力。
具体的,给定由单词语义嵌入及其位置编码叠加得到的输入表示 \(\{x_{i} \in R^{d}\}_{i=1}^{t}\),为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询 \(q_{i}(Query)\) \(k_{i}(Key)\) \(v_{i}(Value)\)。在编码输入序列中每一个单词的表示的过程中,这三个元素用于计算上下文单词所对应的权重得分。直观地说,这些权重反映了在编码当前单词的表示时,对于上下文不同部分所需要的关注程度。

Query、Key、Value 分别表示不同的概念角色:

  • Q (Query - 查询): 代表当前 Token 为了更新自身表示,主动发出的“提问”。
  • K (Key - 键): 代表序列中所有 Token(包括自身)为了被“检索”而展示出的“标签”。
  • V (Value - 值): 代表序列中所有 Token 实际包含的“内容”或信息。
  • \(QK^T\) 就是给每个词(Q)去问‘和其他每个词(K)有多匹配,点积越大,说明这个“键”与当前“查询”越匹配 → 表示这个词对你当前要表达的意思越重要,最后,模型根据这些“匹配程度”决定从哪些词获取更多信息(Value)
  • ✅ 注意:Q 和 K 都从同一个词生成! 这就是“自”注意力(self-attention)——自己跟自己比。
  • \(\frac{...}{\sqrt{d_k}}\): 缩放因子。当键向量的维度 \(d_k\) 较大时,点积结果的方差也会增大,可能导致 softmax 函数梯度爆炸以及收敛效率差。除以\(\sqrt{d_k}\)可以将其方差稳定在 1 附近,是保证深度 Transformer 能够稳定训练的关键技巧
  • softmax(...): 将原始的相似度分数转化为一个概率分布(权重),表示在当前查询下,应该给予每个 Value 多大的关注度。
  • \(...V\): 根据计算出的权重,对所有的 Value 进行加权求和,来聚合希望关注的上下文信息,并最小化不相关信息的干扰。得到最终的输出。这个输出是一个融合了整个序列信息的、针对当前 Token 的全新表示。

\[Z=Attention(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d}})V \]

其中 \(Q \in R^{L\times d_{q}}\), \(K \in R^{L\times d_{k}}\), \(V \in R^{L\times d_{v}}\) 分别表示输入序列中的不同单词的 \(q,k,v\) 向量拼接组成的矩阵,\(L\) 表示序列长度, \(Z \in R^{L\times d_{v}}\) 表示自注意力操作的输出。
image.png

2.2. 操作过程:

对于输入的文本的表示矩阵 X,实际上,自注意力机制中涉及到的三个元素:查询 \(q_{i}(Query)\) \(k_{i}(Key)\) \(v_{i}(Value)\),是由同样的输入矩阵 X 线性变换而来的。我们可以简单理解成:\(Q=XW^Q,K=XW^K,V=XW^V\),其可以用图像表示为:

X 矩阵中的每一行都对应输入句子中的一个单词,W 矩阵为权重矩阵。在了解以上内容后计算自注意力。

2.2.1. 创建三个向量:

根据编码器的每个输入向量(在本例中为每个单词的嵌入 \(X_i\))创建三个向量 \(Q_i,K_i,V_i\)。即为每个单词创建一个查询向量、一个关键向量和一个值向量。请注意,这些新向量的维度比嵌入向量小。它们的维度为 64,而嵌入和编码器输入输出向量的维度为 512。
具体操作为:将 x1 与 \(W^Q\) 权重矩阵相乘,得出 q1,即与该词相关的 "查询 "向量。最终,我们为输入句子中的每个单词创建了一个 "查询"、一个 "关键 "和一个 "值 "投影。
image.png

2.2.2. 计算得分:

假设我们要计算本例中第一个单词 "Thinking "的自注意力。我们需要将输入句子中的每个单词与这个单词进行对比。分数决定了我们在对某个位置的单词进行编码时,对输入句子中其他部分的关注程度。
分数的计算方法是将查询向量与我们要评分的单词的关键向量进行点乘。比如,如果我们处理的是 1 号位置单词的自注意力,那么第一个分数就是 \(q_1\)\(k_1\) 的点积。第二个分数是 \(q_1\)\(k_2\)点积
image.png

2.2.3. 处理得分:

将得分按照公式除以放缩因子 \(\sqrt{d_k}\) 以稳定优化,一般 \(d_k\) 是键的维度。,进行缩放后使用 softmax 运算进行归一化处理:
image.png
这个 softmax 分数决定了每个词在这个位置的表达量。显然,这个位置上的单词将拥有最高的 softmax 分数,但有时也可以用于关注与当前单词相关的另一个单词。

2.2.4. 处理选择值向量:

将每个值向量乘以 softmax 分数(准备求和)。这里主要是保持我们想要关注的单词的值不变,而忽略无关的单词(例如得分为 0.001 这样的小数的单词)。
将加权值向量相加。这将产生自注意层在该位置(第一个单词)的输出。
image.png
https://aicarrier.feishu.cn/wiki/SyROwpjaziB2tfkDImhcNzgCnts?from=from_copylink
The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time. (jalammar.github.io)
注意力机制到底在做什么,Q/K/V怎么来的?一文读懂Attention注意力机制-腾讯云开发者社区-腾讯云 (tencent.com)

2.3. 多头自注意力层

单一的注意力机制可能会让模型只学会关注一种“关系模式”。多头注意力(Multi-Head Attention)模块通过将 \(Q, K, V\) 线性投影到多个独立的“表示子空间”,让模型能够并行地学习多种关系。这使得模型可以同时关注不同层面的信息:一个头可能在学习局部纹理关系,另一个头在学习全局的形状轮廓,还有一个头在学习语义上的相关性。

3. 前馈层:

image.png

前馈层接受自注意力子层的输出作为输入,并通过一个带有 Relu 激活函数的两层全连接网络对输入进行更加复杂的非线性变换。实验证明,这一非线性变换会对模型最终的性能产生十分重要的影响。

\[FFN(x)=Relu(xW_{1}+b_{1})W_{2}+b_{2} \]

其中 \(W_{1},b_{1},W_{2},b_{2}\) 表示前馈子层的参数。另外,以往的训练发现,增大前馈子层隐状态的维度有利于提升最终翻译结果的质量,因此,前馈子层隐状态的维度一般比自注意力子层要大。

4. 残差连接与归一化:

由 Transformer 结构组成的网络结构通常都是非常庞大。编码器和解码器均由很多层基本的 Transformer 块组 成,每一层当中都包含复杂的非线性映射,这就导致模型的训练比较困难。因此,研究者们在 Transformer 块中进一步引入了残差连接与层归一化技术以进一步提升训练的稳定性。具体来说,残差连接主要是指使用一条直连通道直接将对应子层的输入连接到输出上去(这里可以类比残差神经网络的残差块帮助理解),从而避免由于网络过深在优化过程中潜在的梯度消失问题:

\[x^{l+1}=f(x^l)+x^l \]

其中 \(x^l\) 表示第 \(l\) 层的输入,\(f(\cdot)\) 表示一个映射函数。此外,为了进一步使得每一层的输入输出范围稳定在一个合理的范围内,层归一化技术被进一步引入每个 Transformer 块的当中:

\[LN(x)=\alpha \cdot \frac{x-\mu}{\sigma} + b \]

其中 \(\mu\)\(\sigma\) 分别表示均值和方差,用于将数据平移缩放到均值为 0,方差为 1 的标准分布,\(a\)\(b\) 是可学习的参数。层归一化技术可以有效地缓解优化过程中潜在的不稳定、收敛速度慢等问题。
图解Transformer系列三:Batch Normalization & Layer Normalization (批量&层标准化) - 掘金 (juejin.cn)

5. 解码器结构:

从网络架构来看,编码器的实现相对直接,而解码器的结构则更为复杂。这主要源于解码器在处理序列生成任务时所承担的自回归(Auto-regressive)角色。

5.1. 解码器核心组件

解码器的复杂性主要体现在其独特的注意力机制上,具体包括以下两个关键部分:

5.1.1. 掩码多头注意力 (Masked Multi-Head Attention)

解码器每个 Transformer 块的第一个自注意力子层增加了注意力掩码(Attention Mask)。

  • 机制: 在翻译等生成任务中,解码过程是自回归的。对于任意单词的生成,模型只能观测到该单词之前的序列信息。
  • 目的: 掩码的核心作用是遮蔽当前位置之后的文本信息,防止模型在训练时“看到”未来的答案,从而确保模型能够得到有效训练。

5.1.2. 交叉注意力 (Cross-Attention)

解码器在掩码多头注意力层之后,额外增加了一个用于实现交叉注意力的多头注意力模块。

  • 输入: 该模块接收两部分输入:
    • 查询 (Query): 来自解码器前一个掩码注意力层的输出。
    • 键 (Key) 与值 (Value): 来自编码器最终的输出表示。
  • 目的: 允许解码器在生成目标语言序列的每一步时,都能够“关注”到源语言序列的全部信息,从而生成与源文本内容相关且准确的目标序列。

5.1.3. 解码器输入 (Decoder Input)

与编码器类似,解码器也需要将输入序列(即目前已生成的目标语言序列)转换为嵌入向量,并添加位置编码。

  • 目标序列嵌入: 解码器的输入是目标语言序列的词嵌入。在训练阶段,这是真实的(或部分真实的)目标序列;在推理(生成)阶段,这是模型已经生成的部分序列。
  • 位置编码: 为了保留序列中词语的顺序信息,与编码器一样,位置编码也被添加到目标序列的嵌入向量中。
  • “向右偏移”的输入: 特别地,在训练和生成过程中,解码器的输入是“向右偏移”(shifted right)的。这意味着在预测当前词时,模型只能看到它之前的词,而不能看到当前词或之后的词。这通常通过在序列开头添加一个特殊的起始符(如 <s>)并在训练时将目标序列向右移动一位来实现,例如,如果目标是 "I am a student",解码器在预测 "I" 时,输入是 <s>;在预测 "am" 时,输入是 <s> I,以此类推。

5.2. 工作流程总结

  1. 编码阶段: 待翻译的源语言文本首先通过编码器,经过层层抽象,最终为每个源语言单词生成上下文相关的表示。
  2. 解码阶段: 解码器以自回归方式生成目标语言文本。具体而言:
    • 初始输入: 在生成第一个目标语言单词时,解码器接收一个特殊的起始符(如 <s>)作为输入。
    • 迭代生成: 在每个时间步 \(t\),模型会将前 \(t-1\) 个时刻已生成的目标语言单词(包括起始符)作为新的输入序列。
    • 内部处理: 这些输入首先经过掩码多头注意力层处理,确保模型只能关注到自身序列中当前位置之前的信息。
    • 结合源信息: 随后,其输出与编码器输出的源语言表示进行交叉注意力计算,从而允许解码器“关注”到源文本的全部上下文信息。
    • 预测输出: 经过多层解码器块处理后,最终的输出会送入一个线性层Softmax 层,预测当前时刻的目标语言单词。
    • 终止条件: 这个过程重复进行,直到模型生成一个结束符(如 <eos>)或达到预设的最大序列长度。

6. 输出层

解码器堆栈的输出是一个浮点数向量,它代表了模型对下一个词的预测。为了将这个向量转换为实际的单词,需要经过两个最终的层:

  1. 线性层 (Linear Layer):
    • 这是一个简单的全连接神经网络,它将解码器堆栈产生的向量,投影到一个非常非常大的向量上,这个向量被称为“logits 向量”
    • 为了更好地理解,我们假设模型从训练数据中学习了 10,000 个独立的单词(即模型的“输出词汇表”)。那么,这个 logits 向量的维度就会是 10,000。向量中的每一个单元格都对应着一个独立单词的得分(score)。通过这种方式,我们就解读了线性层之后模型的输出。
  2. Softmax 层 (Softmax Layer):
    • Softmax 层会将线性层输出的这些得分转换成概率。
    • 转换后的概率值都是正数,并且加起来总和为 1.0。每个概率值代表了词汇表中对应单词作为下一个词的可能性。
    • 最终,模型会选择概率最高的那个单元格,并输出其对应的单词,作为这个时间步的最终结果。
      image.png

二. ViT

ViT 的整体架构可以分为三大模块,核心思想是将图像分割成多个固定大小的小块(patches),并将这些小块视为序列输入到 Transformer Encoder 中。同时借鉴类似 BERT 的结构,ViT 在序列的前面增加了一个用于分类的 CLS token,最后在分类任务上可以使用 CLS token 学习到的语义特征通过一个 MLP Head 进行分类。

ViT 的三大模块包括:

  • Linear Projection of Flattened Patches: Embedding 层,包含 Patch Embedding、Position Embedding 和 CLS token
  • Transformer Encoder: 提取图像特征,由 N 个 Transformer Encoder Block 堆叠而成。
  • MLP Head: 分类头,使用 CLS token 作为输入,输出分类结果。

ViT 的工作原理图如下:
image.png

1. Linear Projection of Flattened Patches

image.png
image.png

1.1. Patch Embedding:

Transformer 本身的输入要求是二维矩阵 [num_token, token_dim],其中 \(num\_token\) 是序列长度,\(token\_dim\) 是序列中每个 token 的向量维度。图像的形状是三维矩阵 [H, W, C],表示长、宽和通道数。为了将图像转换为 Transformer 所需的格式,ViT 首先对图像进行了分块,将输入图像 [H, W, C] 按固定 Patch size [P, P] 划分为 \(H \times W / P^2\) 个不重叠的patches。
其中Patch size 直接决定模型的 token 数量 \(N\),从而影响计算量、内存需求以及表达能力。

Patch Size 类型 优势 劣势
较小的 Patch(如 8×8 或更小) 每个 token 覆盖更小的图像区域,实现更细粒度的特征提取,更容易捕捉边缘、纹理等局部细节。
增强表达能力,能更好地处理复杂图像结构。
- token 数 \(N\) 增多,使自注意力复杂度 \(O(N^2)\) 显著增加。
- 占用更多显存与计算资源,训练慢且消耗大。
较大的 Patch(如 16×16、32×32 或更大) token 数较少,计算更高效,训练资源消耗低。
对全局结构建模更直接,速度快。
- 粗粒度代表图像内容,可能遗漏精细特征(考虑到图像信息中常存在大量冗余,有时亦可选择较大的 Patch Size)。
- 某些任务性能下降,识别小物体、纹理等能力受限。
每个图像块会被展平(Flatten)为一个向量,然后通过一个线性层投影成固定维度的“补丁嵌入(Patch Embeddings)”。这些嵌入向量就相当于 Transformer 处理的“词向量”(Tokens)。
假设图像尺寸为 [12,12,1],使用 [4,4] 的 patch 大小进行划分,总共被划分为 \(12×12/4^2=9\) 个 patches,每个 patch 的形状为 [4, 4, 1]。得到分块操作后的 patches,每个 patch 在进入线性映射层之前会被展平(flatten),由三维向量的 patch 转换为一维向量 token,即形状从 [4, 4, 1] 变为 [16]。在实际代码中,这一步通常通过一个卷积层实现:使用 4×4 卷积核、步长 4 输出通道数 24。输入 [12, 12, 1] 经过该卷积层后,输出为 [3, 3,24],然后将其展平为顺序的 tokens,得到形状为 [9, 24] 的 token 序列,作为 Transformer 的输入。

1.2. Class Token

CLS Token (Classification Token): 在所有补丁嵌入序列的前面,我们通常会额外添加一个可学习的特殊 TOKEN,即 CLS Token。它不对应图像的任何部分,但其在 Transformer Encoder 的最后一层输出的对应向量,通常被用来代表整个图像的全局特征,用于下游的分类任务。
ViT 参考 BERT,在嵌入后的 9 个 patch token 前添加一个可训练的 [CLS] token,其维度与 patch token 相同(24 维),拼接后形成长度为 10 的 token 序列([10, 24])。该序列通过 Transformer Encoder 的双向自注意力机制交互融合,最终 [CLS] token 的输出聚合了整个序列的信息,可以代表整个序列的特征,也就是整个图像的特征,该表征被送入一个线性分类器,实现端到端的图像分类预测。

1.3. Positional Embeddings

由于 Transformer 不具备处理序列顺序的能力,ViT需要为每个补丁嵌入添加位置编码,以告知模型这些补丁在原始图像中的相对位置。
Vision Transformer不使用任何数学公式(如 sin/cos),而是直接将每个位置(0, 1, 2, ..., N)映射为一个可学习的向量。即采用可训练的 1D 绝对位置编码,与 Patch embedding 维度一致并逐元素相加,以此为 Transformer 提供空间位置先验。在本例子中,由于输入图像为 12×12,划分为 3×3=94×4 的 patch,加上一个 [CLS] token,序列总长度为 10,嵌入维度为 24,因此位置编码被设计为 shape=[10, 24] 的可学习参数矩阵。该设计成功保留了 patch 在原始图像中的空间顺序(沿行优先扫描),且无需复杂插值或 2D 编码,轻量高效,适合小规模视觉任务。
位置编码在模型初始化时随机生成,并在整个训练过程中通过反向传播学习。它被直接叠加(element-wise addition)到 token embedding 上:

\[ Final Input=Patch Embedding+Position Encoding \]

这种做法使得 Transformer Encoder 的自注意力机制能够感知不同 patch 在原始图像中的空间相对位置(如左上、中心、右下等),从而避免了图像结构信息在序列化过程中的丢失。
注意:位置编码是 Transformer 模型的一个参数(Parameter),在初始时给出的是随机值,它和注意力权重、MLP 权重、Embedding 矩阵一样,在训练过程中共同更新,目标是最小化任务损失。

编码方式 ViT 为什么不采用?
Sin/Cos 位置编码(如原始 Transformer) 图像 patch 的空间关系不是周期性函数,无法用 sin/cos 表达。直接学习更灵活、效果更好
2D 相对位置编码(如 Swin Transformer) 2D 位置编码更复杂,需要设计 2D 查询机制,增加计算负担,不适合作为 baseline。
无位置编码 实验结果显著变差 → 说明空间信息至关重要。

2. Transformer Encoder:

image.png
Transformer Encoder 主要用于特征提取,由多个 Encoder Block 堆叠而成,在 ViT-Base 中是 12 个 Block。每个 Encoder Block 包含两个部分:LayerNorm + MHA + 残差,以及 LayerNorm + MLP + 残差

  1. Layer Norm: 对每个 token 进行归一化处理,保持数值稳定。
  2. 多头自注意力(Multi-Head Self-Attention,MHA)。
    • 自注意力机制: 这是 ViT 的核心。它允许模型计算图像中任意两个补丁之间的相互依赖关系,无论它们在图像中的物理距离有多远。通过查询(Query)、键(Key)、值(Value)的计算,实现全局信息的聚合。详见 [[#2. 自注意力层:]]
    • 前馈网络: 对自注意力输出的每个 token 进行独立处理,增加模型的非线性表达能力。
  3. 残差连接: 残差连接提供梯度回传的高速通道,确保即使在很深的网络中,梯度也能良好地从监督信号回传到输入,从而避免优化难度。
  4. MLP Block: 由两个全连接层、GELU 激活函数和 Dropout 组成。第一个全连接层将输入维度扩大;第二个全连接层将其还原

3. MLP Head

image.png

Transformer Encoder 输出一个包含 [CLS] token 和所有 patch tokens 的序列表示,其形状为 \(B×L×D\),其中 \(L=10\)(1个 [CLS] + 9个 patch)、\(D=24\)。MLP Head 用其作为输入,输出分类的 logits,映射到最终的分类结果。MLP Head 的具体结构如下:

  1. 训练 ImageNet21K 时: Linear + tanh 激活函数 + Linear。
  2. 迁移到 ImageNet1K 或其他数据集时: 仅使用一个 Linear 层。

3.1. 基于 [CLS] Token 的分类头

采用标准 ViT 设计,仅取序列中第 0 位的 [CLS] token 作为全局图像表示:

\[\mathbf{v}_{\text{cls}} = \text{Encoder}(\mathbf{X})[0] \in \mathbb{R}^{24} \]

随后通过一个单层线性分类器(Linear Layer)映射至类别空间:

\[\mathbf{z} = \text{Linear}_{\text{cls}}(\mathbf{v}_{\text{cls}}) \in \mathbb{R}^{C} \]

其中 \(C\) 为类别数,最终输出经 Softmax 得到分类概率。

优势:结构极简、与 ViT 原始设计一致、参数极少(仅 \(24 \times C\))、便于部署。
⚠️ 局限:依赖 [CLS] token 在训练过程中是否有效“聚合”了视觉信息——该过程完全由 Transformer 注意力机制隐式完成。


3.2. 全局平均池化(Global Average Pooling, GAP)

放弃 [CLS] token,对全部 \(L-1 = 9\) 个 patch tokens 的最终输出进行逐维度平均池化

\[\mathbf{v}_{\text{gap}} = \frac{1}{9} \sum_{i=1}^{9} \text{Encoder}(\mathbf{X})[i] \in \mathbb{R}^{24} \]

再通过 LayerNorm + 单层 Linear 完成分类:

\[\mathbf{v}' = \text{LayerNorm}(\mathbf{v}_{\text{gap}}), \quad \mathbf{z} = \text{Linear}_{\text{gap}}(\mathbf{v}') \]

优势

  • 具有平移不变性,对 patch 位置轻微扰动更具鲁棒性;
  • 不依赖 [CLS] 的初始化和注意力建模质量,更具“数据驱动”特性。
  • patch tokens 均值聚合在分类任务上“常胜”于 [CLS] token。
    劣势
    忽略 [CLS] 潜在的高层语义聚合信息,仅依赖 pixel-level 特征分布,可能降低对全局结构的理解能力。

3.2.1. 多头注意力池化(Multi-head Attention Pooling, MAP)

引入一个可学习的查询向量 \(\mathbf{q}_{\text{map}} \in \mathbb{R}^{D}\),通过单头(或单层)多头注意力机制从 patch tokens 中动态聚合信息:

\[\mathbf{v}_{\text{map}} = \text{MultiheadAttn}(\mathbf{q}_{\text{map}}, \{\mathbf{p}_1, ..., \mathbf{p}_9\}, \{\mathbf{p}_1, ..., \mathbf{p}_9\}) \in \mathbb{R}^{24} \]

随后接 LayerNorm + Linear 分类器。

优势

  • 引入自适应注意力机制,能根据输入内容动态聚焦“关键 patch”,理论上优于简单平均。
    劣势
  • 增加一个长度为 \(D=24\) 的可学习参数,并引入 QKV 投影与注意力计算,**参数量与计算开销增加;
  • 训练不稳定,收敛较慢。

3.3. ViT 与传统 CNN 在特征提取上的根本性不同:

  • 感受野特性:
    • CNN: 具有局部感受野。通过多层卷积堆叠逐渐扩大感受野,但捕获远距离依赖关系效率较低,需要很深的网络才能看到全局。它的平移不变性(Translation Invariance)是内置的归纳偏置(Inductive Bias)。
    • ViT: 通过自注意力机制在第一层就能建立图像中所有补丁之间的关系,天然具备全局感受野。它没有内置的局部性和平移不变性偏置,而是通过在海量数据上学习来获得这些能力。
  • 架构范式:
    • CNN: 是一种分层、局部、逐级抽象的架构。特征图尺寸层层缩小,通道数层层增加。
    • ViT: 是一种扁平化序列处理的架构。图像被扁平化为一维序列,然后通过 Transformer 的并行处理进行全局特征提取。
  • 数据需求和可扩展性:
    • CNN: 在中小型数据集上表现卓越,其归纳偏置能够抵抗过拟合。
    • ViT: 对数据量有极高的要求,通常需要大规模预训练(如 JFT-300M, LAION-5B)才能发挥出优势。但在获得足够数据后,ViT 展现出比 CNN 更强的可扩展性和建模复杂关系的能力,特别是在处理需要全局上下文理解的复杂视觉任务时。”
  • 表达能力与泛化性
    • ViT 优势
      • 擅长捕捉全局关联与远程上下文,对遮挡、变形等更鲁棒。
      • 在大规模数据监督或自监督训练下的效果优异。
    • CNN 强项
      • 局部细节提取更精准,适合边缘检测与纹理模式的学习。
      • 适用于密集预测任务

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/908881.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

comfUI背后的技术——VAE - 实践

comfUI背后的技术——VAE - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco&qu…

实用指南:Maven、Spring Boot、Spring Cloud以及它们的相互关系

实用指南:Maven、Spring Boot、Spring Cloud以及它们的相互关系2025-09-21 16:28 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !imp…

【57页PPT】智慧高效的方案智慧医院信息化整体规划设计方案(附下载方式)

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

WordPress开放嵌入自动发现功能中的XSS漏洞分析

本文详细分析了WordPress中通过开放嵌入自动发现功能存在的XSS漏洞,包括postMessage()机制的安全问题、Safari浏览器的特殊行为以及完整的漏洞复现步骤,揭示了广泛使用的平台仍可能存在安全风险。WordPress开放嵌入自…

第二次软工作业

第二次软工作业软件工程第二次作业_个人项目 Github连接: mocheen/se_homework: homework ](https://github.com/mocheen/se_homework) 这个作业属于哪个课程 https://edu.cnblogs.com/campus/gdgy/Class12Grade23Comp…

20250921 模拟赛 T4 题解

Description https://zhengruioi.com/problem/3343?cid=1976 Solution 容易发现区间 LIS 满足四边形不等式,所以最终的答案关于划分段数是凸的。 设 \(d_i=f_i-f_{i-1}\)。那么由于 \(\sum d_i=n\) 且 \(d_i\) 不增,…

1.3 课前问题列表

1.什么样的方法应该用static修饰?不用static修饰的方法往往具有什么特性?Student的getName应该用static修饰吗? 1.通常是工具类方法、单例模式中获取单例对象的方法等应该用static修饰 2.不用static修饰的方法特性:…

NOIP 模拟赛十一

贪心+打表+数据结构+DPA. 倒序贪心即可。点击查看#include <bits/stdc++.h> #define lep(i, a, b) for (int i = a; i <= b; ++i) #define rep(i, a, b) for (int i = a; i >= b; --i) #define il inline …

Proxy 库解析(四)

test一切伟大的行动和思想,都有一个微不足道的开始。 There is a negligible beginning in all great action and thought.

warm-flow 监听器对象获取问题

初次使用warm-flow 实现了 Listener 接口,配置名字和路径也有写对,但监听器一直没启动,查看底层代码Listener listener = (Listener) FrameInvoker.getBean(clazz);在要执行监听器时,一直获取不到对象,很疑惑,打…

Hexo Butterfly 5.4 分页问题 YAML 错误 解决方法总结

Hexo Butterfly 5.4 分页问题 & YAML 错误 解决方法总结 本次问题核心是 “首页分页显示不全(仅 1、2…11)” 与 “hexo clean 报错 YAML 重复键”,最终通过配置文件修正 + 主题模板调整解决,具体步骤如下: 一…

FPGA硬件设计6 ZYNQ外围-HDMI、PCIE、SFP、SATA、FMC - 教程

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

js逆向:某Q音乐平台请求数据模拟生成

@目录1. 加密原理2. 参考代码内容仅供学习使用,不能用于商业活动,且不能在该网站高用户访问时频繁访问,以免对对应服务器造成影响。1. 加密原理 该音乐平台加密数据为如下图片这个:所加密的数据data和这篇文章里的…

第十一届中国大学生程序设计竞赛网络预选赛(CCPC Online 2025)

Preface最近因为队友要准备预推免,很久没有一起训练过了;我个人也是把大部分精力都放在科研方面,算是挺久没写代码了 同时因为这场撞了本校预推免的原因,导致学校很多队伍被迫重组,但好在我们队没受影响堪堪凑齐了…

完整教程:数据结构 栈和队列、树

完整教程:数据结构 栈和队列、树pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco&q…

深入解析:【ubuntu】ubuntu中找不到串口设备问题排查

深入解析:【ubuntu】ubuntu中找不到串口设备问题排查pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas&qu…

酵母双杂交技术:高通量筛选的突破与不可忽视的三大局限性

在后基因组时代,解析蛋白质相互作用网络已成为理解生命活动机制、挖掘疾病靶点的核心任务。酵母双杂交技术通过不断革新,已从 “一对一” 的简单互作验证,升级为 “组学水平” 的高通量筛选工具 —— 不仅能覆盖全基…

ubuntu20.04测试cuda

import torch# 1. 检查 PyTorch 版本 print("PyTorch 版本:", torch.__version__) # 应为 2.4.0# 2. 检查 CUDA 是否可用 print("CUDA 可用:", torch.cuda.is_available()) # 应为 True# 3. 检查…

Python lambda

Python lambda 漫思