【transformers.Trainer填坑】在自定义compute_metrics时logits和labels数据维度不一致问题

问题描述

我在使用 transformers.Trainer 训练我的模型时,我自定义了 compute_loss 函数和compute_metrics函数,我的模型是一个简单的二分类模型。

在自定义 compute_loss 时这样写的:

def compute_loss(self, model, inputs, return_outputs=False):"""重写 Trainer.compute_loss:1) 提取字典中的 images, bboxes, locs, labels 等2) 用 vision_encoder 先处理图像,得到特征3) 用下游 model 做预测4) 计算并返回 loss"""# 前向传播outputs, labels = model(**inputs)  # (bz, num_classes), or (bz*num_frames, num_classes)batch_size = inputs['labels'].shape[0]outputs = outputs.squeeze()  # (bz*num_frames)if batch_size == 1:outputs = outputs.unsqueeze(0)# 计算 lossloss = self.loss_func(outputs, labels.float())if self.state.global_step % 10 == 0 and self.state.global_step > 0:# 以50个step为间隔打印pred_probs = torch.sigmoid(outputs)preds = (pred_probs > 0.5).int()logger.info(f"[global_step={self.state.global_step}] preds={preds.tolist()} / labels={labels.tolist()} / loss={loss.item():.4f}")# compute metricaccuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())precision = precision_score(labels.cpu().numpy(), preds.cpu().numpy())recall = recall_score(labels.cpu().numpy(), preds.cpu().numpy())logger.info(f"[global_step={self.state.global_step}] accuracy={accuracy:.4f} / precision={precision:.4f} / recall={recall:.4f}")# 返回 (loss, outputs) 或者只返回 lossreturn (loss, outputs) if return_outputs else loss

于是就出现了报错,像这样的:

File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3754, in predictoutput = eval_loop(File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loopmetrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))File "/workspace/train/object_query/train.py", line 281, in compute_metricscorrect_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)output = eval_loop(File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loopmetrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))File "/workspace/train/object_query/train.py", line 281, in compute_metricscorrect_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)

原因

该问题是 transformers.Trainer 内部有一段对outputs的操作造成的:

if isinstance(outputs, dict):logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:logits = outputs[1:]

这里当 outputs 不是字典时,会把第一个位置的元素offset掉。

解决

Refer to here
所以,我们应该在返回那里这样写:

return (loss, {"label": outputs}) if return_outputs else loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/69919.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文学习记录之《CLR-VMB》

目录 一、基本介绍 二、介绍 三、方法 3.1 FWI中的数据驱动方法 3.2 CLR-VMB理论 3.3 注意力块 四、网络结构 4.1 网络架构 4.2 损失函数 五、实验 5.1 数据准备 5.2 实验设置 5.3 训练和测试 5.4 定量分析 5.5 CLR方案的有效性 5.6 鲁棒性 5.7 泛化性 六、讨…

【STM32】舵机SG90

1.舵机原理 舵机内部有一个电位器,当转轴随电机旋转,电位器的电压会发生改变,电压会带动转一定的角度,舵机中的控制板就会电位器输出的电压所代表的角度,与输入的PWM所代表的角度进行比较,从而得出一个旋转…

算法刷题-链表系列-移除链表、设计链表、翻转列表

题目要求 所有主要考察对链表的增删查改的功能 总结 对于有些从头遍历到尾的方法,创建一个头结点使得所有的结点能以统一的方式且全部被遍历到,不会出现头结点不被遍历的问题。对于遍历的条件,有的时候curNode ! nullptr,有的时…

Django项目中创建app并快速上手(pycharm Windows)

1.打开终端 我选择的是第二个 2.运行命令 python manage.py startapp 名称 例如: python manage.py startapp app01 回车,等待一下,出现app01的文件夹说明创建成功 3.快速上手 1.app注册 增加一行 "app01.apps.App01Config"&#…

Windows系统安装搭建悟空crm客户管理系统 教程

1、在安装悟空 CRM 之前,需要确保你的 Windows 系统上已经安装了以下软件: Web 服务器:推荐使用 Apache 或 Nginx,这里以 Nginx 为例。你可以使用集成环境套件如 XAMPP 来简化安装过程,它包含了 Nginx 、MySQL、PHP 等…

深度学习框架探秘|TensorFlow vs PyTorch:AI 框架的巅峰对决

在深度学习框架中,TensorFlow 和 PyTorch 无疑是两大明星框架。前面两篇文章我们分别介绍了 TensorFlow(点击查看) 和 PyTorch(点击查看)。它们引领着 AI 开发的潮流,吸引着无数开发者投身其中。但这两大框…

java每日精进 2.13 Ganache(区块链本地私有化部署)

