面向图像分类的自监督/对比学习辅助的知识蒸馏-类别对比蒸馏(Category Contrastive Distillation, CCD) - 详解

news/2026/1/19 9:01:11/文章来源:https://www.cnblogs.com/ljbguanli/p/19499983

面向图像分类的自监督/对比学习辅助的知识蒸馏-类别对比蒸馏(Category Contrastive Distillation, CCD) - 详解

1 源码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast

class CategoryContrastiveDistillation(nn.Module):
"""
简化的类别对比蒸馏(CCD)损失模块。
参考自:Chen et al., Category contrastive distillation with self-supervised classification[citation:1]
和 StAlK 中利用均值教师进行特征对齐的思想[citation:6]。
"""
def __init__(self, num_classes, feat_dim, temperature=0.1, momentum=0.999):
super().__init__()
self.num_classes = num_classes
self.temperature = temperature
self.momentum = momentum

# 注册缓冲区,用于存储教师和学生的类别原型(记忆库)
self.register_buffer("teacher_prototype", torch.zeros(num_classes, feat_dim))
self.register_buffer("student_prototype", torch.zeros(num_classes, feat_dim))

# 初始化原型为随机值(实践中可用第一批数据初始化)
nn.init.normal_(self.teacher_prototype, mean=0, std=0.01)
nn.init.normal_(self.student_prototype, mean=0, std=0.01)

@torch.no_grad()
def _update_prototype(self, features, labels, prototype_bank):
"""动量更新类别记忆库"""
for idx in range(self.num_classes):
mask = (labels == idx)
if mask.any():
# 计算当前批次中该类所有样本特征的均值
class_feat_mean = features[mask].mean(dim=0)
# 动量更新:新原型 = momentum * 旧原型 + (1 - momentum) * 当前均值
prototype_bank[idx] = self.momentum * prototype_bank[idx] + (1 - self.momentum) * class_feat_mean
# 可选:对原型进行L2归一化,方便计算余弦相似度
prototype_bank[idx] = F.normalize(prototype_bank[idx].unsqueeze(0), dim=1).squeeze(0)

def forward(self, student_feat, teacher_feat, labels):
"""
计算类别对比蒸馏损失。
Args:
student_feat: 学生模型特征,形状 [batch_size, feat_dim]
teacher_feat: 教师模型特征,形状 [batch_size, feat_dim]
labels: 样本真实标签,形状 [batch_size]
Returns:
ccd_loss: 类别对比蒸馏损失
"""
batch_size = student_feat.shape[0]

# 1. 动量更新教师和学生的类别原型记忆库
self._update_prototype(teacher_feat.detach(), labels, self.teacher_prototype)
self._update_prototype(student_feat.detach(), labels, self.student_prototype)

# 2. 计算学生特征与所有教师原型的相似度(作为“软目标”)
# 相似度矩阵: [batch_size, num_classes]
# 这里使用点积相似度,假设特征和原型都已L2归一化
sim_student_to_teacher_proto = torch.mm(F.normalize(student_feat, dim=1),
F.normalize(self.teacher_prototype, dim=1).t())

# 3. 计算教师特征与所有教师原型的相似度(作为“软标签”)
sim_teacher_to_own_proto = torch.mm(F.normalize(teacher_feat.detach(), dim=1),
F.normalize(self.teacher_prototype, dim=1).t())

# 4. 应用温度缩放,将相似度转换为概率分布
student_dist = F.log_softmax(sim_student_to_teacher_proto / self.temperature, dim=1)
teacher_dist = F.softmax(sim_teacher_to_own_proto / self.temperature, dim=1)

# 5. 计算KL散度损失,让学生特征与原型的相似度分布接近教师特征与原型的分布
ccd_loss = F.kl_div(student_dist, teacher_dist, reduction='batchmean') * (self.temperature ** 2)

# 6. (可选) 引入一个辅助的“学生原型-教师原型”对齐损失
# 直接最小化学生原型和教师原型之间的距离,进一步稳定训练
proto_alignment_loss = F.mse_loss(self.student_prototype, self.teacher_prototype.detach())

# 总CCD损失是两项的加权和
total_ccd_loss = ccd_loss + 0.5 * proto_alignment_loss

return total_ccd_loss

# ============ 如何在主训练循环中使用 ============
# 假设已有:student_model, teacher_model, optimizer, dataloader
# num_classes = 100, feature_dim = 256
# ccd_criterion = CategoryContrastiveDistillation(num_classes=100, feat_dim=256, temperature=0.1)

