Supervised Contrastive 损失函数详解

在这里插入图片描述
有什么不对的及时指出,共同学习进步。(●’◡’●)

有监督对比学习将自监督批量对比方法扩展到完全监督设置,能够有效地利用标签信息。属于同一类的点簇在嵌入空间中被拉到一起,同时将来自不同类的样本簇推开。这种损失显示出对自然损坏很稳健,并且对优化器和数据增强等超参数设置更稳定。

有监督对比学习论文的贡献

  1. 提出了对比损失函数一种新的扩展,允许每个锚点都有多个正样本,使对比学习适应完全监督设置。
  2. 该损失为很多数据集的top-1的准确率带来了提升,对自然损坏有稳健性。
  3. 损失函数的梯度鼓励从硬正样本和硬的负样本中学习。(硬的正样本与锚点图像不相似的正样本,硬的负样本就是与锚点图像相似的负样本,都是难以学习的那种)
  4. 对比损失函数不如交叉熵损失函数对超参数敏感。

自监督对比学习损失
在这里插入图片描述
有监督对比学习损失
在这里插入图片描述
文中对交叉熵损失训练,自监督对比损失训练和有监督对比损失训练进行比较
在这里插入图片描述
推理模型中的参数个数始终保持不变,应该是推理的时候就是编码器+分类头都一样。
上图是训练的时候,交叉熵损失不必说。
自监督损失一般采用的是个体判别代理任务,正样本是自身经过数据增强后的图像(一般一个正样本),其他的都是负样本,训练编码器的时候让正样本和锚点图像经过编码器得到的特征尽可能接近,与负样本之间的特征尽可能拉远。
有监督对比学习,有标签信息,正样本除了自身数据增强后的之外还有这个类别中的其他样本(一般这个batch_size中)。
stage1就是训练编码器。
stage2是训练分类头,作者指出不需要训练线性分类器,并且先前的工作已经使用k -最近邻分类或原型分类来评估分类任务上的表示。线性分类器也可以与编码器联合训练,只要不将梯度传播回编码器即可,就是分类头和编码器之间训练要分开。
有监督对比学习损失代码
对比学习对比的是特征,所以损失函数的输入是特征,有监督对比学习损失还要输入标签信息。
损失函数就是模型的输出和标签(这里是mask)之间的差距,输出和标签差距越大,那么loss就越大。
输出这里是编码器的输出就是特征,标签就是类别标签。标签是如何起作用的呢?就是让损失函数区分这个batchsize中的正负样本,属于同一类就是正样本,其他都是负样本。
其中标签mask怎么获得,一个是通过label,另一个直接输入。label是每个数据的类别信息,label.view(1,-1)变成列向量然后再与它的转置进行torch.eq(),得到一个矩阵mask,mask(i,j)如果第i个数据和第j个数据类别相同那么这个位置是True,否则为False,float就变成0,1。后面乘了一个对角线元素为0,其他位置元素为1的矩阵,就是不让每个feature与自身对比。
我们看它self.contrast_mode="one"的时候只是比较feature中第0个特征(也就是平常的第一个特征),那么锚点特征就是所有数据的第0个特征;"all"就是所有的特征都要对比;锚点特征就是所有数据的所有特征。 torch.cat(torch.unbind(features, dim=1), dim=0)把feature按照第1维拆开,然后在第0维上cat,然后比较的feature的形式就是每一个数据的第1个特征|每个数据的第2个特征|…|每个数据的第n个特征,排列,这些特征是排在一起的在一个维度上。锚点特征要么是输入特征组的每个数据的第0个特征要么就是这些比较的特征。(不太理解为什么one的时候比较特征还是所有的)
锚点特征与比较特征的转置相乘,得到的就是batch_size*channel个相似矩阵,每两个数据在这个特征下的相似度。然后这个相似度矩阵要和我们得到的mask进行比较,就是上面的第二个式子。
下面是详细解释。

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_functionimport torch
import torch.nn as nnclass SupConLoss(nn.Module):"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.It also supports the unsupervised contrastive loss in SimCLR"""def __init__(self, temperature=0.07, contrast_mode='all',base_temperature=0.07):super(SupConLoss, self).__init__()self.temperature = temperatureself.contrast_mode = contrast_mode#设置对比的模式有one和all两种,代表对比一个channel还是所有,个人理解self.base_temperature = base_temperature #设置的温度def forward(self, features, labels=None, mask=None):"""Compute loss for model. If both `labels` and `mask` are None,it degenerates to SimCLR unsupervised loss:https://arxiv.org/pdf/2002.05709.pdfArgs:features: hidden vector of shape [bsz, n_views, ...].labels: ground truth of shape [bsz].mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample jhas the same class as sample i. Can be asymmetric.Returns:A loss scalar."""device = (torch.device('cuda')#设置设备if features.is_cudaelse torch.device('cpu'))if len(features.shape) < 3:raise ValueError('`features` needs to be [bsz, n_views, ...],''at least 3 dimensions are required')if len(features.shape) > 3:# batch_size, channel,H,W,平铺变成batch_size, channel, (H,W)features = features.view(features.shape[0], features.shape[1], -1)batch_size = features.shape[0]if labels is not None and mask is not None:#只能存在一个raise ValueError('Cannot define both `labels` and `mask`')elif labels is None and mask is None:#如果两个都没有就是无监督对比损失,mask就是一个单位阵mask = torch.eye(batch_size, dtype=torch.float32).to(device)elif labels is not None:#有标签,就把他变成masklabels = labels.contiguous().view(-1, 1)#contiguous深拷贝,与原来的labels没有关系,展开成一列,这样的话能够计算mask,否则labels一维的话labels.T是他本身捕获发生转置if labels.shape[0] != batch_size:raise ValueError('Num of labels does not match num of features')mask =  torch.eq(labels, labels.T).float().to(device)#label和label的转置比较,感觉应该是广播机制,让label和label.T都扩充了然后进行比较,相同的是1,不同是0.#这里就是由label形成mask,mask(i,j)代表第i个数据和第j个数据的关系,如果两个类别相同就是1, 不同就是0else:mask = mask.float().to(device)#有mask就直接用mask,mask也是代表两个数据之间的关系contrast_count = features.shape[1]#对比数是channel的个数contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)#把feature按照第1维拆开,然后在第0维上cat,(batch_size*channel,h*w..)#后面就是展开的feature的维度#这个操作就和后面mask.repeat对上了,这个操作是第一个数据的第一维特征+第二个数据的第一维特征+第三个数据的第一维特征这样排列的与mask对应if self.contrast_mode == 'one':#如果mode=one,比较feature中第1维中的0号元素(batch, h*w)anchor_feature = features[:, 0]anchor_count = 1elif self.contrast_mode == 'all':#all就(batch*channel, h*w)anchor_feature = contrast_featureanchor_count = contrast_countelse:raise ValueError('Unknown mode: {}'.format(self.contrast_mode))# compute logitsanchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),#两个相乘获得相似度矩阵,乘积值越大代表越相关self.temperature)# for numerical stabilitylogits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)#计算其中最大值logits = anchor_dot_contrast - logits_max.detach()#减去最大值,都是负的了,指数就小于等于1# tile maskmask = mask.repeat(anchor_count, contrast_count)#repeat它就是把mask复制很多份# mask-out self-contrast caseslogits_mask = torch.scatter(#生成一个mask形状的矩阵除了对角线上的元素是0,其他位置都是1, 不会对自身进行比较torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0)mask = mask * logits_mask# compute log_probexp_logits = torch.exp(logits) * logits_mask#定义其中的相似度log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))#softmax# compute mean of log-likelihood over positive# modified to handle edge cases when there is no positive pair# for an anchor point. # Edge case e.g.:- # features of shape: [4,1,...]# labels:            [0,1,1,2]# loss before mean:  [nan, ..., ..., nan] mask_pos_pairs = mask.sum(1)#mask的和mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)#满足返回1,不满足返回mask_pos_pairs.保证数值稳定mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs# lossloss = - (self.temperature / self.base_temperature) * mean_log_prob_pos#类似蒸馏temperature温度越高,分布曲线越平滑不易陷入局部最优解,温度低,分布陡峭loss = loss.view(anchor_count, batch_size).mean()#计算平均return loss

