深度学习中的数值稳定性处理详解:以SimCLR损失为例

文章目录

    • 1. 问题背景
      • SimCLR的原始公式
    • 2. 数值溢出问题
      • 为什么会出现数值溢出?
      • 浮点数的表示范围
    • 3. 数值稳定性处理方法
      • 核心思想
      • 数学推导
    • 4. 代码实现分解
      • 代码与公式的对应关系
    • 5. 具体数值示例
      • 示例:相似度矩阵
      • 方法1:直接计算exp(x)
      • 方法2:减去最大值后计算
      • 验证结果等价性
    • 6. 为什么减去最大值有效?
      • 关键原理
    • 7. 实际应用场景
    • 8. 实现建议
    • 总结

在深度学习实现中,特别是涉及指数和对数运算的损失函数计算过程中,数值稳定性是一个核心问题。本文以SimCLR对比学习损失为例,详细解析数值稳定性处理的原理、实现和重要性。

1. 问题背景

SimCLR是一种自监督学习方法,其核心是InfoNCE损失函数。这个损失函数的计算涉及大量指数运算,容易导致数值溢出或下溢问题。

SimCLR的原始公式

SimCLR的核心损失函数(InfoNCE损失)公式为:

L i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ ) ⋅ 1 k ≠ i L_i = -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} Li=logk=12Nexp(sim(zi,zk)/τ)1k=iexp(sim(zi,zj)/τ)

其中:

  • z i z_i zi是锚点特征
  • z j z_j zj是与 z i z_i zi对应的正样本特征
  • τ \tau τ是温度参数
  • s i m ( ) sim() sim()是相似度函数(通常是点积)
  • 1 k ≠ i \mathbf{1}_{k \neq i} 1k=i表示排除自身对比的指示函数

2. 数值溢出问题

为什么会出现数值溢出?

当我们计算 exp ⁡ ( x ) \exp(x) exp(x)时:

  • 如果 x x x很大(如 x = 100 x = 100 x=100), exp ⁡ ( 100 ) ≈ 2.7 × 1 0 43 \exp(100) \approx 2.7 \times 10^{43} exp(100)2.7×1043,可能超出浮点数表示范围
  • 如果 x x x是很小的负数(如 x = − 100 x = -100 x=100), exp ⁡ ( − 100 ) ≈ 3.7 × 1 0 − 44 \exp(-100) \approx 3.7 \times 10^{-44} exp(100)3.7×1044,可能导致下溢为0

在SimCLR中, s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi,zk)/τ可能很大,特别是当:

  • 特征向量高度相似( s i m sim sim接近1)
  • 温度参数 τ \tau τ很小(如0.07)

浮点数的表示范围

浮点数的表示范围是有限的:

  • 单精度浮点数(32位):约 ± 3.4 × 1 0 38 \pm 3.4 \times 10^{38} ±3.4×1038
  • 双精度浮点数(64位):约 ± 1.8 × 1 0 308 \pm 1.8 \times 10^{308} ±1.8×10308

3. 数值稳定性处理方法

SimCLR实现中使用了一种简单而有效的数值稳定性处理技术,代码如下:

# 数值稳定性处理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

核心思想

这种处理的核心思想是:

  1. 找出每行相似度的最大值
  2. 将每行的所有值减去这个最大值
  3. 然后再进行指数计算

数学推导

这种操作是数学等价的。对原始公式进行变换:

L i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ ) ⋅ 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} \\ \end{align} Li=logk=12Nexp(sim(zi,zk)/τ)1k=iexp(sim(zi,zj)/τ)

引入最大值 M i = max ⁡ k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi=maxk(sim(zi,zk)/τ)

L i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ − M i + M i ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ − M i + M i ) ⋅ 1 k ≠ i = − log ⁡ exp ⁡ ( M i ) ⋅ exp ⁡ ( s i m ( z i , z j ) / τ − M i ) exp ⁡ ( M i ) ⋅ ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ − M i ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i + M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i + M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(M_i) \cdot \exp(sim(z_i, z_j)/\tau - M_i)}{\exp(M_i) \cdot \sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \end{align} Li=logk=12Nexp(sim(zi,zk)/τMi+Mi)1k=iexp(sim(zi,zj)/τMi+Mi)=logexp(Mi)k=12Nexp(sim(zi,zk)/τMi)1k=iexp(Mi)exp(sim(zi,zj)/τMi)=logk=12Nexp(sim(zi,zk)/τMi)1k=iexp(sim(zi,zj)/τMi)

因为分子和分母中的 exp ⁡ ( M i ) \exp(M_i) exp(Mi)相互抵消,所以最终结果不变。

4. 代码实现分解

完整的SimCLR损失计算代码(包含数值稳定性处理):

# 计算相似度矩阵并除以温度系数
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),self.temperature)# 数值稳定性处理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()# 创建和应用掩码
mask = mask.repeat(anchor_count, contrast_count)
logits_mask = torch.scatter(torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0
)
mask = mask * logits_mask# 计算损失
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()

