focal loss的几种实现版本(Keras/Tensorflow)

起源于在工作中使用focal loss遇到的一个bug,我仔细的学习多个靠谱的focal loss讲解及实现版本

通过测试,我发现了这样一个奇怪的现象,几乎每个版本的focal loss实现对同样的输入计算出的loss都是不同的。

通过仔细的比对和思考,我总结了三种我认为正确的focal loss实现方法,并将代码分享出来。

完整的代码我整理到了我的github代码库AI-Toolbox中,代码戳这里

何为focal loss

focal loss 是随网络RetinaNet一起提出的一个令人惊艳的损失函数 paper 下载,主要针对的是解决正负样本比例严重偏斜所产生的模型难以训练的问题。

这里假设你对focal loss有所了解,简单回顾下公式 ,focal loss的定义如下:
focal loss
其中
pt
公式中γ{\gamma}γα{\alpha}α是两个可以调节的超参数。

γ{\gamma}γ的含义更好理解一些,其作用是削弱那些模型已经能够较好预测的样本产生损失的权重,使模型更专注于学习那些较难的hard case。

αt{\alpha}_tαt的定义,原文中的表述是:

For notational convenience, we define αt analogously to how we defined pt

也就是说,αt{\alpha}_tαt的定义可以同理于ptp_tpt的定义。它的作用是平衡类别之间的权重。

这里补充一句,网上能够找到的各种不同版本的focal loss实现,分歧基本都出现在这里。由于focal loss最初是伴随着目标检测中判断某个区域是物体or背景(二分类问题)出现的,当我们使用focal loss来解决更一般化的问题时(比如多分类问题、多标签预测问题),αt{\alpha}_tαt 如何定义便会产生分歧,很难说哪种是绝对正统的,因为不同的定义赋予了损失函数不同的功能,可以针对不同的问题。

让我们来看看,我总结的三种实现版本。

focal loss for binary classification

针对二分类版本的 focal loss 实现

def binary_focal_loss(gamma=2, alpha=0.25):"""Binary form of focal loss.适用于二分类问题的focal lossfocal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.References:https://arxiv.org/pdf/1708.02002.pdfUsage:model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)"""alpha = tf.constant(alpha, dtype=tf.float32)gamma = tf.constant(gamma, dtype=tf.float32)def binary_focal_loss_fixed(y_true, y_pred):"""y_true shape need be (None,1)y_pred need be compute after sigmoid"""y_true = tf.cast(y_true, tf.float32)alpha_t = y_true*alpha + (K.ones_like(y_true)-y_true)*(1-alpha)p_t = y_true*y_pred + (K.ones_like(y_true)-y_true)*(K.ones_like(y_true)-y_pred) + K.epsilon()focal_loss = - alpha_t * K.pow((K.ones_like(y_true)-p_t),gamma) * K.log(p_t)return K.mean(focal_loss)return binary_focal_loss_fixed

在使用本损失函数前,假设你已经将每个样本使用sigmoid映射成了一个0-1之间的数,代表二分类的概率。

在keras中使用此函数作为损失函数,只需在编译模型时指定损失函数为focal loss:

model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=optimizer)

focal loss for multi category 版本1

针对多分类问题或多标签问题的 focal loss 实现1.

前面已经提到网上不同的实现版本中αt{\alpha}_tαt的定义存在一定的分歧

当我们使用αt{\alpha}_tαt来控制不同类别 / 标签 的权重时,实现代码如下:

def multi_category_focal_loss1(alpha, gamma=2.0):"""focal loss for multi category of multi label problem适用于多分类或多标签问题的focal lossalpha用于指定不同类别/标签的权重,数组大小需要与类别个数一致当你的数据集不同类别/标签之间存在偏斜,可以尝试适用本函数作为lossUsage:model.compile(loss=[multi_category_focal_loss1(alpha=[1,2,3,2], gamma=2)], metrics=["accuracy"], optimizer=adam)"""epsilon = 1.e-7alpha = tf.constant(alpha, dtype=tf.float32)#alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)#alpha = tf.constant_initializer(alpha)gamma = float(gamma)def multi_category_focal_loss1_fixed(y_true, y_pred):y_true = tf.cast(y_true, tf.float32)y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)ce = -tf.log(y_t)weight = tf.pow(tf.subtract(1., y_t), gamma)fl = tf.matmul(tf.multiply(weight, ce), alpha)loss = tf.reduce_mean(fl)return lossreturn multi_category_focal_loss1_fixed