使用的化就是下面这段:

loss = criterion(features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/641108.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux C | 进程】进程终止、等待 | exit、_exit、wait、waitpid

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

【Kafka】开发实战和Springboot集成kafka

目录 消息的发送与接收生产者消费者 SpringBoot 集成kafka服务端参数配置 消息的发送与接收 生产者 生产者主要的对象有&#xff1a; KafkaProducer &#xff0c; ProducerRecord 。 其中 KafkaProducer 是用于发送消息的类&#xff0c; ProducerRecord 类用于封装Kafka的消息…

仅使用K-M法+Cox比例风险模型就能发二区文章 | SEER公共数据库周报(1.17)

欢迎各位参加本周中山大学著名卫生统计学家方积乾教授公益直播讲座&#xff01; 就在本周三晚&#xff0c;主题为“真实世界研究与RCT研究”&#xff0c;欢迎各位预约参加&#xff01; SEER&#xff08;The Surveillance, Epidemiology, and End Results&#xff09;数据库是由…

回溯算法篇-01:全排列

力扣46&#xff1a;全排列 题目分析 这道题属于上一篇——“回溯算法解题框架与思路”中的 “元素不重复不可复用” 那一类中的 排列类问题。 我们来回顾一下当时是怎么说的&#xff1a; 排列和组合的区别在于&#xff0c;排列对“顺序”有要求。比如 [1,2] 和 [2,1] 是两个不…

柔性数组和C语言内存划分

柔性数组和C语言内存划分 1. 柔性数组1.1 柔性数组的特点&#xff1a;1.2 柔性数组的使用1.3 柔性数组的优势 2. 总结C/C中程序内存区域划分 1. 柔性数组 也许你从来没有听说过柔性数组&#xff08;flexible array)这个概念&#xff0c;但是它确实是存在的。 C99 中&#xff…

力扣740. 删除并获得点数

动态规划 思路&#xff1a; 选择元素 x&#xff0c;获得其点数&#xff0c;删除 x 1 和 x - 1&#xff0c;则其他的 x 的点数也会被获得&#xff1b;可以将数组转换成一个有序 map&#xff0c;key 为 x&#xff0c; value 为对应所有 x 的和&#xff1b;则问题转换成了不能同…

Postman基本使用、测试环境(Environment)配置

文章目录 准备测试项目DemoController测试代码Interceptor模拟拦截配置 Postman模块简单介绍Postman通用环境配置新建环境(Environment)配置环境(Environment)设置域名变量引用域名变量查看请求结果打印 Postman脚本设置变量登录成功后设置全局Auth-Token脚本编写脚本查看conso…

即插即用篇 | UniRepLKNet:用于音频、视频、点云、时间序列和图像识别的通用感知大卷积神经网络 | DRepConv

大卷积神经网络(ConvNets)近来受到了广泛研究关注,但存在两个未解决且需要进一步研究的关键问题。1)现有大卷积神经网络的架构主要遵循传统ConvNets或变压器的设计原则,而针对大卷积神经网络的架构设计仍未得到解决。2)随着变压器在多个领域的主导地位,有待研究ConvNets…

