深度学习之用CelebA_Spoof数据集搭建一个活体检测-一些模型训练中的改动带来的改善

实验背景

在前面的深度学习之用CelebA_Spoof数据集搭建一个活体检测-模型搭建和训练,我们基于CelebA_Spoof数据集构建了一个用SqueezeNe框架进行训练的活体2D模型,采用了蒸馏法进行了一些简单的工作。在前面提供的训练参数中,主要用了以下几个参数:

lr=0.001, step_size=10, gamma=0.1, alpha=0.5, T=3.0, epochs=300, log_interval=5, batch_size=128, save_interval=10, eval_interval=5

但是,效果并不是很好,得到的评估指标为:

{'Accuracy': 0.8841447074586869, 'Precision': 0.9865851252712566, 'Recall': 0.8468018456588917, 'F1': 0.9113647235700131, 'FPR': 0.027303754266211604, 'ROC_AUC': 0.9661331014932475, 'EER': 0.076691427424212, 'PR_AUC': 0.9862052738456543, 'AP': 0.9862056134292535}

对于一个好的活体检测模型来说,各项指标都不是很好。对于这个指标,就需要进一步进行全面的分析了,如:预处理、训练的各个参数的玄学调整,模型结构深度,蒸馏中的权重比等等之类。在各种折腾后都得不到比较好的改变,于是想在特征上进行改进,如果人工再加点特征试试,会是怎样?这突发奇想就想到了:傅里叶变换。为什么用它,因为非活体的照片很多都是翻拍的,那么因为相机或者屏幕的闪烁,可能会出现一些条纹或者频域上的特征,这些就有可能很好的区分这两类图片。为了提升模型对伪造攻击的识别能力,我们尝试在训练过程中加入傅里叶变换作为辅助特征。

方法对比

基线模型训练时候的训练过程(无傅里叶变换)

直接是普通的蒸馏训练过程,正常的损失计算。

# 传统RGB图像预处理
def _compute_loss(self, student_out, teacher_out, targets):current_T = max(1.0, self.args.T * (0.95 ** (self.current_epoch/10)))"""计算蒸馏损失"""# KL散度损失kl_loss = nn.KLDivLoss(reduction='batchmean')(torch.log_softmax(student_out/self.args.T, dim=1),torch.softmax(teacher_out/self.args.T, dim=1)) * (current_T ** 2)# 交叉熵损失ce_loss = self.criterion(student_out, targets)total_loss = self.args.alpha * kl_loss + (1 - self.args.alpha) * ce_lossreturn total_lossdef train_epoch(self, train_loader, epoch):try:"""完整训练逻辑"""self.student.train()self.current_epoch = epochtotal_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(self.device), target.to(self.device)if self.gpu == 0 and batch_idx % 200 == 0:current_lr = self.optimizer.param_groups[0]['lr']print(f'当前学习率: {current_lr:.6f}')  # 添加这行打印学习率self.optimizer.zero_grad()# 前向传播student_out = self.student(data)with torch.no_grad():teacher_out = self.teacher(data)        # 计算损失loss = self._compute_loss(student_out, teacher_out, target)# 反向传播loss.backward()self.optimizer.step()# 统计指标total_loss += loss.item()_, predicted = student_out.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 主进程打印日志if self.gpu == 0 and batch_idx % 200 == 0:avg_loss = total_loss / (batch_idx + 1)acc = 100. * correct / totalprint(f'Epoch {epoch} Batch {batch_idx}/{len(train_loader)} 'f'Loss: {avg_loss:.4f} | Acc: {acc:.2f}%')self.scheduler.step()return {'loss':total_loss / len(train_loader),'accuracy': 100.0 * correct / total}except Exception as e:if "NCCL" in str(e):print(f"NCCL错误发生,尝试恢复训练...")torch.distributed.destroy_process_group()torch.distributed.init_process_group(backend='nccl')return {'loss': 0, 'accuracy': 0}else:print(f"训练过程中发生错误: {str(e)}")raise e

改进模型(加入傅里叶变换)