注意,你需要将α{\alpha}α指定为一个数组,数组大小需要与类别个数一致,代表着每一个类别对应的权重。

当你的数据集不同类别/标签之间存在偏斜,可以尝试适用本函数作为loss。

我们将核心函数copy出来做一个简单的测试,来验证α{\alpha}α平衡类别间权重的有效性。

import os
from keras import backend as K
import tensorflow as tf
import numpy as npos.environ["CUDA_VISIBLE_DEVICES"] = '0'def multi_category_focal_loss1(y_true, y_pred):epsilon = 1.e-7gamma = 2.0#alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)y_true = tf.cast(y_true, tf.float32)y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)ce = -tf.log(y_t)weight = tf.pow(tf.subtract(1., y_t), gamma)fl = tf.matmul(tf.multiply(weight, ce), alpha)loss = tf.reduce_mean(fl)return loss
Y_true = np.array([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0]])
Y_pred = np.array([[0.3, 0.99, 0.8, 0.97, 0.85], [0.9, 0.05, 0.1, 0.09, 0]], dtype=np.float32)
print(K.eval(multi_category_focal_loss1(Y_true, Y_pred)))

假设我们正在处理一个5个输出的多label预测问题,按照上面的示例,假设我们的模型对于第一个label相比于其它标签的预测很糟糕(这可能是由于第一个label出现的概率很小,在算损失时没有话语权导致的)。

上面代码的运算结果是1.2347984

我们使用α{\alpha}α来调节第一个label的权重,尝试将α{\alpha}α修改为:

alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)

重新运行,损失增大为2.4623184,说明损失函数成功的放大了第一个类别的权重,会使模型更重视第一个label的正确预测。

focal loss for multi category 版本2

针对多分类问题或多标签问题的 focal loss 实现2.

当我们使用 αt{\alpha}_tαt 来控制真值y_true为 1 or 0 时的权重时

即 y = 1 时的权重为 α{\alpha}α, y = 0时的权重为 1−α1-{\alpha}1α

实现代码如下:

def multi_category_focal_loss2(gamma=2., alpha=.25):"""focal loss for multi category of multi label problem适用于多分类或多标签问题的focal lossalpha控制真值y_true为1/0时的权重1的权重为alpha, 0的权重为1-alpha当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小当模型过于惰性(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征)尝试将alpha调大,鼓励模型进行预测出1。Usage:model.compile(loss=[multi_category_focal_loss2(alpha=0.25, gamma=2)], metrics=["accuracy"], optimizer=adam)"""epsilon = 1.e-7gamma = float(gamma)alpha = tf.constant(alpha, dtype=tf.float32)def multi_category_focal_loss2_fixed(y_true, y_pred):y_true = tf.cast(y_true, tf.float32)y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)alpha_t = y_true*alpha + (tf.ones_like(y_true)-y_true)*(1-alpha)y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)ce = -tf.log(y_t)weight = tf.pow(tf.subtract(1., y_t), gamma)fl = tf.multiply(tf.multiply(weight, ce), alpha_t)loss = tf.reduce_mean(fl)return lossreturn multi_category_focal_loss2_fixed

注意,你需要将α{\alpha}α指定为一个数组,数组大小需要与类别个数一致,代表着每一个类别对应的权重。

当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss

当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小

当模型过于“懒惰”时(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征),尝试将alpha调大,鼓励模型预测出1。

同样地,我们将核心函数copy出来做一个简单的测试,来验证α{\alpha}α平衡0-1权重的有效性。

import os
from keras import backend as K
import tensorflow as tf
import numpy as npos.environ["CUDA_VISIBLE_DEVICES"] = '0'def multi_category_focal_loss2_fixed(y_true, y_pred):epsilon = 1.e-7gamma=2.alpha = tf.constant(0.5, dtype=tf.float32)y_true = tf.cast(y_true, tf.float32)y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)alpha_t = y_true*alpha + (tf.ones_like(y_true)-y_true)*(1-alpha)y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)ce = -tf.log(y_t)weight = tf.pow(tf.subtract(1., y_t), gamma)fl = tf.multiply(tf.multiply(weight, ce), alpha_t)loss = tf.reduce_mean(fl)return loss
Y_true = np.array([[1, 1, 1, 1, 1], [0, 1, 1, 1, 1]])
Y_pred = np.array([[0.9, 0.99, 0.8, 0.97, 0.85], [0.9, 0.95, 0.91, 0.99, 1]], dtype=np.float32)
print(K.eval(multi_category_focal_loss2_fixed(Y_true, Y_pred)))

