【NLP 50、损失函数 KL散度】

目录

一、定义与公式

1.核心定义

2.数学公式

3.KL散度与交叉熵的关系

二、使用场景

1.生成模型与变分推断

2.知识蒸馏

3.模型评估与优化

4.信息论与编码优化

三、原理与特性

1.信息论视角

​2.优化目标

3.​局限性

四、代码示例

代码运行流程

核心代码解析


抵达梦想靠的不是狂热的想象,而是谦卑的务实,甚至你自己都看不起的可怜的隐忍

                                                                                                                        —— 25.3.27

一、定义与公式

1.核心定义

        KL散度(相对熵)是衡量两个概率分布 P 和 Q 之间差异的非对称性指标。它量化了当用分布 Q 近似真实分布 P 时的信息损失

非对称性:,即P和Q的顺序不能交换

非负性:,当且仅当P = Q时取等号

2.数学公式

离散形式:

连续形式:

其中,P是真实分布,Q是近似分布

3.KL散度与交叉熵的关系

KL散度可以分解为交叉熵H(P,Q与P的熵H(P):

交叉熵常用于分类任务,而KL散度更关注分布间的信息差异


二、使用场景

1.生成模型与变分推断

变分自编码器(VAE)​:通过最小化,使编码器输出的隐变量分布Q(z|x)逼近先验分布P(z)

生成对抗网络(GAN)​:辅助衡量生成分布与真实分布的差异

2.知识蒸馏

        将复杂教师模型的输出概率(软标签)作为监督信号,指导学生模型学习,损失函数中常包含KL散度项

3.模型评估与优化

多模态分布对齐:在推荐系统中对齐用户行为分布与模型预测分布​

异常检测:通过KL散度衡量测试数据分布与正常数据分布的偏离程度

4.信息论与编码优化

最小化编码长度:KL散度表示用 Q 编码 P 时所需的额外比特数


三、原理与特性

1.信息论视角

信息增益:KL散度表示从 Q 中获取 P 的信息时需要增加的“惊讶度”(Surprisal)。

凸性:KL散度是凸函数,可通过梯度下降法优化。

​2.优化目标

前向KL散度 DKL​(P∥Q):要求 Q 覆盖 P 的主要模式,避免 Q 的“零概率陷阱”(即 Q(x)=0 但 P(x)>0 会导致无穷大)​反向KL散度 DKL​(Q∥P):鼓励 Q 聚焦于 P 的单一主峰,适用于稀疏分布近似。

3.​局限性

非对称性:需根据任务选择方向(如VAE使用前向KL,部分GAN变体使用反向KL)

数值稳定性:需避免 Q(x)=0 或极端概率值,可通过平滑或温度参数(Temperature Scaling)调整。


四、代码示例

代码运行流程

KL散度计算流程
├── 1. 输入预处理
│   ├── a. 获取学生/教师模型原始输出
│   │   ├─ student_logits: 形状(batch=32, classes=10)
│   │   └─ teacher_logits: 同左[1,3](@ref)
│   └── b. 温度参数初始化
│       └─ temperature=5.0 (默认值)
├── 2. 概率变换
│   ├── a. 温度缩放
│   │   ├─ student_logits → student_logits / 5.0
│   │   └─ teacher_logits → teacher_logits / 5.0
│   ├── b. 概率归一化
│   │   ├─ student_probs = log_softmax(...)  # 对数空间
│   │   └─ teacher_probs = softmax(...)      # 线性空间
├── 3. 损失计算
│   ├── a. 初始化KLDivLoss
│   │   └─ reduction='batchmean' (符合数学期望)
│   ├── b. 执行KL散度计算
│   │   └─ KL(student_probs || teacher_probs)
│   └── c. 梯度补偿
│       └─ 乘以temperature²=25 恢复梯度幅值
└── 4. 结果输出└── 打印损失值 (标量Tensor转float)

student_logits:学生模型的原始输出(未归一化),形状为 (batch_size, num_classes),表示每个样本的预测得分

teacher_logits:教师模型的原始输出(未归一化),作为知识蒸馏的监督信号,形状同student_logits

temperature:温度缩放参数,软化概率分布(值越大分布越平滑,值越小越接近原始分布)

student_probs:学生模型经温度缩放后的对数概率

teacher_probs:教师模型经温度缩放后的概率

loss:KL散度损失的计算结果,表示学生模型输出分布与教师模型输出分布之间的差异程度。该值是一个标量(Scalar),用于指导反向传播优化学生模型的参数

batch_size:表示 ​单次输入模型的样本数量,即一次前向传播和反向传播处理32个样本。

nums_classes: 表示 ​分类任务的类别总数,即模型需区分的不同标签种类数。

F.log_softmax():将输入张量通过Softmax函数归一化为概率分布后,再对每个元素取自然对数,常用于分类任务的损失计算(如交叉熵损失)。

参数名类型说明默认值
​**input**Tensor输入张量必填
​**dim**int指定归一化的维度(如dim=1表示按行计算)必填

F.softmax():将输入张量通过指数函数归一化为概率分布,输出值范围为(0,1)且和为1。

参数名类型说明默认值
​**input**Tensor输入张量必填
​**dim**int归一化维度(如dim=0按列归一化)必填

nn.KLDivLoss():计算两个概率分布之间的Kullback-Leibler散度(KL散度),用于衡量分布差异。

参数名类型说明可选值默认值
​**reduction**str损失聚合方式'none''mean''sum''batchmean''mean'

torch.randn():生成服从标准正态分布(均值为0,标准差为1)的随机数张量,常用于初始化权重或生成噪声数据。

参数名类型说明默认值
​***size**int或tuple张量形状(如(3,4)生成3行4列矩阵)必填
​**dtype**torch.dtype数据类型(如torch.float32None(自动推断)
​**device**torch.device设备(如'cuda'CPU
​**requires_grad**bool是否需要梯度跟踪False

item():PyTorch中torch.Tensor类的方法,用于从单元素张量中提取Python标量值(如intfloat等)

核心代码解析

loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ​** 2)

nn.KLDivLoss(reduction='batchmean')计算学生模型输出 (student_probs) 与教师模型输出 (teacher_probs) 之间的 ​KL散度,衡量两者的概率分布差异

        ​参数 reduction='batchmean'将每个样本的KL散度求和后除以批量大小 (batch_size),确保损失值符合KL散度的数学定义

   mean对所有元素取平均(总和除以元素总数)。

   sum直接求和。

   none保留每个样本的独立损失值。

(student_probs, teacher_probs):输入参数student_probs 和 teacher_probs

* (temperature ​** 2):温度缩放与梯度补偿

        温度的作用:软化概率分布:高温值会使教师模型的概率分布更平滑,避免过度关注高置信度类别

        ​为何乘以 temperature²① 梯度补偿:温度缩放会缩小梯度的幅值,乘以 temperature² 可恢复原始梯度量级,确保优化方向正确​ ② 数学推导:KL散度计算中,温度参数会引入缩放因子 T1​,反向传播时梯度需乘以 T2 以抵消缩放效应。

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义KL散度损失函数(带温度参数)
def kl_div_loss_with_temperature(student_logits, teacher_logits, temperature=5.0):# 对logits应用温度缩放student_probs = F.log_softmax(student_logits / temperature, dim=-1)teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)# 计算KL散度loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ​** 2)return loss# 模拟输入数据
batch_size, num_classes = 32, 10
student_logits = torch.randn(batch_size, num_classes)  # 学生模型输出(未归一化)
teacher_logits = torch.randn(batch_size, num_classes)  # 教师模型输出(未归一化)# 计算损失
loss = kl_div_loss_with_temperature(student_logits, teacher_logits)
print(f"KL散度损失: {loss.item()}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/73742.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用QT画带有透明效果的图

分辨率&#xff1a;24X24 最大圆 代码: #include <QApplication> #include <QImage> #include <QPainter>int main(int argc, char *argv[]) {QImage image(QSize(24,24),QImage::Format_ARGB32);image.fill(QColor(0,0,0,0));QPainter paint(&image);…

【Unity网络编程知识】使用Socket实现简单TCP通讯

1、Socket的常用属性和方法 创建Socket TCP流套接字 Socket socketTcp new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); 1.1 常用属性 1&#xff09;套接字的连接状态 socketTcp.Connected 2&#xff09;获取套接字的类型 socketTcp.So…

青少年编程与数学 02-013 初中数学知识点 02课题、概要

青少年编程与数学 02-013 初中数学知识点 02课题、概要 一、数与代数二、图形与几何三、统计与概率四、综合与实践五、课程理念与目标 根据2022年版义务教育数学课程标准&#xff0c;初中数学知识点可以总结为以下四大领域。 一、数与代数 数与式 有理数与实数&#xff1a;理解…

深入探索 libarchive

深入探索 libarchive&#xff1a;跨平台归档处理的终极解决方案 一、背景与历史沿革 1.1 归档处理的演进之路 从1979年tar格式的诞生到现代云存储时代&#xff0c;归档技术经历了四个关键阶段&#xff1a; Unix时代&#xff1a;tar/cpio主导系统备份互联网黎明期&#xff1…

2025最新“科研创新与智能化转型“暨AI智能体开发与大语言模型的本地化部署、优化技术实践

第一章、智能体(Agent)入门 1、智能体&#xff08;Agent&#xff09;概述&#xff08;什么是智能体&#xff1f;智能体的类型和应用场景、典型的智能体应用&#xff0c;如&#xff1a;Google Data Science Agent等&#xff09; 2、智能体&#xff08;Agent&#xff09;与大语…

Yolo_v8的安装测试

前言 如何安装Python版本的Yolo&#xff0c;有一段时间不用了&#xff0c;Yolo的版本也在不断地发展&#xff0c;所以重新安装了运行了一下&#xff0c;记录了下来&#xff0c;供参考。 一、搭建环境 1.1、创建Pycharm工程 首先创建好一个空白的工程&#xff0c;如下图&…

时尚界正在试图用AI,创造更多冲击力

数字艺术正以深度融合的方式&#xff0c;在时尚、游戏、影视等行业实现跨界合作&#xff0c;催生了多样化的商业模式&#xff0c;为创作者和品牌带来更多机会&#xff0c;数字艺术更是突破了传统艺术的限制&#xff0c;以趣味触达用户&#xff0c;尤其吸引了年轻一代的消费群体…

蓝桥杯省模拟赛 01串个数

问题描述 请问有多少个长度为 24 的 01 串&#xff0c;满足任意 5 个连续的位置中不超过 3 个位置的值为 1。 所有长度为24的01串组合有2*24种 思路&#xff1a;遍历所有长度为24的01串组合&#xff0c;选择出符合题意的 #include<iostream> #include<cmath> us…

【软考备考】系统架构设计论文完整范文示例

本文由AI辅助创造 题目:基于微服务与云原生的智慧政务平台架构设计与实践 摘要(约300字) 本文以某省级智慧政务平台建设项目为背景,针对传统政务系统存在的"信息孤岛"、扩展性差、维护成本高等问题,提出了一套基于微服务与云原生技术的解决方案。通过领域驱动…

数据库原理及应用mysql版陈业斌实验二

&#x1f3dd;️专栏&#xff1a;Mysql_猫咪-9527的博客-CSDN博客 &#x1f305;主页&#xff1a;猫咪-9527-CSDN博客 “欲穷千里目&#xff0c;更上一层楼。会当凌绝顶&#xff0c;一览众山小。” 目录 实验二单表查询 1.实验数据如下 student 表&#xff08;学生表&#…

SDL —— 将sdl渲染画面嵌入Qt窗口显示(附:源码)

🔔 SDL/SDL2 相关技术、疑难杂症文章合集(掌握后可自封大侠 ⓿_⓿)(记得收藏,持续更新中…) 效果 使用QWidget加载了SDL的窗口,渲染器使用硬件加速跑GPU的。支持Qt窗口缩放或显示隐藏均不影响SDL的图像刷新。   操作步骤 1、在创建C++空工程时加入SDL,引入头文件时需…

C语言之链表增删查改

1.知识百科 链表&#xff08;Linked List&#xff09;是计算机科学中一种基础的数据结构&#xff0c;通过节点&#xff08;Node&#xff09;的链式连接来存储数据。每个节点包含两部分&#xff1a;存储数据的元素和指向下一个节点的指针&#xff08;单链表&#xff09;或前后两…

Windows环境下AnythingLLM安装与Ollama+DeepSeek集成指南

前面已经完成了Ollama的安装并下载了deepseek大模型包&#xff0c;下面介绍如何与anythingLLM 集成 Windows环境下AnythingLLM安装与OllamaDeepSeek集成指南 一、安装准备 1. 硬件要求 如上文说明 2. 前置条件 已安装Ollama并下载DeepSeek模型&#xff08;如deepseek-r1:…

当贝AI知识库评测 AI如何让知识检索快人一步

近日,国内领先的人工智能服务商当贝AI正式推出“个人知识库”功能,这一创新性工具迅速引发行业关注。在信息爆炸的时代,如何高效管理个人知识资产、快速获取精准答案成为用户的核心需求。当贝AI通过将“闭卷考试”变为“开卷考试”的独特设计,为用户打造了一个高度个性化的智能…

HarmonyOS NEXT——【鸿蒙原生应用加载Web页面】

鸿蒙客户端加载Web页面&#xff1a; 在鸿蒙原生应用中&#xff0c;我们需要使用前端页面做混合开发&#xff0c;方法之一是使用Web组件直接加载前端页面&#xff0c;其中WebView提供了一系列相关的方法适配鸿蒙原生与web之间的使用。 效果 web页面展示&#xff1a; Column()…

嵌入式开发场景中Shell脚本执行方式的对比

‌Shell脚本执行方式对比表‌ ‌执行方式‌‌命令示例‌‌是否需要执行权限‌‌是否启动子Shell‌‌环境变量影响范围‌‌适用场景‌‌嵌入式开发中的典型应用‌‌直接执行脚本‌./script.sh是是子Shell内有效独立运行的脚本&#xff0c;需固定环境自动化构建脚本&#xff08;…

MES系统需要采集的数据及如何采集

​数据采集在企业信息化建设中占据着举足轻重的地位&#xff0c;是实现物料跟踪、生产计划制定、产品历史记录维护以及其他生产管理活动的基石。数据的准确性和实时性直接关系到企业信息化能否成功落地&#xff0c;是企业迈向高效生产的关键因素。 数据收集对于MES制造执行系统…

闭环管理:借助数字化管理平台实现客户反馈的价值升级

在竞争激烈的市场环境中&#xff0c;客户反馈已成为企业优化服务、提升竞争力的核心资源。如何高效处理客户反馈&#xff0c;将其转化为企业持续改进的动力&#xff0c;是每个企业面临的重要课题。作为服务管理数字化转型服务商&#xff0c;瑞云服务云为大中型企业提供了一套完…

C++Primer学习(13.6 对象移动)

13.6 对象移动 新标准的一个最主要的特性是可以移动而非拷贝对象的能力。如我们在13.1.1节(第440页)中所见&#xff0c;很多情况下都会发生对象拷贝。在其中某些情况下&#xff0c;对象拷贝后就立即被销毁了。在这些情况下&#xff0c;移动而非拷贝对象会大幅度提升性能。 如我…

Uni-app页面信息与元素影响解析

获取窗口信息uni.getWindowInfo {pixelRatio: 3safeArea:{bottom: 778height: 731left: 0right: 375top: 47width: 375}safeAreaInsets: {top: 47, left: 0, right: 0, bottom: 34},screenHeight: 812,screenTop: 0,screenWidth: 375,statusBarHeight: 47,windowBottom: 0,win…