加入的傅里叶变换该怎么加呢,我们只在训练过程中加入,那么得到的特征中具有较好区分性就行,所以不需要将输入图像数据都进行傅里叶变换,这样也防止在后续的推理过程中都需要进行傅里叶变换,增加无畏的动作和减少更多的特征内卷。
训练过程中,正常的输入图像数据,正常的教师学生模型的特征求取,但是同时采用傅里叶变换对图像数据进行预处理,用于后续的损失函数加入。
训练参数为:

lr=0.001, step_size=10, gamma=0.9, alpha=0.5, T=3.0, epochs=300, log_interval=5, batch_size=128, save_interval=10, eval_interval=5
    def _fourier_transform(self, x):x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1)))  # 中心化频谱x = torch.abs(x)# 动态调整滤波区域h, w = x.shape[-2:]crow, ccol = h//2, w//2mask = torch.ones_like(x)mask[..., crow-10:crow+10, ccol-10:ccol+10] = 0.3  # 部分保留中心低频return torch.log(1 + 10*x*mask)  # 增强高频特征
def _compute_loss(self, student_out, teacher_out, targets):current_T = max(1.0, self.args.T * (0.95 ** (self.current_epoch/10)))"""计算蒸馏损失"""# KL散度损失kl_loss = nn.KLDivLoss(reduction='batchmean')(torch.log_softmax(student_out/self.args.T, dim=1),torch.softmax(teacher_out/self.args.T, dim=1)) * (current_T ** 2)# 交叉熵损失ce_loss = self.criterion(student_out, targets)total_loss = self.args.alpha * kl_loss + (1 - self.args.alpha) * ce_loss#if self.gpu == 0:  # 仅主进程打印#    print(f"原始损失 - KL: {kl_loss.item():.4f} | CE: {ce_loss.item():.4f}")# 添加频域分支损失if self.use_freq:base_weight = self.freq_weight  # 基础权重dynamic_weight = min(0.25, 0.15 + self.current_epoch*0.001)freq_loss = self.criterion(self.freq_pred, targets) * base_weight * dynamic_weighttotal_loss += freq_loss#if self.gpu == 0:#    print(f"频域分支损失: {freq_loss.item():.4f} (权重: {self.freq_weight})")return total_lossdef train_epoch(self, train_loader, epoch):try:"""完整训练逻辑"""self.student.train()self.current_epoch = epochtotal_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(self.device), target.to(self.device)if self.gpu == 0 and batch_idx % 200 == 0:current_lr = self.optimizer.param_groups[0]['lr']print(f'当前学习率: {current_lr:.6f}')  # 添加这行打印学习率self.optimizer.zero_grad()# 前向传播student_out = self.student(data)with torch.no_grad():teacher_out = self.teacher(data)if self.use_freq:# 频域处理with torch.no_grad():freq_data = self._fourier_transform(data)self.freq_pred = self.freq_branch(freq_data).squeeze()          # 计算损失loss = self._compute_loss(student_out, teacher_out, target)# 反向传播loss.backward()self.optimizer.step()# 统计指标total_loss += loss.item()_, predicted = student_out.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 主进程打印日志if self.gpu == 0 and batch_idx % 200 == 0:avg_loss = total_loss / (batch_idx + 1)acc = 100. * correct / totalprint(f'Epoch {epoch} Batch {batch_idx}/{len(train_loader)} 'f'Loss: {avg_loss:.4f} | Acc: {acc:.2f}%')self.scheduler.step()return {'loss':total_loss / len(train_loader),'accuracy': 100.0 * correct / total}except Exception as e:if "NCCL" in str(e):print(f"NCCL错误发生,尝试恢复训练...")torch.distributed.destroy_process_group()torch.distributed.init_process_group(backend='nccl')return {'loss': 0, 'accuracy': 0}else:print(f"训练过程中发生错误: {str(e)}")raise e

性能指标对比

指标基线模型傅里叶增强模型提升幅度
Accuracy88.41%93.40%+4.99%
Precision98.66%98.52%-0.14%
Recall84.68%91.86%+7.18%
F1 Score91.14%95.08%+3.94%
ROC AUC96.61%97.83%+1.22%
EER7.67%5.84%-1.83%