仍然假设我们正在处理一个5个输出的多label预测问题

按照上面的示例,假设这次我们遇到的问题是,所有的标签都会有很高的概率出现1,这时我们的模型发现了一个投机取巧的办法,将每个结果都预测为1,即可得到很小的loss,于是模型严重的欠拟合。

上面代码的运算结果是0.093982555,如我们所料,损失并不大,这显然会影响模型成功收敛。

我们使用α{\alpha}α来抑制模型输出1的权重,尝试将α{\alpha}α修改为:

alpha = tf.constant(0.25, dtype=tf.float32)

重新运行,损失增大为0.14024596,说明损失函数成功的放大了这种投机行为的损失。

参考文献

focal loss paper
Keras自定义Loss函数
Keras中自定义复杂的loss函数
github: focal-loss-keras 实现1
github: focal-loss-keras 实现2
kaggle kernel: FocalLoss for Keras
Focal Loss理解
应用:Multi-class classification with focal loss for imbalanced datasets

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/499484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于ARM Cortex-M和Eclipse的SWO单总线输出

最近在MCU on Eclipse网站上看到Erich Styger所写的一篇有关通过SWD的跟踪接口SWO获取ARM Cortex-M相关信息的文章,文章结构明晰,讲解透彻,本人深受启发,特意将其翻译过来供各位同仁参考。当然限于个人水平,有不当之处…

包管理工具conda极简教程

包管理工具conda极简教程 conda的作用 Anaconda是目前非常流行的一个python包管理器,自带很多流行的python库,包括numpy,pandas等,当然还有conda。而Conda是一个开源的软件包管理系统和环境管理系统,用于安装多个版本…

PID控制器开发笔记之九:基于前馈补偿的PID控制器的实现

对于一般的时滞系统来说,设定值的变动会产生较大的滞后才能反映在被控变量上,从而产生合理的调节。而前馈控制系统是根据扰动或给定值的变化按补偿原理来工作的控制系统,其特点是当扰动产生后,被控变量还未变化以前,根…

借助百度识图爬取数据集

背景 一个能够实际应用的深度学习模型,背后的数据集往往都花费了大量的人力财力,通过聘用标注团队对真实场景数据进行标注生产出来,大多数情况不太可能使用网络来源的图片。但在项目初期的demo阶段,或者某些特定的场合下&#xf…

通过printf从目标板到调试器的输出

最近在SEGGER的博客上看到Johannes Lask写的一篇关于在调试时使用printf函数从目标MCU输出信息到调试器的文章,自我感觉很有启发,特此翻译此文并推荐给各位同仁。当然限于个人水平,有不当之处恳请指正。原文网址:https://blog.seg…

小心使用tf.image.resize_images,填坑经验分享给你

上上周,我在一个项目上线前对模型进行测试时出现了问题,这个问题困扰了我近两周,终于找到了问题根源,做个简短总结分享给你,希望对大家有帮助。 问题描述: 线上线下测试结果不一致,且差异很大…

PID控制器开发笔记之十:步进式PID控制器的实现

对于一般的PID控制系统来说,当设定值发生较大的突变时,很容易产生超调而使系统不稳定。为了解决这种阶跃变化造成的不利影响,人们发明了步进式PID控制算法。 1、步进式PID的基本思想 所谓步进式PID算法,实际就是在设定值发生阶跃…

AutoML 与 Bayesian Optimization 概述

1. AutoML 概述 AutoML是指对于一个超参数优化任务(比如规定计算资源内,调整网络结构找到准确率最高的网络),尽量减少人为干预,使用某种学习机制,来调节这些超参数,使得目标问题达到最优。 这…

使用Eclipse进行Makefile项目

最近在MCU on Eclipse网站上看到Erich Styger所写的一篇有关在Eclipse中使用Makefile创建项目的文章,文章讲解清晰明了非常不错,所以呢没人将其翻译过来供各位同仁参考。当然限于个人水平,有不当之处恳请指正。原文网址:https://m…

