注意力蒸馏技术

文章目录

  • 摘要
  • abstract
  • 论文摘要
  • 简介
  • 方法
    • 预备知识
    • 注意力蒸馏损失
    • 注意力引导采样
  • 实验
  • 结论
  • 总结
  • 参考文献

摘要

本周阅读了一篇25年二月份发表于CVPR 的论文《Attention Distillation: A Unified Approach to Visual Characteristics Transfer》,论文开发了Attention Distillation引导采样,这是一种改进的分类器引导方法,将注意力蒸馏损失整合到去噪过程中,大大加快了合成速度,并支持广泛的视觉特征迁移和合成应用。

abstract

This week I read a paper published in CVPR in February, "Attention Distillation: A Unified Approach to Visual Characteristics Transfer, this paper develops the Attention Distillation guided sampling, which is an improved classifier guided method to integrate the attention distillation loss into the denoising process. It greatly speeds up synthesis and supports a wide range of visual feature migration and synthesis applications.

下图中是给定参考图,文生图的示例:
在这里插入图片描述

论文摘要

最近扩散模型方面的进展显示了对图像风格和语义的内在理解。论文提出了一种新颖的注意力蒸馏损失,通过在潜在空间中反向传播来优化合成图像,同时改进了一个分类器引导,它将注意力蒸馏损失集成到去噪采样过程中,进一步加速合成过程。
在这里插入图片描述

简介

论文解决问题: 现有生成扩散模型在图像风格和语义理解方面虽然有进展,但在将参考图像的视觉特征转移到生成图像中时,使用即插即用注意力特征的方法存在局限性。

传统的方法通常将纹理定义为重复的局部模式,并通过从源图像中复制局部补丁来合成新的纹理。通常归结为以下三个原因导致的局限性:

  1. 域差距:当两幅图像存在显著差异时,目标Q(合成图像的查询)与参考图像的K,V之间的相似性较低且不可靠,导致错误的聚合结果(AdaIN和注意力能缓解这个问题)
  2. 误差积累:虽然扩散模型中的迭代采样过程可以改善目标Q和参考图中的K,V之间的巨大差异,但误差也可能积累。来自不同扩散模型层的特征集中于不同的信息,如语义和几何。不正确的匹配将会错误传播到马尔科夫链的后续层,并降低最终图像质量。
  3. 框架限制:在去噪网络的剩余分支内实现自注意力机制,参考图像中的自注意力特征可能对目标图像有潜在的影响,降低了合成的效力。
    为了解决上述局限性,本篇论文中引入一种新的注意力蒸馏损失AD loss,在此基础上,通过反向传播直接更新合成的图像。

提出方案: 首先,提出了一种新颖的注意力蒸馏损失,用于在理想和当前风格化结果之间计算损失,并在隐空间中通过反向传播优化合成图像。其次,开发了一种改进的分类器引导方法,即注意力蒸馏引导采样,将注意力蒸馏损失整合到去噪采样过程中。

方法

预备知识