关键发现

  1. 召回率显著提升:傅里叶变换帮助模型更好地捕捉伪造痕迹,使召回率提高了7.18%

  2. 等错误率降低:EER从7.67%降至5.84%,表明系统整体性能更均衡

  3. 特征互补性:虽然单独看频域特征效果有限,但与空间特征结合产生了协同效应

实现建议

在本次实验,我是保留了中心低频,增强了高频特征,当然也可以不这么干,毕竟有些低频的信息也有用,需要多次验证采取最好的。代码中添加的过滤如下:

        # 动态调整滤波区域h, w = x.shape[-2:]crow, ccol = h//2, w//2mask = torch.ones_like(x)mask[..., crow-10:crow+10, ccol-10:ccol+10] = 0.3  # 部分保留中心低频return torch.log(1 + 10*x*mask)  # 增强高频特征

结论

傅里叶变换的引入使模型在保持高精确度的同时,显著提升了召回能力。这只是在调整模型过程中的一点小改善,当然还有其他更好的方法,SqueezeNe的模型结构还是浅,如果没有更多的限制,可以加深加大,这样效果会更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/906106.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年PMP 学习二十 第13章 项目相关方管理

第13章 项目相关方管理 序号过程过程组过程组1识别相关方启动2规划相关方管理规划3管理相关方参与与执行4监控相关方参与与监控 相关方管理,针对于团队之外的相关方的,核心目标是让对方为了支持项目,以达到项目目标。 文章目录 第13章 项目相…

GO语言语法---For循环、break、continue

文章目录 1. 基本for循环(类似其他语言的while)2. 经典for循环(初始化;条件;后续操作)3. 无限循环4. 使用break和continue5 . 带标签的循环(可用于break/continue指定循环)1、break带标签2、continue带标签…

CSS- 4.4 固定定位(fixed) 咖啡售卖官网实例

本系列可作为前端学习系列的笔记,代码的运行环境是在HBuilder中,小编会将代码复制下来,大家复制下来就可以练习了,方便大家学习。 HTML系列文章 已经收录在前端专栏,有需要的宝宝们可以点击前端专栏查看! 点…

分布式微服务系统架构第132集:Python大模型,fastapi项目-Jeskson文档-微服务分布式系统架构

加群联系作者vx:xiaoda0423 仓库地址:https://webvueblog.github.io/JavaPlusDoc/ https://1024bat.cn/ https://github.com/webVueBlog/fastapi_plus 这个错误是由于 Python 3 中已经将线程的 isAlive() 方法更名为 is_alive(),但你的调试工…

react路由中Suspense的介绍

好的,我们来详细解释一下这个 AppRouter 组件的代码。 这个组件是一个在现代 React 应用中非常常见的模式,特别是在使用 React Router v6 进行路由管理和结合代码分割(Code Splitting)来优化性能时。 JavaScript const AppRout…

C语言内存函数与数据在内存中的存储