# for images, labels in dataloader:
# images, labels = images.cuda(), labels.cuda()
#
# # 1. 前向传播,获取特征和logits
# # 假设模型返回一个元组 (logits, feature)
# student_logits, student_feat = student_model(images)
# with torch.no_grad(): # 教师模型不计算梯度
# # 教师模型通常使用学生模型的EMA参数[citation:6]
# teacher_logits, teacher_feat = teacher_model(images)
#
# # 2. 计算各项损失
# # a. 标准交叉熵损失
# ce_loss = F.cross_entropy(student_logits, labels)
#
# # b. 传统知识蒸馏损失 (软化标签)
# temp = 4.0
# kd_loss = F.kl_div(F.log_softmax(student_logits / temp, dim=1),
# F.softmax(teacher_logits / temp, dim=1),
# reduction='batchmean') * (temp * temp)
#
# # c. 类别对比蒸馏损失
# ccd_loss = ccd_criterion(student_feat, teacher_feat, labels)
#
# # 3. 组合总损失 (权重需要调参)
# lambda_kd = 0.5
# lambda_ccd = 1.0
# total_loss = ce_loss + lambda_kd * kd_loss + lambda_ccd * ccd_loss
#
# # 4. 反向传播与优化
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()
#
# # 5. (关键) 更新教师模型为学生的EMA
# # tau为EMA动量,例如0.999
# tau = 0.999
# for param_s, param_t in zip(student_model.parameters(), teacher_model.parameters()):
# param_t.data = tau * param_t.data + (1 - tau) * param_s.data

2 流程图与解析

流程关键点解读

  1. 双分支输入:同一张图像经过不同的数据增强(如强增强和弱增强),分别输入学生和教师模型。

  2. 教师模型更新:教师模型的参数通常采用学生模型参数的指数移动平均(EMA)获得,这是一个稳定知识源的关键技巧-6。

  3. 记忆库:教师和学生模型分别维护一个动态的类别记忆库,用于存储和更新每个类别的特征原型(通常使用当前批次的特征滑动平均更新)-1。

  4. 损失函数:总损失是多项损失的加权和,核心包括:

    • 传统知识蒸馏损失(L_kd):让学生模型的 softened logits 去匹配教师模型的。

    • 自监督对比损失(L_ssl):例如,让学生模型对同一图像不同增强视图的特征表示尽可能接近(正样本),而与其他图像的特征表示远离(负样本)。

    • 类别对比蒸馏损失(L_ccd):这是核心创新点。它计算学生特征与其所有教师类别原型的相似度分布,并与one-hot标签或教师预测分布进行对比损失计算,从而让学生特征向正确的教师类别原型靠近,并远离其他类别原型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1182189.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

柚坛工具箱 NT 实战手册:如何高效解决 Android 开发痛点

柚坛工具箱 NT 实战手册:如何高效解决 Android 开发痛点 【免费下载链接】UotanToolboxNT A Modern Toolbox for Android Developers 项目地址: https://gitcode.com/gh_mirrors/uo/UotanToolboxNT 在 Android 开发过程中,设备管理、刷机调试、应…

艾尔登法环存档安全迁移完全指南:5分钟掌握零风险备份技巧

艾尔登法环存档安全迁移完全指南:5分钟掌握零风险备份技巧 【免费下载链接】EldenRingSaveCopier 项目地址: https://gitcode.com/gh_mirrors/el/EldenRingSaveCopier 还在为艾尔登法环存档管理而困扰吗?游戏版本更新导致存档丢失?设…

喜马拉雅音频资源本地化终极指南:打造永不丢失的私人听书馆

喜马拉雅音频资源本地化终极指南:打造永不丢失的私人听书馆 【免费下载链接】xmly-downloader-qt5 喜马拉雅FM专辑下载器. 支持VIP与付费专辑. 使用GoQt5编写(Not Qt Binding). 项目地址: https://gitcode.com/gh_mirrors/xm/xmly-downloader-qt5 想要永久保…

WorkshopDL终极指南:3分钟学会免费下载Steam创意工坊模组

WorkshopDL终极指南:3分钟学会免费下载Steam创意工坊模组 【免费下载链接】WorkshopDL WorkshopDL - The Best Steam Workshop Downloader 项目地址: https://gitcode.com/gh_mirrors/wo/WorkshopDL 还在为无法访问Steam创意工坊而烦恼?WorkshopD…

比较好的盐城网站定制服务怎么联系?2026年专业指南 - 品牌宣传支持者

开篇:盐城网站定制行业背景与市场趋势随着数字化转型浪潮席卷全球,盐城作为江苏省重要的沿海中心城市,其企业对于专业网站定制服务的需求正呈现爆发式增长。2025年数据显示,盐城地区中小企业网站建设渗透率已达78.…

跨平台部署TTS有多简单?Supertonic镜像一键启动教程

跨平台部署TTS有多简单?Supertonic镜像一键启动教程 1. 引言:为什么需要设备端TTS解决方案? 在当前AI语音技术快速发展的背景下,文本转语音(Text-to-Speech, TTS)系统已广泛应用于智能助手、无障碍阅读、…

如何快速掌握国家自然科学基金LaTeX模板:面向科研新手的完整指南

如何快速掌握国家自然科学基金LaTeX模板:面向科研新手的完整指南 【免费下载链接】NSFC-application-template-latex 国家自然科学基金申请书正文(面上项目)LaTeX 模板(非官方) 项目地址: https://gitcode.com/GitHu…

