论文辅助笔记:LLM-Mob metric测量

0 导入库

import os
import pandas as pd
from sklearn.metrics import f1_score
import ast
import numpy as np

1 基本的metric计算方式

1.1 get_acc1_f1

def get_acc1_f1(df):#计算top1 prediction的准确度和f1 scoreacc1 = (df['prediction'] == df['ground_truth']).sum() / len(df)f1 = f1_score(df['ground_truth'], df['prediction'], average='weighted')#根据支持度(每个标签的真实实例数)加权平均return acc1, f1

1.2 get_is_correct

def get_is_correct(row):#计算ground truth是否在top k prediction内pred_list = row['prediction']if row['ground_truth'] in pred_list:row['is_correct'] = Trueelse:row['is_correct'] = Falsereturn row

1.3 get_is_correct10

def get_is_correct10(row):#计算ground truth是否在top 10,top 5, top 1 prediction内pred_list = row['top10']if row['ground_truth'] in pred_list:row['is_correct10'] = Trueelse:row['is_correct10'] = Falsepred_list = row['top5']if row['ground_truth'] in pred_list:row['is_correct5'] = Trueelse:row['is_correct5'] = Falsepred = row['top1']if pred == row['ground_truth']:row['is_correct1'] = Trueelse:row['is_correct1'] = Falsereturn row

1.4 first_nonzero

def first_nonzero(arr, axis, invalid_val=-1):mask = arr!=0return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)
#在给定轴上找到数组中第一个非零元素的索引。如果没有非零元素,则返回一个无效值

1.5 get_ndcg

#计算归一化折扣累积增益(NDCG),这是评估排名质量的一种方式,尤其用于推荐系统和信息检索
def get_ndcg(prediction, targets, k=10):"""Calculates the NDCG score for the given predictions and targets.Args:prediction (Nxk): list of lists. the softmax output of the model.targets (N): torch.LongTensor. actual target place id.Returns:the sum ndcg score"""for _, xi in enumerate(prediction):#首先遍历prediction列表中的每个子列表xiif len(xi) < k:xi += [-5 for _ in range(k-len(xi))]#如果xi的长度小于k,则将其通过添加特定值(-5)扩展到k的长度elif len(xi) > k:xi = xi[:k]#如果长度大于k,则截断至kelse:pass'''确保每个预测列表都有k个元素,方便后续操作'''n_sample = len(prediction)prediction = np.array(prediction)targets = np.broadcast_to(targets.reshape(-1, 1), prediction.shape)#targets被重塑并广播到与prediction相同的形状,以便可以逐元素比较hits = first_nonzero(prediction == targets, axis=1, invalid_val=-1)#调用first_nonzero函数,该函数返回prediction中与targets相等的元素的第一个索引位置#如果没有匹配的元素,则返回一个预先设定的无效值(-1)hits = hits[hits>=0]ranks = hits + 1#计算每个有效命中的排名(即索引位置加1,因为索引是从0开始的ndcg = 1 / np.log2(ranks + 1)#计算每个排名的折扣增益,使用公式1 / np.log2(ranks + 1)return np.sum(ndcg) / n_sample#计算所有样本的平均NDCG分数

2 Top10预测指标衡量

2.1 文件列表获取

output_dir = 'results/geolife/top10_wot'
file_list = [file for file in os.listdir(output_dir) if file.endswith('.csv')]
file_list

file_path_list = [os.path.join(output_dir, file) for file in file_list]
file_path_list

iter_df = pd.read_csv(file_path_list[0])
iter_df 

 

2.2. 创建结果dataframe

df = pd.DataFrame({'user_id': None,'ground_truth': None,'prediction': None,'reason': None
}, index=[])
df

for file_path in file_path_list:iter_df = pd.read_csv(file_path)if output_dir[-1] != '1':pred_series = iter_df['prediction'].apply(lambda x: ast.literal_eval(x))  # A pandas seriesiter_df['top10'] = pred_series.apply(lambda x: x[:10] if type(x) == list else [x] * 10)iter_df['top5'] = pred_series.apply(lambda x: x[:5] if type(x) == list else [x] * 5)iter_df['top1'] = pred_series.apply(lambda x: x[0] if type(x) == list else x)#如果预测的结果是列表类型(也就是预测top k),那么保存前k个元素的list#如果预测的结果是int类型(预测最有可能的location),那么复制这个元素k次df = pd.concat([df, iter_df], ignore_index=True)
df

2.3 调用get_is_correct10

df = df.apply(func=get_is_correct10, axis=1)
df

2.4 结果计算