软件设计师——项目管理(一)

&#x1f4d1;前言 本文主要是【项目管理】——软件设计师——项目管理的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304…

Databend 开源周报第 129 期

Databend 是一款现代云数仓。专为弹性和高效设计&#xff0c;为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务&#xff1a;https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展&#xff0c;遇到更贴近你心意的 Databend 。 支持标准流 标…

如何在 Ubuntu 22.04 上安装 Apache Web 服务器

前些天发现了一个人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;最重要的屌图甚多&#xff0c;忍不住分享一下给大家。点击跳转到网站。 如何在 Ubuntu 22.04 上安装 Apache Web 服务器 介绍 Apache HTTP 服务器是世界上使用最广泛的 Web 服务器。它…

模拟器单窗口ip有问题?试试关闭IPV6来解决

目前应该不止雷电9有这个问题了&#xff0c;最早是看到无忧群里在说有这个问题&#xff0c;后面发现很多其他的ip软件也有同样的问题&#xff0c;很多人都遇到&#xff0c;所以做个图文教程在这里&#xff0c;没出问题的也可以设置一下&#xff0c;目前ipv6也还没普及&#xff…

x-cmd pkg | hurl - HTTP 请求处理工具

目录 简介首次用户功能特点竞品和相关作品进一步探索 简介 Hurl 是 HTTP 请求处理工具&#xff0c;支持使用简单的纯文本格式定义的 HTTP 请求。它的用途非常广泛&#xff0c;既可以用于获取数据&#xff0c;也可以用于测试HTTP会话。 它可以链式处理请求&#xff0c;捕获数值…