隐空间扩散模型(LDM),如Stable Diffusion,由于其对复杂数据分布的强大建模能力,在图像生成方面达到了最先进的性能。在LDM中,首先使用预训练的VAE 将图像x压缩到一个学习到的隐空间中。随后,基于UNet的去噪网络被训练用于在扩散过程中预测噪声,通过最小化预测噪声与实际添加噪声之间的均方误差来实现。
L L D M = E z ∼ E ( x ) , y , ϵ ∼ N ( 0 , 1 ) , t [ ∥ ϵ θ ( z t , t , y ) − ϵ ∥ 2 2 ] \mathcal{L}_{\mathrm{LDM}}=\mathbb{E}_{z\sim\mathcal{E}(x),y,\epsilon\sim\mathcal{N}(0,1),t}\left[\|\epsilon_\theta(z_t,t,y)-\epsilon\|_2^2\right] LLDM=EzE(x),y,ϵN(0,1),t[ϵθ(zt,t,y)ϵ22]
其中 y 表示条件, 表示时间步长。去噪 UNet 通常由一系列卷积块和自注意力/交叉注力模块组成,所有这些都集成在残差架构的预测分支中。
KV注入在图像编辑、风格迁移和纹理合成中被广泛使用。它建立在自注意力机制之上,并将扩散模型中的自注意力特征用作即插即用的属性。自注意力机制的公式为:
S e l f − A t t n ( Q , K , V ) = s o f t m a x ( Q K T d ) V \mathrm{Self-Attn}(Q,K,V)=\mathrm{softmax}(\frac{QK^{T}}{\sqrt{d}})V SelfAttn(Q,K,V)=softmax(d QKT)V
在注意力机制的核心,是基于查询Q和键K之间的相似性计算权重矩阵,该矩阵用于对值V进行加权聚合。KV注入通过在不同的合成分支之间复制或共享KV特征来扩展这一机制。其关键假设是KV特征代表图像的视觉外观。在采样过程中,将合成分支中的KV特征替换为示例的相应时间步长的KV特征,可以实现从源图像到合成目标的外观转移。

注意力蒸馏损失

尽管KV注入取得了显著的效果,但由于残差机制的影响,它在保留参考的风格或纹理细节方面表现不足;例如,下图(a)中。KV注入仅作用于残差,这意味着信息流(红色箭头)随后受到恒等连接的影响,导致信息传递不完整。因此,采样输出无法完全再现所需的视觉细节。
在这里插入图片描述
本论文通过在自注意力机制中重新聚合特征来提取视觉元素。利用预训练的T2I扩散模型SD的UNet,从自注意力模块中提取图像特征。
在这里插入图片描述
上图中,首先根据目标分支的Q,从参考分支重新聚合KV特征(Ks和Vs)的视觉信息,这与KV注入相同。
将此注意力输出视为理想的风格化。然后,我们计算目标分支的注意力输出,并计算相对于理想注意力输出的L1损失,这定义了AD损失:
L A D = ∥ S e l f − A t t n ( Q , K , V ) − S e l f − A t t n ( Q , K s , V s ) ∥ 1 \mathcal{L}_{\mathrm{AD}}=\|\mathrm{Self-Attn}(Q,K,V)-\mathrm{Self-Attn}(Q,K_{s},V_{s})\|_{1} LAD=SelfAttn(Q,K,V)SelfAttn(Q,Ks,Vs)1
可以使用提出的AD损失通过梯度下降来优化随机隐空间噪声,从而在输出中实现生动的纹理或风格再现;例如,参见上图(b)。这归因于优化中的反向传播,它不仅允许信息在(残差)自注意力模块中流动,还通过恒等连接流动。通过持续优化,Q和Ks之间的差距逐渐缩小,使得注意力越来越准确,最终特征被正确聚合以产生所需的视觉细节。

注意力引导采样

将注意力蒸馏损失以改进的分类器引导方式纳入扩散模型的采样过程中。
分类器引导在去噪过程中改变去噪方向,从而生成来自p(zt|c)的样本,其公式可以表示为:
ϵ ^ θ = ϵ θ ( z t , t , y ) − α σ t ∇ z t log ⁡ p ( c ∣ z t ) \hat{\epsilon}_\theta=\epsilon_\theta(z_t,t,y)-\alpha\sigma_t\nabla_{z_t}\log p(c|z_t) ϵ^θ=ϵθ(zt,t,y)ασtztlogp(czt)
其中,t是时间步长,y表示提示, ϵ θ \epsilon_\theta ϵθ z t \ {z_t}  zt分别指去噪网络和LDM中的隐空间变量。 α \alpha α控制引导强度。使用基于注意力蒸馏损失的能量函数来引导扩散采样过程。

实验