acc1 = (df['is_correct1']).sum() / len(df)
acc5 = (df['is_correct5']).sum() / len(df)
acc10 = (df['is_correct10']).sum() / len(df)
f1 = f1_score(df['ground_truth'], df['top1'], average='weighted')
preds = df['top10'].tolist()
targets = np.array(df['ground_truth'].tolist())
ndcg = get_ndcg(prediction=preds, targets=targets, k=10)print("Acc@1: ", acc1)
print("Acc@5: ", acc5)
print("Acc@10: ", acc10)
print("Weighted F1: ", f1)
print("NDCG@10: ", ndcg)
'''
Acc@1:  0.3295750216825672
Acc@5:  0.8291413703382481
Acc@10:  0.8736629083550159
Weighted F1:  0.21629743615527502
NDCG@10:  0.6276420364672752
'''

3 Top1

3.1 读取文件+创建df

output_dir = 'results/geolife/top1'
file_list = [file for file in os.listdir(output_dir) if file.endswith('.csv')]file_path_list = [os.path.join(output_dir, file) for file in file_list]df = pd.DataFrame({'user_id': None,'ground_truth': None,'prediction': None,'reason': None
}, index=[])pd.read_csv(file_path_list[0])

 

3.2  读取prediction 结果


for file_path in file_path_list:iter_df = pd.read_csv(file_path)df = pd.concat([df, iter_df], ignore_index=True)df['prediction'] = df['prediction'].apply(lambda x: int(x))
df['ground_truth'] = df['ground_truth'].apply(lambda x: int(x))
df

3.3 计算metric

