深度学习| 交叉熵损失函数(包含代码实现)

前言:因为我深度学习主要用于图像分割,所以交叉熵损失函数主要侧重在图像分割。

交叉熵损失函数

  • 介绍
  • 公式
  • 交叉熵函数存在什么问题
  • 带权重的交叉熵函数
  • 代码

介绍

交叉熵损失函数(Cross-Entropy Loss)是深度学习中常用的一种损失函数,特别是处理分类问题。该函数起源于信息论中的交叉熵概念,用于衡量两个概率分布间的差异,可以衡量预估概率分布和真实样品对应概率分布之间的差异。

从这个概念上理解会让人感觉很抽象,建议直接从公式来进行理解。

公式

二分类
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i l o g ( y ^ i ) + ( 1 − y i ) l o g ( 1 − y ^ i ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_ilog(\widehat{y}_i)+(1-y_i)log(1-\widehat{y}_i)] H(y,y )=N1i=1N[yilog(y i)+(1yi)log(1y i)]
其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。每个样本都会计算一个损失,然后对所有样本的损失求平均。

对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

多分类
多分类就是二分类的延申,理解的原理都是一样的。
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i 1 l o g ( y ^ i 1 ) + y i 2 l o g ( y ^ i 2 ) + . . . + y i m l o g ( y ^ i m ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_{i1}log(\widehat{y}_{i1})+y_{i2}log(\widehat{y}_{i2})+...+y_{im}log(\widehat{y}_{im})] H(y,y )=N1i=1N[yi1log(y i1)+yi2log(y i2)+...+yimlog(y im)]
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^m[y_{ij}log(\widehat{y}_{ij})] H(y,y )=N1i=1Nj=1m[yijlog(y ij)]
这里的 y ^ \widehat{y} y 和y是one-hot编码目标向量,例如 y i = [ y i 1 , y i 2 , . . . , y i m ] y_i=[y_{i1},y_{i2},...,y_{im}] yi=[yi1,yi2,...,yim]

交叉熵函数存在什么问题

之前我的一篇博客提过使用交叉熵函数面对类别不均衡的时候会出现问题,导致结果会偏向更常见的类别,对少类别的识别非常差。

产生这点的原因是因为交叉熵的特点就是“平等”地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

带权重的交叉熵函数

由于交叉熵函数在应对类别不均衡会出现问题,于是就有了带权重的交叉熵函数。

带权重的交叉熵函数(Weighted cross entropy,WCE)会在计算交叉熵函数的时候,给不同类别前面加入一个权重。

公式
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ w j y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^m[w_{j} y_{ij}log(\widehat{y}_{ij})] H(y,y )=N1i=1Nj=1m[wjyijlog(y ij)]
其中 w j w_j wj表示对j类别的权重,用于增大在预测图上占比例小的类别,公式如下:
w j = N − ∑ 1 N y ^ i j ∑ 1 N y ^ i j w_j= \frac{N-\sum_{1}^N\widehat{y}_{ij}}{\sum_{1}^N\widehat{y}_{ij}} wj=1Ny ijN1Ny ij

补充:除了带权重的交叉熵函数能解决样本类别不均衡,还有DiceLoss和FocalLoss能用来解决。

代码

PyTorch的话有自带的库能解决:

import torch
import torch.nn as nnclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None, size_average=True):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, size_average)def forward(self, preds, targets):return self.nll_loss(preds, targets)

nn.CrossEntropyLoss(weight, size_average)
weight:可以指定一个一维的Tensor,用来设置每个类别的权重。用C表示类别的个数,Tensor的长度应该为C。
size_average:bool类型数据,默认情况下为True,此时损失是每个minibatch的平均;如果设置成False,则对每个minibatch求和。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/815649.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于深度学习的生活垃圾智能分类系统(微信小程序+YOLOv5+训练数据集+开题报告+中期检查+论文)

摘要 本文基于Python技术,搭建了YOLOv5s深度学习模型,并基于该模型研发了微信小程序的垃圾分类应用系统。本项目的主要工作如下: (1)调研了移动端垃圾分类应用软件动态,并分析其优劣势;分析了深…

【S32K3 MCAL配置】-4.1-CAN Driver:如何解决CAN帧发送丢帧问题

"><--返回「Autosar_MCAL高阶配置」专栏主页--> 案例背景:如何解决:同一时刻,连续调用多次CanIf_Transmit / Can_Write API,同时发送不同CANID帧,出现丢帧问题。 目录(共9页精讲,基于评估板: NXP S32K312EVB-Q172,手把手教你S32K3从入门到精通) 实现的架…

LeetCode-热题100:5. 最长回文子串

题目描述 给你一个字符串 s&#xff0c;找到 s 中最长的回文子串。 如果字符串的反序与原始字符串相同&#xff0c;则该字符串称为回文字符串。 示例 1&#xff1a; 输入&#xff1a; s “babad” 输出&#xff1a; “bab” 解释&#xff1a; “aba” 同样是符合题意的答案…

鸿蒙开发学习笔记第一篇--TypeScript基础语法

目录 前言 一、ArkTS 二、基础语法 1.基础类型 1.布尔值 2.数字 3.字符串 4.数组 5.元组 6.枚举 7.unkown 8.void 9.null和undefined 10.联合类型 2.条件语句 1.if语句 1.最简单的if语句 2.if...else语句 3.if...else if....else 语句 2.switch语句 5.函数…

Java 入门教程||Java 关键字

Java 关键字 Java教程 - Java关键字 Java中的关键字完整列表 关键词是其含义由编程语言定义的词。 Java关键字和保留字&#xff1a; abstract class extends implements null strictfp true assert const false import package super try …