Git commit 常用表情快速查询

git commit 的时候,添加表情符号可以更好的表明本次提交的性质,也更有趣。 常用表情符号如下: emoji emoji代码 commit说明 🎨 (调色板) :art: 改进代码结构/代码格式 ⚡️ (闪电) :zap: 提升性能 🐎 (赛马)…

C语言学习及应用笔记之一:C运算符优先级及使用问题

C语言中的运算符绝对是C语言学习和使用的一个难点,因为在2011版的标准中,C语言的运算符的数量超过40个,甚至比关键字的数量还要多。这些运算符有单目运算符、双目运算符以及三目运算符,又涉及到左结合和右结合的问题,真…

Docker用法整理

Docker教程推荐 两个不错的参考资料&#xff1a; https://yeasy.gitbooks.io/docker_practice/content/introduction/ https://www.cnblogs.com/bethal/p/5942369.html 镜像&#xff1a; 查看镜像 docker images ls 删除镜像 docker image rm <image id> 拉取镜像 …

使用FreeRTOS进行性能和运行时分析

在MCU on Eclipse网站上看到Erich Styger在2月25日发的博文&#xff0c;一篇关于使用FreeRTOS进行性能和运行分析的文章&#xff0c;本人觉得很有启发&#xff0c;特将其翻译过来以备参考。当然限于个人水平&#xff0c;有描述不当之处恳请指正。原文网址&#xff1a;https://m…

生成微信公众号对应二维码的两种简单方法

方法1 在浏览器中打开下面的链接 https://open.weixin.qq.com/qr/code?usernameName 其中Name替换为对应公众号的微信号 例如&#xff0c;我们打算生成公众号 AI算法联盟 的二维码 只需首先关注这个公众号 在其详细信息中&#xff0c;查找到微信号信息&#xff1a;AIReport…

在Amazon FreeRTOS V10中使用运行时统计信息

在MCU on Eclipse网站上看到Erich Styger在8月2日发的博文&#xff0c;一篇关于在Amazon FreeRTOS V10中使用运行时统计信息的文章&#xff0c;本人觉得很有启发&#xff0c;特将其翻译过来以备参考。原文网址&#xff1a;https://mcuoneclipse.com/2018/08/02/tutorial-using-…

github无法加载图片的解决办法

最近发现我的github上面项目README里面的图片全裂了&#xff0c;一直以为是github最近服务器不稳定。今天通过简单的查询&#xff0c;发现原来这个问题可以解决&#xff0c;但是不能永久有效&#xff0c;之后还会用到&#xff0c;因此记录在这里&#xff0c; 也分享给大家。 解…

leetcode 1.两数之和

题目 链接&#xff1a;https://leetcode-cn.com/problems/two-sum 给定一个整数数组 nums 和一个目标值 target&#xff0c;请你在该数组中找出和为目标值的那 两个 整数&#xff0c;并返回他们的数组下标。 你可以假设每种输入只会对应一个答案。但是&#xff0c;你不能重复…

C语言学习及应用笔记之二:C语言static关键字及其使用

C语言有很多关键字&#xff0c;大多关键字使用起来是很明确的&#xff0c;但有一些关键字却要相对复杂一些。我们这里要说明的static关键字就是如此&#xff0c;它的功能很强大&#xff0c;相应的使用也就更复杂。 一般来说static关键字的常见用法有三种&#xff1a;一是用作局…

μCUnit,微控制器的单元测试框架

在MCU on Eclipse网站上看到Erich Styger在8月26日发布的博文&#xff0c;一篇关于微控制器单元测试的文章&#xff0c;有很高的参考价值&#xff0c;特将其翻译过来以备学习。原文网址&#xff1a;https://mcuoneclipse.com/2018/08/26/tutorial-%CE%BCcunit-a-unit-test-fram…

leetcode No.15-16 三数之和相关问题

leetcode 15. 三数之和 题目 链接&#xff1a;https://leetcode-cn.com/problems/3sum 给定一个包含 n 个整数的数组 nums&#xff0c;判断 nums 中是否存在三个元素 a&#xff0c;b&#xff0c;c &#xff0c;使得 a b c 0 &#xff1f;找出所有满足条件且不重复的三元组…