acc1, f1 = get_acc1_f1(df)
print("Acc@1: ", acc1)
print("F1: ", f1)
'''
Acc@1:  0.4512864989881469
F1:  0.403742729579556
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/828684.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源数据集分享———猫脸码客

猫脸码客作为一个专注于开源数据集分享的公众号&#xff0c;致力于为广大用户提供丰富、优质的数据资源。我们精心筛选和整理各类开源数据集&#xff0c;涵盖机器学习、深度学习、自然语言处理等多个领域&#xff0c;以满足不同用户的需求。 (https://img-blog.csdnimg.cn/d98…

Exploiting CXL-based Memory for Distributed Deep Learning——论文泛读

ICPP 2022 Paper CXL论文阅读笔记整理 问题 深度学习&#xff08;DL&#xff09;正被广泛用于解决不同领域的科学应用中的复杂问题。DL应用程序使用大规模高性能计算&#xff08;HPC&#xff09;系统来训练给定的模型&#xff0c;需要消耗大量数据。这些工作负载具有很大的内…

Git for Windows 下载与安装

当前环境&#xff1a;Windows 8.1 x64 1 打开网站 https://git-scm.com/ &#xff0c;点击 Downloads 。 2 点击 Windows 。 3 选择合适的版本&#xff0c;这里选择了 32-bit Git for Windows Portable。 4 解压下载后的 PortableGit-2.44.0-32-bit.7z.exe &#xff0c;并将 P…

使用 Flask 和 WTForms 构建一个用户注册表单

在这篇技术博客中&#xff0c;我们将使用 Flask 和 WTForms 库来构建一个用户注册表单。我们将创建一个简单的 Flask 应用&#xff0c;并使用 WTForms 定义一个注册表单&#xff0c;包括用户名、密码、确认密码、邮箱、性别、城市和爱好等字段。我们还将为表单添加验证规则&…

好用的在线客服系统PHP源码(开源代码+终身使用+安装教程) 制作第一步

创建一个在线客服系统是一个涉及多个步骤的过程&#xff0c;包括前端界面设计、后端逻辑处理、数据库设计、用户认证、实时通信等多个方面。以下是使用PHP制作在线客服系统的第一步&#xff1a;需求分析和系统设计。演示&#xff1a;ym.fzapp.top 第一步&#xff1a;需求分析 确…

分布式技术在文本摘要生成中的应用

摘要 自然语言处理首先要应对的是如何表示文本以供机器处理&#xff0c;随着网络技术的发展和信息的公开&#xff0c;因特网上可供访问的数字文档成爆炸式的增长&#xff0c;文本摘要生成逐渐成为了自然语言处理领域的重要研究课题。本文主要介绍了分布式技术在文本摘要生成中…

基于springboot+vue+Mysql的广场舞团管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

猫头虎分享已解决Bug || TypeError: Cannot read property ‘map‘ of undefined**

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

智慧养猪场视频AI智能监控与可视化管理方案

在科技日新月异的今天&#xff0c;智能化、自动化已成为众多行业追求的方向。养猪业作为传统农业的重要组成部分&#xff0c;同样迎来了技术革新的春风。特别是随着人们对食品安全等问题的日益关注&#xff0c;养猪场视频监控监管方案的智能化升级显得尤为重要。 养猪场视频智…

Android11适配

1.分区存储 1.1.背景 Android 11 进一步增强了平台功能&#xff0c;为外部存储设备上的应用和用户数据提供了更好的保护。作为这项工作的一部分&#xff0c;平台引入了进一步的改进&#xff0c;以简化向分区存储的转换。 为了让用户更好地控制自己的文件&#xff0c;保护用户…

(C++) share_ptr 之循环引用

文章目录 &#x1f6a9;前言&#x1f6a9;循环引用&#x1f579;️例子1Code&#x1f62d;shared_ptr &#xff08;错误&#xff09;&#x1f602;weak_ptr &#xff08;正确&#xff09;&#x1f62d;unique_ptr &#xff08;错误&#xff09; &#x1f579;️例子2Code &…

Vu3+QuaggaJs实现web页面识别条形码

一、什么是QuaggaJs QuaggaJS是一个基于JavaScript的开源图像识别库&#xff0c;可用于识别条形码。 QuaggaJs的作用主要体现在以下几个方面&#xff1a; 实时图像处理与识别&#xff1a;QuaggaJs是一款基于JavaScript的开源库&#xff0c;它允许在Web浏览器中实现实时的图像…

LORA详解

参考论文&#xff1a; low rank adaption of llm 背景介绍&#xff1a; 自然语言处理的一个重要范式包括对一般领域数据的大规模预训练和对特定任务或领域的适应处理。在自然语言处理中的许多应用依赖于将一个大规模的预训练语言模型适配到多个下游应用上。这种适配通常是通过…

DiT论文精读Scalable Diffusion Models with Transformers CVPR2023

Scalable Diffusion Models with Transformers CVPR2023 Abstract idea 将UNet架构用Transformer代替。并且分析其可扩展性。 并且实验证明通过增加transformer的宽度和深度&#xff0c;有效降低FID 我们最大的DiT-XL/2模型在classconditional ImageNet 512、512和256、256基…

小程序AI智能名片S2B2C商城系统:四大主流商业模式深度解析与实战案例分享

在私域电商迅速崛起的大背景下&#xff0c;小程序AI智能名片S2B2C商城系统以其独特的商业模式和强大的功能&#xff0c;正成为品牌商们争相探索的新领域。在这一系统中&#xff0c;拼团模式、会员电商、社区团购和KOC营销等四种主流模式&#xff0c;为品牌商提供了多样化的营销…

【文章转载】Lance Martin的关于RAG的笔记

转载自微博黄建同学 从头开始学习 RAG&#xff0c;看Lance Martin的这篇笔记就行了&#xff0c;包含了十几篇论文和开源实现&#xff01; —— 这是一组简短的&#xff08;5-10 分钟视频&#xff09;和笔记&#xff0c;解释了我最喜欢的十几篇 RAG 论文。我自己尝试实现每个想…

C# GetField 方法应用实例

目录 关于 C# Type 类 GetField 方法应用 应用举例 心理CT设计题 类设计 DPCT类实现代码 小结 关于 C# Type 类 Type表示类型声明&#xff1a;类类型、接口类型、数组类型、值类型、枚举类型、类型参数、泛型类型定义&#xff0c;以及开放或封闭构造的泛型类型。调用 t…

WPS-EXCEL:快速删除多个线条对象

问题图 我需要将线条快速删除 方法一:使用定位对象功能 使用定位功能&#xff1a;按Ctrl G打开定位对话框。在对话框中&#xff0c;点击“定位条件”。 定位对象&#xff1a;在定位条件对话框中&#xff0c;勾选“对象”选项&#xff0c;然后点击“确定”。这样&#xff0c;…

CTF之变量1

拿到题目发现是一个php代码&#xff0c;意思是用get方式获取args参数。 至于下面那个正则表达式怎么绕过暂且不知&#xff0c;但是题目最上面告诉我们lag In the variable ! &#xff08;意思是flag就在变量中&#xff09;。 那我们就传入全局变量globals&#xff08;&#xf…

户外指南——时代产物

分类 一级分类&#xff1a; 衣&#xff1a;除了上述提到的&#xff0c;还包括衣物的材质、款式多样性、与身份地位的关联等。 食&#xff1a;还包括饮食的文化内涵、地域特色、对特殊饮食需求的满足等。 住&#xff1a;还包括居住空间的合理布局、智能家居的应用、与自然环境…