在本系列的开篇,就对整个项目训练所需要的所有模块都进行了一个简要的介绍,尤其是针对训练中需要引入的各个结构,进行一个串联操作。
而在之前的数据构建篇和网络模型篇中,都对其中的每一个组块进行了分别的验证,预先在未开始训练前,检验其中的正确性,避免到训练时候,问题连连。
通过这一系列文章的学习后,我相信绝大部分的模块都已经介绍过了。包括:
- 综述篇中对优化器、模型获取和保存模型进行了介绍;
- 在数据流模块中,学习了如何导入数据,验证数据流;
- 网络模型那里,损失函数
loss
的调用。
本篇其实存在的最大意义,就在于将这些零零散散的东西,拼接成一个整体。至于推理阶段,将单独新开一节,放到后面。通过这个系列的学习,也能多一些思考,加深一些感悟。
一、损失函数
在分割任务中,把目标分割任务的mask
,转化为对像素点的分类任务。所以在计算损失的时候,论文里面的损失函数采用的就是交叉熵损失函数。
在后续的损失改进中,多引入dice loss
或focal loss
。我们就从交叉熵损失函数开始,探讨下它为什么可以应用在分割任务中。
本文继续沿着在网络模型评估阶段,使用的交叉熵损失函数,定义如下。对于其他分割的损失函数,参考这篇文章:【AI面试】CrossEntropy Loss 、Bal