二叉排序树的增删改查(java版)

文章目录 1. 基本节点2. 二叉排序树2.1 增加节点2.2 查找&#xff08;就是遍历&#xff09;就一起写了吧2.3 广度优先遍历2.4 删除&#xff08;这个有点意思&#xff09;2.5 测试样例 最后的删除&#xff0c;目前我测试的是正确的 1. 基本节点 TreeNode: class TreeNode{pri…

bugku-web-文件包含2

页面源码 <!-- upload.php --><!doctype html><html><head><meta charset"utf-8"/><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-widt…

Zabbix_Agent一键安装脚本(包含ansible-playbook批量执行脚本)

为了快速安装配置zabbix_agent&#xff0c;特地写了此shell脚本&#xff0c;脚本实现功能如下&#xff1a; 1、自动检测操作系统类型&#xff0c;目前支持Ubuntu/Centos/Redhat 2、自动获取安装agent主机IP地址 3、交互判断主机IP是否可用&#xff0c;用户输入正确IP地址 4、输…

技术周刊的转变:如何平衡热爱与现实?

大家好&#xff0c;我是那个自己打脸自己的猫哥&#xff0c;本来说周刊不做订阅制的&#xff0c;现在却推出了订阅专栏。今天想为自己辩护一下&#xff0c;同时聊聊技术周刊今后的发展计划。 首先回顾一下我过去的想法吧&#xff0c;然后再解释为什么会突然出现转变。 出于对…

2024.4.12力扣每日一题——找到冠军 I

2024.4.12 题目来源我的题解方法一 哈希表方法二 列式遍历统计方法三 列式遍历优化统计 题目来源 力扣每日一题&#xff1b;题序&#xff1a;2923 我的题解 方法一 哈希表 哈希表存储不可能是冠军的队伍&#xff0c;最后没在哈希表中的队伍就是冠军。 时间复杂度&#xff1a…

Python学习之-Pandas详解

前言&#xff1a; Pandas 是一个开源的 Python 数据分析库&#xff0c;它提供了高性能、易于使用的数据结构和数据分析工具。Pandas提供 了方便的类表格和类SQL的操作&#xff0c;同时提供了强大的缺失值处理方法&#xff0c;可以方便的进行数据导入、选取、清洗、处理、合并、…

如何进行宏观经济预测

理性预期经济学提出了理性预期的概念&#xff0c;强调政府在制定各种宏观经济政策时&#xff0c;要考虑到各行为主体预期对政策实施有效性的影响&#xff0c;积极促成公众理性预期的形成&#xff0c;从而更好地实现宏观调控的目标。政府统计要深入开展统计分析预测研究&#xf…

poi-tl的使用(通俗易懂,全面,内含动态表格实现 包会!!)

最近在做项目时候有一个关于解析Html文件&#xff0c;然后将解析的数据转化成word的需求&#xff0c;经过调研&#xff0c;使用poi-tl来实现这个需求&#xff0c;自己学习花费了一些时间&#xff0c;现在将这期间的经验总结起来&#xff0c;让大家可以快速入门 poi-tl的介绍 …

979: 输出利用先序遍历创建的二叉树的后序遍历序列

解法&#xff1a; #include<iostream> using namespace std; struct TreeNode {char val;TreeNode* left;TreeNode* right;TreeNode(char c) :val(c), left(NULL), right(NULL) {}; }; TreeNode* buildTree() {char c;cin >> c;if (c #) {return NULL;}TreeNode*…

Android图形显示架构概览

图形显示系统作为Android系统核心的子系统&#xff0c;掌握它对于理解Android系统很有帮助&#xff0c;下面从整体上简单介绍图形显示系统的架构&#xff0c;如下图所示。 这个框架只包含了用户空间的图形组件&#xff0c;不涉及底层的显示驱动。框架主要包括以下4个图形组件。…

内网通如何去除广告,内网通免广告生成器

公司使用内网通内部传输确实方便&#xff01;但是会有广告弹窗推送&#xff01;这个很烦恼&#xff01;那么如何去除广告呢&#xff01; 下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1CVVdWexliF3tBaFgN1W9aw?pwdhk7m 提取码&#xff1a;hk7m ID&#xff1a;…

Uniapp小程序路由拦截器、navigator不被拦截

添加一个文件interceptor.js&#xff08;名字随意、位置随意&#xff09; import store from "./store";let config {//白名单页面whiteList: ["/pages/login/login","/pages/guides/guides","/pages/index/index"],//登录页loginPa…

mysql dll文件的缺失和Can‘t connect to MySQL server on ‘localhost‘ (10061)

个人笔记&#xff08;整理不易&#xff0c;有帮助&#xff0c;收藏点赞评论&#xff0c;爱你们&#xff01;&#xff01;&#xff01;你的支持是我写作的动力&#xff09; 笔记目录&#xff1a;学习笔记目录_pytest和unittest、airtest_weixin_42717928的博客-CSDN博客 个人随笔…

InternlM2

第一次作业 基础作业 进阶作业 1. hugging face下载 2. 部署 首先&#xff0c;从github上git clone仓库 https://github.com/InternLM/InternLM-XComposer.git然后里面的指引安装环境

【自研网关系列】请求服务模块和客户端模块实现

&#x1f308;Yu-Gateway&#xff1a;&#xff1a;基于 Netty 构建的自研 API 网关&#xff0c;采用 Java 原生实现&#xff0c;整合 Nacos 作为注册配置中心。其设计目标是为微服务架构提供高性能、可扩展的统一入口和基础设施&#xff0c;承载请求路由、安全控制、流量治理等…