由于补丁来源有限,使用传统方法合成超高分辨率纹理非常困难。在此,将注意力蒸馏引导的采样应用于MultiDiffusion模型,使纹理扩展到任意分辨率。尽管SD-1.5是在尺寸为512×512的图像上训练的,但令人惊讶的是,当结合注意力蒸馏时,它在大尺寸纹理合成中表现出了强大的能力。下图展示了将纹理扩展到512×1536的尺寸与GCD和GPDM的比较。
在这里插入图片描述
损失函数代码

def ad_loss(q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None
):loss = 0attn_mask = Nonefor q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list):if source_mask is not None and target_mask is not None:w = h = int(np.sqrt(q.shape[2]))mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w)))mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w)))attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1)attn_mask=attn_mask.to(q.device)target_out = F.scaled_dot_product_attention(q * scale,torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),attn_mask=attn_mask)loss += loss_fn(self_out, target_out.detach())return lossdef q_loss(q_list, qc_list):loss = 0for q, qc in zip(q_list, qc_list):loss += loss_fn(q, qc.detach())return loss# weight = 200
def qk_loss(q_list, k_list, qc_list, kc_list):loss = 0for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list):scale_factor = 1 / math.sqrt(q.size(-1))self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1)target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1)loss += loss_fn(self_map, target_map.detach())return loss# weight = 1
def qkv_loss(q_list, k_list, vc_list, c_out_list):loss = 0for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list):self_out = F.scaled_dot_product_attention(q, k, vc)loss += loss_fn(self_out, target_out.detach())return loss

下面这段代码主要通过自适应特征提取和优化,将内容图像的潜变量 (latents) 调整为具有风格图像特征的潜变量,实现风格迁移(Style Transfer)或风格控制 (Style-Adaptive Denoising, AD)。
1.使用了一种基于 AdaIN (Adaptive Instance Normalization) 的方法对 latents 进行风格调整:

if self.adain:noise = torch.randn_like(self.style_latent)style_latent = self.scheduler.add_noise(self.style_latent, noise, t)latents = utils.adain(latents, style_latent)

2.提取风格和内容特征:

qs_list, ks_list, vs_list, s_out_list = self.extract_feature(self.style_latent,t,self.null_embeds_for_style,add_noise=True,
)
if self.content_latent is not None:qc_list, kc_list, vc_list, c_out_list = self.extract_feature(self.content_latent,t,self.null_embeds,add_noise=True,)

3.优化 latents 使其匹配风格和内容特征:

optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
optimizer = self.accelerator.prepare(optimizer)

在 iters 轮优化中,计算损失 (style_loss 和 content_loss),并进行反向传播:

for j in range(iters):style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)if self.content_latent is not None:content_loss = q_loss(q_list, qc_list)loss = style_loss + content_loss * weightself.accelerator.backward(loss)optimizer.step()

结论

这篇论文提出了一种统一的方法来处理各种视觉特征转移任务,包括风格/外观转移、特定风格的图像生成和纹理合成。该方法的关键是一种新颖的注意力蒸馏损失,它计算理想风格化与当前风格化之间的差异,并逐步修改合成。

总结

这篇论文提出了一种基于注意力蒸馏(Attention Distillation, AD)的新方法,用于改进扩散模型在视觉特征迁移任务中的表现。作者引入注意力蒸馏损失(AD Loss),通过反向传播优化合成图像,使其更好地匹配目标风格。此外,论文提出注意力蒸馏引导采样,将AD Loss整合到去噪过程中,加快图像合成速度,并提升细节保真度。实验表明,该方法在风格迁移、特定风格图像生成和纹理合成等任务中均优于现有技术,特别是在高分辨率纹理生成方面表现突出。该方法通过改进查询-键-值(Q-K-V)特征聚合,有效缓解域差距、误差积累和框架限制问题。

参考文献

[1] Attention Distillation: A Unified Approach to Visual Characteristics Transfer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/76146.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flutter android端抓包工具