需求:使用区块链实现数据村存储,记录一些不可篡改的交互信息,网络环境为内外网均需要部署; 1.准备工作(软件安装) 1.1 安装 Node.js 和 npm 1.2 安装 Ganache 地址如下:windows有可视化界面 &a…

w206基于Spring Boot的农商对接系统的设计与实现

🙊作者简介:多年一线开发工作经验,原创团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹赠送计算机毕业设计600个选题excel文…

chrome://version/

浏览器输入: chrome://version/ Google浏览器版本号以及安装路径 Google Chrome131.0.6778.205 (正式版本) (64 位) (cohort: Stable) 修订版本81b36b9535e3e3b610a52df3da48cd81362ec860-refs/branch-heads/6778_155{#8}操作系统Windows…

哈希槽算法与一致性哈希算法比较

Redis 集群模式使用的 哈希槽(Hash Slot) 算法与传统的 一致性哈希(Consistent Hashing) 算法在数据分布和节点管理上有显著的区别。以下是两者的详细比较: 1. Redis 哈希槽算法 1.1 基本原理 Redis 集群将整个数据集…

【BUUCTF逆向题】[WUSTCTF2020]level3(魔改base64)

一.[WUSTCTF2020]level3 打开IDA反汇编,发现就是base64加密 这里rand就是与&搭配设置奇偶数2分随机 但是根据提示不是标准base64加密 首先想到魔改密码表,追踪进去,发现没有什么变化啊 尝试对Base64字符串解码也不对 追踪密码表CtrlX发…

有关Java中的接口

学习目标 掌握接口语法理解接口多态熟练使用接口了解接口新特性掌握final关键字了解lambda语法 1.接口语法 1.1 接口概念 从功能上看, 实现接口就意味着扩展了某些功能 接口与类之间不必满足is-a的关系结构 从抽象上看, 接口是特殊的抽象父类 从规则上看, 接口定义者和实…

鸿蒙(openharmony) 5.0 光感接口崩溃

目录 1.背景 2.解决方案 1.背景 使用OpenHarmony 5.0调用光感接口崩溃,返回的值是undefined,接口如下: sensor.on(sensor.SensorId.AMBIENT_LIGHT, (data) => {if (data == null || data == undefined || data.intensity == null || data.intensity == undefined) {ret…

git用法(简易版)

介绍 git是一个版本管理工具 使用方法 建立仓库 第一步 git init:初始化仓库 第二步 git add .:将代码添加到暂存区 第三步 git commit -m "first":为修改添加备注 第四步 git remote add origin 你的url 第五步 git pus…

【C++八股】内存泄漏

内存泄漏(Memory Leak)是指程序在动态分配内存后,未能及时释放已分配的内存,导致这些内存无法被再次使用,从而造成系统内存的浪费。随着时间的推移,内存泄漏可能导致程序性能下降,甚至系统崩溃。…

sqli-labs时间盲注和布尔盲注

1、时间盲注和布尔盲注 在SQL注入攻击中,时间盲注(Time-Based Blind SQL Injection)和布尔盲注(Boolean-Based Blind SQL Injection)是两种常见的技术,用于在无法直接获取数据的情况下推断数据库信息。 2…

数据库脚本MySQL8转MySQL5

由于生产服务器版本上部署的是MySQL5,而开发手里的脚本代码是MySQL8。所以只能降版本了… 升级版本与降级版本脚本转换逻辑一样 MySQL5与MySQL8版本SQL脚本区别 大多数无需调整、主要是字符集与排序规则 MySQL5与MySQL8版本SQL字符集与排序规则 主要操作&…

Flutter 双屏双引擎通信插件加入 GitCode:解锁双屏开发新潜能

在双屏设备应用场景日益丰富的当下,移动应用开发领域迎来了新的机遇与挑战。如何高效利用双屏设备优势,为用户打造更优质的交互体验,成为开发者们关注的焦点。近日,一款名为 Flutter 双屏双引擎通信插件的创新项目正式入驻 GitCod…

Mysql进阶篇(mysqlcheck - 表维护程序)

mysqlcheck的作用 mysqlcheck客户端用于执行表维护,可以对表进行:分析、检查、优化或修复操作。 (1)分析的作用是查看表的关键字分布,能够让 sql 生成正确的执行计划(支持 InnoDB,MyISAM&#x…

如何使用qt开发一个xml发票浏览器,实现按发票样式显示

使用Qt开发一个按发票样式显示的XML发票浏览器,如下图所示样式: 一、需求: 1、按税务发票样式显示。 2、拖入即可显示。 3、正确解析xml文件。 二、实现 可以按照以下步骤进行: 1. 创建Qt项目 打开Qt Creator,创…