ORA-01033: ORACLE initialization or shutdown in progress---惜分飞

客户反馈数据库使用plsql dev登录报ORA-01033: ORACLE initialization or shutdown in progress的错误 出现该错误一般是由于数据库没有正常open成功,查看oracle 告警日志发现 Mon Jan 22 16:55:50 2024 Database mounted in Exclusive Mode Lost write protection disabled …

Unity SRP 管线【第五讲:URP烘培光照】

本节&#xff0c;我们将跟随数据流向讲解UEP管线中的烘培光照。 文章目录 一、URP烘培光照1. 搭建场景2. 烘培光照参数设置MixedLight光照设置&#xff1a;直观感受 Lightmapping Settings参数设置&#xff1a; 3. 我们如何记录次表面光源颜色首先我们提取出相关URP代码&#…

企业数字档案馆的构成要素

企业数字档案馆的构成要素包括以下几个方面&#xff1a; 1. 系统平台&#xff1a;企业数字档案馆需要有一个稳定的系统平台&#xff0c;用于存储、管理和检索档案信息。这个平台可以是基于云计算、数据库或其他技术的&#xff0c;能够支持大容量的数据存储和快速的检索功能。 2…

设计模式二(工厂模式)

本质&#xff1a;实例化对象不用new&#xff0c;用工厂代替&#xff0c;实现了创建者和调用者分离 满足&#xff1a; 开闭原则&#xff1a;对拓展开放&#xff0c;对修改关闭 依赖倒置原则&#xff1a;要针对接口编程 迪米特原则&#xff1a;最少了解原则&#xff0c;只与自己直…

Unity—配置lua环境变量+VSCode 搭建 Lua 开发环境

每日一句&#xff1a;保持须臾的浪漫&#xff0c;理想的喧嚣&#xff0c;平等的热情 Windows 11下配置lua环境变量 一、lua-5.4.4版本安装到本地电脑 链接&#xff1a;https://pan.baidu.com/s/14pAlOjhzz2_jmvpRZf9u6Q?pwdhd4s 提取码&#xff1a;hd4s 二、高级系统设置 此电…

P9232 [蓝桥杯 2023 省 A] 更小的数

[蓝桥杯 2023 省 A] 更小的数 终于本弱一次通关了一道研究生组别的题了[普及/提高−] 一道较为简单的双指针题,但一定有更好的解法. 题目描述 小蓝有一个长度均为 n n n 且仅由数字字符 0 ∼ 9 0 \sim 9 0∼9 组成的字符串&#xff0c;下标从 0 0 0 到 n − 1 n-1 n−1&a…

C++ //练习 2.35 判断下列定义推断出的类型是什么,然后编写程序进行验证。

C Primer&#xff08;第5版&#xff09; 练习 2.35 练习 2.35 判断下列定义推断出的类型是什么&#xff0c;然后编写程序进行验证。 const int i 42; auto j i; const auto &k i; auto *p &i; const auto j2 i, &k2 i;环境&#xff1a;Linux Ubuntu&#x…