如何将闲置电视盒子改造为专业Linux服务器:Armbian系统完整指南

如何将闲置电视盒子改造为专业Linux服务器:Armbian系统完整指南 【免费下载链接】amlogic-s9xxx-armbian amlogic-s9xxx-armbian: 该项目提供了为Amlogic、Rockchip和Allwinner盒子构建的Armbian系统镜像,支持多种设备,允许用户将安卓TV系统更…

2026年靠谱商品房装修公司排行榜,新测评精选欧式风格商品房装修推荐品牌 - 工业品牌热点

为帮业主高效锁定适配自身需求的商品房装修合作伙伴,避免选型走弯路,我们从设计落地能力(如风格还原度、功能实用性)、施工工艺水准(含标准化流程、质量管控)、全周期服务质量(覆盖前期设计到售后质保)、真实客…

Qwen3-4B-Instruct-2507隐私保护实施方案

Qwen3-4B-Instruct-2507隐私保护实施方案 1. 背景与挑战 随着大语言模型在企业服务、智能客服、内容生成等场景的广泛应用,数据隐私和安全合规问题日益突出。Qwen3-4B-Instruct-2507作为阿里开源的文本生成大模型,在提升通用能力的同时,也面…

AI工程学习路径:纸质与数字资源的最优配置方案

AI工程学习路径:纸质与数字资源的最优配置方案 【免费下载链接】aie-book [WIP] Resources for AI engineers. Also contains supporting materials for the book AI Engineering (Chip Huyen, 2025) 项目地址: https://gitcode.com/GitHub_Trending/ai/aie-book …

Lucy-Edit-Dev:文本指令轻松实现视频精准编辑

Lucy-Edit-Dev:文本指令轻松实现视频精准编辑 【免费下载链接】Lucy-Edit-Dev 项目地址: https://ai.gitcode.com/hf_mirrors/decart-ai/Lucy-Edit-Dev 导语:DecartAI团队发布开源视频编辑模型Lucy-Edit-Dev,首次实现纯文本指令驱动的…

USB通信中HID请求处理流程系统学习

深入理解HID请求处理:从USB枚举到报告交互的完整链路 你有没有遇到过这样的情况? 一个精心设计的自定义HID设备插上电脑后,系统却提示“未知USB设备”;或者报告描述符明明写好了,主机只读取了一半;又或者…

UI-TARS终极使用指南:零基础实现桌面自动化革命

UI-TARS终极使用指南:零基础实现桌面自动化革命 【免费下载链接】UI-TARS 项目地址: https://gitcode.com/GitHub_Trending/ui/UI-TARS 每天面对电脑重复点击相同的按钮、填写格式固定的表格、执行千篇一律的操作流程,你是否曾想过:这…

Midscene.js自动化测试实战:5大核心技术原理深度解析

Midscene.js自动化测试实战:5大核心技术原理深度解析 【免费下载链接】midscene Let AI be your browser operator. 项目地址: https://gitcode.com/GitHub_Trending/mid/midscene 你是否曾经为跨平台自动化测试的复杂性而头疼?Midscene.js作为一…

Qwen3-4B-Instruct-2507性能基准:吞吐量与延迟测试

Qwen3-4B-Instruct-2507性能基准:吞吐量与延迟测试 1. 引言 随着大模型在实际业务场景中的广泛应用,推理服务的性能表现成为决定用户体验和系统效率的关键因素。Qwen3-4B-Instruct-2507作为通义千问系列中面向高效部署场景的轻量级指令模型&#xff0c…

N_m3u8DL-RE完全指南:从零开始掌握流媒体下载

N_m3u8DL-RE完全指南:从零开始掌握流媒体下载 【免费下载链接】N_m3u8DL-RE 跨平台、现代且功能强大的流媒体下载器,支持MPD/M3U8/ISM格式。支持英语、简体中文和繁体中文。 项目地址: https://gitcode.com/GitHub_Trending/nm3/N_m3u8DL-RE 想要…

Qwen2.5-0.5B公共安全:应急问答系统

Qwen2.5-0.5B公共安全:应急问答系统 在公共安全领域,信息响应的及时性与准确性直接关系到应急处置效率。传统人工问答系统受限于人力和知识覆盖范围,难以满足突发场景下的高并发、多语言、结构化输出需求。随着轻量级大模型技术的发展&#…

终极图像差异检测工具odiff:快速发现像素级视觉差异

终极图像差异检测工具odiff:快速发现像素级视觉差异 【免费下载链接】odiff The fastest pixel-by-pixel image visual difference tool in the world. 项目地址: https://gitcode.com/gh_mirrors/od/odiff 在现代软件开发流程中,图像对比和视觉回…

2026年EPS泡沫优质厂家推荐,看哪家产品性价比高? - 工业品牌热点

2026年包装行业持续升级,EPS泡沫制品作为物流运输、电子防护的核心材料,其品质、成本与服务效率直接影响企业供应链稳定性与运营成本。无论是精密电子器件的缓冲防护、生鲜货物的保温运输,还是大宗货物的成本优化,…