flutter做的android app,使用fiddler抓不了包,现介绍一款能支持flutter的抓包工具Reqable,使用方法如下: 1、下载电脑端安装包 下载地址为【https://reqable.com/zh-CN/download/】 2、还是在上述地址下载 android 端apk&#xf…

PyTorch单机多卡训练(DataParallel)

PyTorch单机多卡训练 nn.DataParallel 是 PyTorch 中用于多GPU并行训练的一个模块,它的主要作用是将一个模型自动拆分到多个GPU上,并行处理输入数据,从而加速训练过程。以下是它的核心功能和工作原理: 1、主要作用 数据并行&am…

PyTorch中的Tensor

PyTorch中的Tensor‌ 是核心数据结构,类似于 NumPy 的多维数组,但具备 GPU 加速和自动求导等深度学习特性。 一、基本概念 ‌核心数据结构‌ Tensor 是存储和操作数据的基础单元,支持标量(0D)、向量(1D&am…

基于Python的图书馆信息管理系统研发

标题:基于Python的图书馆信息管理系统研发 内容:1.摘要 在数字化信息快速发展的背景下,传统图书馆管理方式效率低下,难以满足日益增长的信息管理需求。本研究旨在研发一款基于Python的图书馆信息管理系统,以提高图书馆信息管理的效率和准确性…

RCE复现

1.过滤flag <?php error_reporting(0); if(isset($_GET[c])){$c $_GET[c];if(!preg_match("/flag/i", $c)){eval($c);}}else{highlight_file(__FILE__);代码审计过滤了"flag"关键词&#xff0c;但限制较弱&#xff0c;容易绕过 ?csystem("ls&…

FPGA_YOLO(四) 部署yolo HLS和Verilog 分别干什么

首先,YOLO作为深度学习模型,主要包括卷积层、池化层、全连接层等。其中,卷积层占据了大部分计算量,尤其适合在FPGA上进行并行加速。而像激活函数(如ReLU)和池化层相对简单,可能更容易用HLS实现。FPGA的优势在于并行处理和定制化硬件加速,因此在处理这些计算密集型任务时…

自动化发布工具CI/CD实践Jenkins介绍!

1. 认识Jenkins 1.1 Jenkins是什么&#xff1f; Jenkins 是一个开源的自动化服务器&#xff0c;主要用于持续集成和持续部署&#xff08;CI/CD&#xff09;。 它由Java编写&#xff0c;因此它可以在Windows、Linux和macOS等大多数操作系统上运行。 Jenkins 提供了一个易于使用…

【愚公系列】《高效使用DeepSeek》039-政务工作辅助

🌟【技术大咖愚公搬代码:全栈专家的成长之路,你关注的宝藏博主在这里!】🌟 📣开发者圈持续输出高质量干货的"愚公精神"践行者——全网百万开发者都在追更的顶级技术博主! 👉 江湖人称"愚公搬代码",用七年如一日的精神深耕技术领域,以"…

深度学习篇---模型训练评估参数

文章目录 前言一、Precision&#xff08;精确率&#xff09;1.1定义1.2意义1.3数值接近11.4数值再0.5左右1.5数值接近0 二、Recall&#xff08;召回率&#xff09;2.1定义2.2意义2.3数值接近12.4数值在0.5左右2.5数值接近0 三、Accuracy&#xff08;准确率&#xff09;3.1定义3…

Windows 图形显示驱动开发-WDDM 2.4功能-GPU 半虚拟化(十一)

注册表设置 GPU虚拟化标志 GpuVirtualizationFlags 注册表项用于设置半虚拟化 GPU 的行为。 密钥位于&#xff1a; DWORD HKLM\System\CurrentControlSet\Control\GraphicsDrivers\GpuVirtualizationFlags 定义了以下位&#xff1a; 位描述0x1 ​ 为所有硬件适配器强制设置…

Vue 的 nextTick 是如何实现的?

参考答案&#xff1a; nextTick 的本质将回调函数包装为一个微任务放入到微任务队列&#xff0c;这样浏览器在完成渲染任务后会优先执行微任务。 nextTick 在 Vue2 和 Vue3 里的实现有一些不同&#xff1a; 1. Vue2 为了兼容旧浏览器&#xff0c;会根据不同的环境选择不同包装策…

安卓开发之LiveData与DataBinding

LiveData——生命周期感知 LiveData 是 Android Jetpack 提供的一个生命周期感知的数据持有者类&#xff0c;它可以用于持有数据并在数据发生变化时通知观察者。LiveData 常与 ViewModel 配合使用&#xff0c;帮助简化 UI 层和数据层之间的交互&#xff0c;确保 UI 在合适的生…

TCP协议与wireshark抓包分析

一、tcp协议格式 1. 源端口号 &#xff1a; 发送方使用的端口号 2. 目的端口号 &#xff1a; 接收方使用的端口号 3. 序号: 数据包编号 &#xff0c; tcp 协议为每个数据都设置编号,用于确认是否接收到相应的包 4. 确认序列号 : 使用 tcp 协议接收到数据包&#xff0c…

《HelloGitHub》第 108 期

兴趣是最好的老师&#xff0c;HelloGitHub 让你对开源感兴趣&#xff01; 简介 HelloGitHub 分享 GitHub 上有趣、入门级的开源项目。 github.com/521xueweihan/HelloGitHub 这里有实战项目、入门教程、黑科技、开源书籍、大厂开源项目等&#xff0c;涵盖多种编程语言 Python、…

VITA 模型解读,实时交互式多模态大模型的 pioneering 之作

写在前面:实时交互llm 今天回顾一下多模态模型VITA,当时的背景是OpenAI 的 GPT-4o 惊艳亮相,然而,当我们将目光投向开源社区时,却发现能与之匹敌的模型寥寥无几。当时开源多模态大模型(MLLM),大多在以下一个或多个方面存在局限: 模态支持不全:大多聚焦于文本和图像,…

VLAN的高级特性

前言&#xff1a; 1&#xff1a;华为VLAN聚合通过逻辑分层设计&#xff0c;将广播域隔离与子网共享结合&#xff0c;既解决了IP地址浪费问题&#xff0c;又实现了灵活的网络管理 2&#xff1a;MUX VLAN&#xff08;Multiplex VLAN&#xff09;提供了一种通过VLAN进行网络资源控…

制作cass高程点块定义——cad c#二次开发——待调试

public class Demo{[CommandMethod("xx")]public void Demo1(){using var tr1 new DBTrans();var doc Application.DocumentManager.MdiActiveDocument; var db doc.Database;var ed doc.Editor;var 圆心 new Point3d(0, 0, 0); var 半径 10.0;using (var tr …

pod几种常用状态

在 Kubernetes 中&#xff0c;Pod 是最小的可部署单元&#xff0c;Pod 的状态反映了其当前的运行状况。以下是几种常见的 Pod 状态&#xff1a; 1. Pending 描述: Pod 已被 Kubernetes API Server 接收并创建&#xff0c;但还没有开始运行在任何节点上。原因: Pod 资源不足&a…

04 单目标定实战示例

看文本文,您将获得以下技能: 1:使用opencv进行相机单目标定实战 2:标定结果参数含义和数值分析 3:Python绘制各标定板姿态,查看图像采集多样性 4:如果相机画幅旋转90,标定输入参数该如何设置? 5:图像尺寸缩放,标定结果输出有何影响? 6:单目标定结果应用类别…

DevEco Studio编辑器的使用-代码code Linter检查

Code Linter代码检查 Code Linter针对ArkTS/TS代码进行最佳实践/编程规范方面的检查。检查规则支持配置&#xff0c;配置方式请参考配置代码检查规则。 开发者可根据扫描结果中告警提示手工修复代码缺陷&#xff0c;或者执行一键式自动修复&#xff0c;在代码开发阶段&#x…