代码与公式的对应关系

  1. anchor_dot_contrast s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi,zk)/τ
  2. logits_max M i = max ⁡ k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi=maxk(sim(zi,zk)/τ)
  3. logits s i m ( z i , z k ) / τ − M i sim(z_i, z_k)/\tau - M_i sim(zi,zk)/τMi
  4. exp_logits exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i} exp(sim(zi,zk)/τMi)1k=i
  5. log_prob log ⁡ exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ∑ k exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i \log \frac{\exp(sim(z_i, z_k)/\tau - M_i)}{\sum_{k} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} logkexp(sim(zi,zk)/τMi)1k=iexp(sim(zi,zk)/τMi)

5. 具体数值示例

为了直观理解,我们用一个简化的例子来说明为什么减去最大值能防止数值溢出。

示例:相似度矩阵

假设有一个计算得到的相似度矩阵(已除以温度τ=0.07):

sim(z_i, z_k)/τ = [[80, 50, 60, 70, 40],[60, 90, 70, 80, 50],[70, 60, 85, 75, 55],[50, 40, 60, 75, 45]
]

方法1:直接计算exp(x)

直接计算exp(sim(z_i, z_k)/τ)

exp(sim(z_i, z_k)/τ) ≈ [[5.54e+34, 5.18e+21, 1.14e+26, 2.51e+30, 2.35e+17],[1.14e+26, 1.22e+39, 2.51e+30, 5.54e+34, 5.18e+21],[2.51e+30, 1.14e+26, 5.91e+36, 3.58e+32, 1.14e+24],[5.18e+21, 2.35e+17, 1.14e+26, 3.58e+32, 3.49e+19]
]

这些值极其巨大,相加时很容易溢出。例如第一行的和约为5.54e+34,已经接近单精度浮点数的上限。

方法2:减去最大值后计算

找出每行的最大值:

max_values = [80, 90, 85, 75]

减去最大值:

adjusted_logits = [[0, -30, -20, -10, -40],[-30, 0, -20, -10, -40],[-15, -25, 0, -10, -30],[-25, -35, -15, 0, -30]
]

计算exp(adjusted_logits)

exp(adjusted_logits) ≈ [[1.0, 9.36e-14, 2.06e-9, 4.54e-5, 4.25e-18],[9.36e-14, 1.0, 2.06e-9, 4.54e-5, 4.25e-18],[3.06e-7, 1.39e-11, 1.0, 4.54e-5, 9.36e-14],[1.39e-11, 6.31e-16, 3.06e-7, 1.0, 9.36e-14]
]

这些值都在[0,1]范围内,完全避免了溢出问题。同时,正样本对和负样本对之间的相对比例关系保持不变。

验证结果等价性

例如,对于第一行计算最终的归一化概率:

原始方法:

P(z_0 -> z_0) = exp(80) / sum(exp(row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(50) / sum(exp(row_0)) ≈ 9.35e-14
...

减去最大值后:

P(z_0 -> z_0) = exp(0) / sum(exp(adjusted_row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(-30) / sum(exp(adjusted_row_0)) ≈ 9.35e-14
...

两种计算方法得到的概率分布是相同的,但后者避免了数值溢出风险。

6. 为什么减去最大值有效?

关键原理

减去最大值的处理之所以有效,是因为:

  1. 将范围控制在安全区间

    • 减去最大值后,所有值都≤0
    • 因此所有exp(x)的结果都≤1,避免了上溢
    • 同时最大值对应的exp(0)=1,避免了整体下溢为0
  2. 保持相对比例关系

    • 对每行减去相同的常数不改变值之间的相对大小
    • 对于exp()函数来说,这等价于同时除以一个常数因子
    • 在计算Softmax或对数概率时,这个常数因子在分子和分母中抵消
  3. 数学等价性

    • exp(a-b) = exp(a)/exp(b)的性质保证了结果的正确性
    • 这相当于将原始公式的分子和分母同时除以exp(max_value)

7. 实际应用场景

这种数值稳定性技术不仅适用于SimCLR,还广泛应用于:

  1. Softmax计算:几乎所有需要计算Softmax的地方都需要
  2. 交叉熵损失:分类任务中常用
  3. 注意力机制:Transformer中的attention计算
  4. 所有对比学习方法:MoCo、BYOL、CLIP等

8. 实现建议

在实现涉及指数计算的函数时,建议:

  1. 始终使用数值稳定性处理
  2. 对每个batch/样本独立进行处理(找到每行/每个样本的最大值)
  3. 使用.detach()阻止梯度通过最大值操作传播
  4. 注意掩码操作,确保不包括自身对比或特定的负样本

总结

数值稳定性处理是深度学习实现中一个看似简单但至关重要的技术。通过简单地减去每行的最大值,我们可以有效防止数值溢出/下溢问题,同时保持计算结果的数学等价性。这种技术尤其重要,因为随着模型和批量大小的增加,数值问题更容易出现,而且往往难以诊断。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/79019.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL(9):创建数据库,表,简单

1、创建数据库,一句SQL语句搞定 CREATE DATDBASE 数据库名 CREATE DATABASE my_db;2、创建表 CREATE TABLE 表名(字段名 类型) CREATE TABLE Persons ( PersonID int, LastName varchar(255), FirstName varchar(255), Address varchar(255), City varchar(255)…

QT Sqlite数据库-教程002 查询数据-下

【1】数据库查询的优化:prepare prepare语句是一种在执行之前将SQL语句编译为字节码的机制,可以提高执行效率并防止SQL注入攻击。 【2】使用prepare查询一张表 QString myTable "myTable" ; QString cmd QString("SELECT * FROM %1…

cline 提示词工程指南-架构篇

cline 提示词工程指南-架构篇 本篇是 cline 提示词工程指南的学习和扩展,可以参阅: https://docs.cline.bot/improving-your-prompting-skills/prompting 前言 cline 是 vscode 的插件,用来在 vscode 里实现 ai 编程。 它使得你可以接入…

算法---子序列[动态规划解决](最长递增子序列)

最长递增子序列 子序列包含子数组&#xff01; 说白了&#xff0c;要用到双层循环&#xff01; 用双层循环中的dp[i]和dp[j]把所有子序列情况考虑到位 class Solution { public:int lengthOfLIS(vector<int>& nums) {vector<int> dp(nums.size(),1);for(int i …

kubectl命令补全以及oc命令补全

kubectl命令补全 1.安装bash-completion 如果你用的是Bash(默认情况下是)&#xff0c;先安装补全功能支持包 sudo apt update sudo apt install bash-completion -y2.为kubectl 启用补全功能 会话中临时&#xff1a; source <(kubectl completion bash)持久化配置&#x…

48、Spring Boot 详细讲义(五)

3、集成MyBatis 3.1 MyBatis 概述 3.1.1 核心功能和优势 MyBatis 是一个 Java 持久层框架,它通过 XML 或注解配置 SQL 语句,将 Java 方法与 SQL 语句映射起来,消除了大量的 JDBC 代码,简化了数据库操作。MyBatis 的核心功能和优势包括: ORM(对象关系映射):通过 XML …

BERT - Bert模型框架复现

本节将实现一个基于Transformer架构的BERT模型。 1. MultiHeadAttention 类 这个类实现了多头自注意力机制&#xff08;Multi-Head Self-Attention&#xff09;&#xff0c;是Transformer架构的核心部分。 在前几篇文章中均有讲解&#xff0c;直接上代码 class MultiHeadAtt…

解决 Spring Boot 启动报错:数据源配置引发的启动失败

启动项目时&#xff0c;控制台输出了如下错误信息&#xff1a; Error starting ApplicationContext. To display the condition evaluation report re-run your application with debug enabled. 2025-04-14 21:13:33.005 [main] ERROR o.s.b.d.LoggingFailureAnalysisReporte…

履带小车+六轴机械臂(2)

本次介绍原理图部分 开发板部分&#xff0c;电源供电部分&#xff0c;六路舵机&#xff0c;PS2手柄接收器&#xff0c;HC-05蓝牙模块&#xff0c;蜂鸣器&#xff0c;串口&#xff0c;TB6612电机驱动模块&#xff0c;LDO线性稳压电路&#xff0c;按键部分 1、开发板部分 需要注…

【开发记录】服务外包大赛记录

参加服务外包大赛的A07赛道中&#xff0c;最近因为频繁的DEBUG&#xff0c;心态爆炸 记录错误 以防止再次出现错误浪费时间。。。 2025.4.13 项目在上传图片之后 会自动刷新 没有等待后端返回 Network中的fetch /upload显示canceled. 然而这是使用了VS的live Server插件才这样&…

基于FreeRTOS和LVGL的多功能低功耗智能手表(硬件篇)

目录 一、简介 二、板子构成 三、核心板 3.1 MCU最小系统板电路 3.2 电源电路 3.3 LCD电路 3.4 EEPROM电路 3.5 硬件看门狗电路 四、背板 4.1 传感器电路 4.2 充电盘 4.3 蓝牙模块电路 五、总结 一、简介 本篇开始介绍这个项目的硬件部分&#xff0c;从最小电路设…

为 Kubernetes 提供智能的 LLM 推理路由:Gateway API Inference Extension 深度解析

现代生成式 AI 和大语言模型&#xff08;LLM&#xff09;服务给 Kubernetes 带来了独特的流量路由挑战。与典型的短时、无状态 Web 请求不同&#xff0c;LLM 推理会话通常是长时运行、资源密集且部分有状态的。例如&#xff0c;一个基于 GPU 的模型服务器可能同时维护多个活跃的…

MacOs下解决远程终端内容复制并到本地粘贴板

常常需要在服务器上捣鼓东西&#xff0c;同时需要将内容复制到本地的需求。 1-内容是在远程终端用vim打开&#xff0c;如何用vim的类似指令达到快速复制到本地呢&#xff1f; 假设待复制的内容&#xff1a; #include <iostream> #include <cstring> using names…

STM32 vs ESP32:如何选择最适合你的单片机?

引言 在嵌入式开发中&#xff0c;STM32 和 ESP32 是两种最热门的微控制器方案。但许多开发者面对项目选型时仍会感到困惑&#xff1a;到底是选择功能强大的 STM32&#xff0c;还是集成无线的 ESP32&#xff1f; 本文将通过 硬件资源、开发场景、成本分析 等多维度对比&#xf…

【blender小技巧】Blender导出带贴图的FBX模型,并在unity中提取材质模型使用

前言 这其实是我之前做过的操作&#xff0c;我只是单独提取出来了而已。感兴趣可以去看看&#xff1a;【blender小技巧】使用Blender将VRM或者其他模型转化为FBX模型&#xff0c;并在unity使用&#xff0c;导出带贴图的FBX模型&#xff0c;贴图材质问题修复 一、导出带贴图的…

如何保证本地缓存和redis的一致性

1. Cache Aside Pattern&#xff08;旁路缓存模式&#xff09;​​ ​核心思想​&#xff1a;应用代码直接管理缓存与数据的同步&#xff0c;分为读写两个流程&#xff1a; ​读取数据​&#xff1a; 先查本地缓存&#xff08;如 Guava Cache&#xff09;。若本地未命中&…

k8s通过service标签实现蓝绿发布

k8s通过service标签实现蓝绿发布 通过k8s service label标签实现蓝绿发布方法1:使用kubelet完成蓝绿切换1. 创建绿色版本1.1 创建绿色版本 Deployment1.2 创建绿色版本 Service 2. 创建蓝色版本2.1 创建蓝色版本 Deployment2.2 创建蓝色版本 Service 3. 创建蓝绿切换SVC (用于外…

智慧酒店企业站官网-前端静态网站模板【前端练习项目】

最近又写了一个静态网站&#xff0c;智慧酒店宣传官网。 使用的技术 html css js 。 特别适合编程学习者进行网页制作和前端开发的实践。 项目包含七个核心模块&#xff1a;首页、整体解决方案、优势、全国案例、行业观点、合作加盟、关于我们。 通过该项目&#xff0c;小伙伴们…

Epplus 8+ 许可证设置

Epplus 8 之后非商业许可证的设置变了如果还用普通的方法会报错 Unhandled exception. OfficeOpenXml.LicenseContextPropertyObsoleteException: Please use the static ‘ExcelPackage.License’ property to set the required license information from EPPlus 8 and later …

CST1016.基于Spring Boot+Vue高校竞赛管理系统

计算机/JAVA毕业设计 【CST1016.基于Spring BootVue高校竞赛管理系统】 【项目介绍】 高校竞赛管理系统&#xff0c;基于 DeepSeek Spring AI Spring Boot Vue 实现&#xff0c;功能丰富、界面精美 【业务模块】 系统共有两类用户&#xff0c;分别是学生用户和管理员用户&a…