一、c语言内存函数 1、memcpy函数是一个标准库函数,用于内存复制。功能上是用来将一块内存中的内容复制到另一块内存中。用户需要提供目标地址、源地址以及要复制的字节数。例如结构体之间的复制。 memcpy函数的原型是:void* memcpy(void* …

层次原理图

层次原理图简介 层次原理图(Hierarchical Schematic)是一种常用于电子工程与系统设计的可视化工具,通过分层结构将复杂系统分解为多个可管理的子模块。它如同“设计蓝图”,以树状结构呈现整体与局部的关系:顶层展现系…

流程编辑器Bpmn与LogicFlow学习

工作流技术如何与用户交互结合(如动态表单、任务分配)处理过 XML 与 JSON 的转换自定义过 bpmn.js 的样式(如修改节点颜色、形状、图标)扩展过上下文菜单(Palette)或属性面板(Properties Panel&…

LWIP的NETCONN接口

NETCONN接口简介 NETCONN API 使用了操作系统的 IPC 机制, 对网络连接进行了抽象,使用同一的接口完成UDP和TCP连接 NETCONN API接口是在RAW接口基础上延申出来的一套API接口 NETCONN实现原理 2.1,NETCONN控制块 2.2,NETCONN收…

Linux搜索

假如我们要搜索 struct sockaddr_in 我们在命令终端输入 cd/usr/include/ //进入头文件目录地址 /usr/include/ grep " struct sockaddr_in { " *-nir (*是在当前目录,n 是找出来显示行数…

2025长三角杯数学建模B题思路模型代码:空气源热泵供暖的温度预测,赛题分析与思路

2025长三角杯数学建模B题思路模型代码,详细内容见文末名片 空气源热泵是一种与中央空调类似的设备,其结构主要由压缩主机、热交换 器以及末端构成,依靠水泵对末端房屋提供热量来实现制热。空气源热泵作为热 惯性负载,调节潜力巨…

ssh免密码登录

创建秘钥和公钥 ssh-keygen -t rsa 输入上述命令后,直接按回车即可,完成后会在上面信息显示,生成的文件路径信息 id_rsa:秘钥 id_rsa.pub: 公钥 将公钥的内容copy到远端 将id_rsa.pub的内容拷贝到~/.ssh下的authori…

基于Bootstrap 的网页html css 登录页制作成品

目录 前言 一、网页制作概述 二、登录页面 2.1 HTML内容 2.2 CSS样式 三、技术说明书 四、页面效果图 前言 ‌Bootstrap‌是一个用于快速开发Web应用程序和网站的前端框架,由Twitter的设计师Mark Otto和Jacob Thornton合作开发。 它基于HTML、CSS和JavaScri…

20倍云台球机是一种高性能的监控设备

20倍云台球机是一种高性能的监控设备,其主要特点包括20倍光学变焦能力和云台旋转功能。以下是对20倍云台球机的详细分析: 一、主要特点 20倍光学变焦 : 摄像机镜头能够在保持图像清晰度的前提下,将监控目标放大20倍。 这一功能…

大型语言模型应用十大安全风险

40多页LLM应用的十大风险 这是一份关于LLM应用的十大风险(2025版),有一定的参考价值。 如果你时间充裕,可以听听播客,详细了解: 如果你只想快速了解10条分别是什么,可以直接看重点摘录&#xff…

一文掌握工业相机选型计算

目录 一、基本概念 1.1 物方和像方 1.2 工作距离和视场 1.3 放大倍率 1.4 相机芯片尺寸 二、公式计算 三、实例应用 一、基本概念 1.1 物方和像方 在光学领域,物方(Object Space)是与像方(Image Space)相对的…

《虚拟即真实:数字人驱动技术在React Native社交中的涅槃》

当React Native与数字人驱动技术相遇,它们将如何携手塑造社交应用中智能客服与虚拟主播的自然交互呢?这正是本文要深入探讨的话题。 React Native是Facebook开源的一个用于构建原生移动应用的框架,它允许开发者使用JavaScript和React编写代码…

使用AI 生成PPT 最佳实践方案对比

文章大纲 一、专业AI生成工具(推荐新手)**1. 推荐工具详解****2. 操作流程优化****3. 优势与局限**二、代码生成方案(开发者推荐)**1. Python-pptx进阶用法****2. GitHub推荐**三、混合工作流(平衡效率与定制)**1. 工具链升级****2. 示例Markdown结构**四、网页转换方案(…

前端-HTML元素

目录 HTML标签是什么? 什么是HTML元素? HTML元素有哪些分类方法? 什么是HTML头部元素 更换路径 注:本文以leetbook为基础 HTML标签是什么? HTML标签是HTML语言中最基本单位和重要组成部分 虽然它不区分大小写&a…

菱形继承原理

在C中,菱形继承的内存模型会因是否使用虚继承产生本质差异。我们通过具体示例说明两种场景的区别: 一、普通菱形继承的内存模型 class A { int a; }; class B : public A { int b; }; class C : public A { int c; }; class D : public B, public C { i…