基于llama-index对embedding模型进行微调

QA对话目前是大语言模型的一大应用场景,在QA对话中,由于大语言模型信息的滞后性以及不包含业务知识的特点,我们经常需要外挂知识库来协助大模型解决一些问题。在外挂知识库的过程中,embedding模型的召回效果直接影响到大模型的回答效果,因此,在许多场景下,我们都需要微调我们的embedding模型来提高我们的召回效果。下面,我们就基于llama-index对BAAI/bge-base-zh-v1.5模型进行微调,关于该模型的介绍,可以参考https://huggingface.co/BAAI/bge-base-zh-v1.5。

平台介绍

对embedding模型进行微调的过程中需要使用GPU加速训练,由于家境贫寒,我这里就使用了Google colab提供的免费T4GPU进行微调测试。如果大家没办法使用这个,可以使用国内一些公司的GPU云平台,租便宜的GPU就行,微调这个模型所耗费的GPU资源不多。以下所有训练代码皆是在Jupter-notebook上编写并执行的。

依赖安装

安装一些依赖库,有些依赖需要制定版本,否则存在不兼容的问题。

!pip install langchain==0.0.300 llmx==0.0.15a0 openai==0.28.1 llama_index==0.8.23.post1 pypdf sentence-transformers

训练样本准备

我们当前的使用场景是QA问答场景,因此训练数据的格式最好也是问答的格式。我这里由于没有现成的问答样本(人工整理比较耗时),因此我就摘取了《明朝那些事儿》这个小说里面的部分章节,然后让GPT-3.5针对文章内容进行提问,从而形成问答对。代码如下

import json
import openai
import osfrom llama_index import SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.schema import MetadataMode
from llama_index import (VectorStoreIndex,SimpleDirectoryReader,ServiceContext,Response
)def load_corpus(docs, for_training=False, verbose=False):parser = SimpleNodeParser.from_defaults()if for_training:nodes = parser.get_nodes_from_documents(docs[:5], show_progress=verbose)else:nodes = parser.get_nodes_from_documents(docs[6:], show_progress=verbose)if verbose:print(f'Parsed {len(nodes)} nodes')return nodesSEC_FILE = ['embedding_test.txt'] # embedding_test.txt是我训练样本的文件名,即我摘取了部分小说章节并直接保存为了txt文件。print(f"Loading files {SEC_FILE}")reader = SimpleDirectoryReader(input_files=SEC_FILE)
docs = reader.load_data()
print(f'Loaded {len(docs)} docs')docs_nodes = load_corpus(docs, for_training=True, verbose=True)len(docs_nodes)train_nodes = docs_nodes[:75]  # 人工选择3分之2作为训练集
print(f'Loaded {len(train_nodes)} train docs')
val_nodes = docs_nodes[76:] # 剩下三分之一作为验证集
print(f'Loaded {len(val_nodes)} val docs')

构造训练集和测试集

使用GPT3.5基于小说内容生成对应的问题,最后生成train_dataset.json作为训练集,val_dataset.json作为验证集。

from llama_index.finetuning import (generate_qa_embedding_pairs,EmbeddingQAFinetuneDataset,
)
from llama_index.llms import OpenAIos.environ["OPENAI_API_KEY"] = "sk-************"
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_base = "https://************"prompt="""下方是上下文信息。---------------------
{context_str}
---------------------根据提供的上下文信息和没有先验知识的原则,仅基于以下查询生成问题。你是一名教师/教授。你的任务是为即将到来的测验/考试设置{num_questions_per_chunk}个问题。这些问题应在文档中多样化,且仅限于所提供的上下文信息。
"""train_dataset = generate_qa_embedding_pairs(train_nodes, qa_generate_prompt_tmpl=prompt)
val_dataset = generate_qa_embedding_pairs(val_nodes, qa_generate_prompt_tmpl=prompt)train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

微调Embedding模型

这里的微调都是使用的默认参数,在实际微调过程中,可根据实际情况进行调整。

from llama_index.finetuning import SentenceTransformersFinetuneEngine
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")
finetune_engine = SentenceTransformersFinetuneEngine(train_dataset,model_id="BAAI/bge-base-zh-v1.5",model_output_path="test_model",val_dataset=val_dataset,
)
finetune_engine.finetune() #由于模型较小,且训练样本较少,微调过程非常快
embed_model = finetune_engine.get_finetuned_model()
embed_model

评估微调后的模型

在评估阶段,我们对比了微调前、后的BAAI/bge-base-zh-v1.5模型以及OPENAI的ada002的Embedding模型的召回效果,代码如下:

from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import TextNode
from tqdm.notebook import tqdm
import pandas as pd
def evaluate(dataset,embed_model,top_k=5,verbose=False,
):corpus = dataset.corpusqueries = dataset.queriesrelevant_docs = dataset.relevant_docsservice_context = ServiceContext.from_defaults(embed_model=embed_model)nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]index = VectorStoreIndex(nodes, service_context=service_context, show_progress=True)retriever = index.as_retriever(similarity_top_k=top_k)eval_results = []for query_id, query in tqdm(queries.items()):retrieved_nodes = retriever.retrieve(query)retrieved_ids = [node.node.node_id for node in retrieved_nodes]expected_id = relevant_docs[query_id][0]is_hit = expected_id in retrieved_ids  # assume 1 relevant doceval_result = {"is_hit": is_hit,"retrieved": retrieved_ids,"expected": expected_id,"query": query_id,}eval_results.append(eval_result)return eval_results

注意,在执行下面的代码前,需要先在当前项目的目录下创建results文件夹,否则会导致程序执行失败。

from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformerdef evaluate_st(dataset,model_id,name,
):corpus = dataset.corpusqueries = dataset.queriesrelevant_docs = dataset.relevant_docsevaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs, name=name)model = SentenceTransformer(model_id)return evaluator(model, output_path="results/")

OPENAI-ada002

ada = OpenAIEmbedding()
ada_val_results = evaluate(val_dataset, ada)
df_ada = pd.DataFrame(ada_val_results)
hit_rate_ada = df_ada['is_hit'].mean()
hit_rate_ada

ada002模型的最终评测结果为0.9285714285714286

原始BAAI/bge-base-zh-v1.5

bge = "local:BAAI/bge-base-zh-v1.5"
bge_val_results = evaluate(val_dataset, bge)
df_bge = pd.DataFrame(bge_val_results)
hit_rate_bge = df_bge['is_hit'].mean()
hit_rate_bge

原始的bge-base-zh-v1.5模型的评测结果为0.7663744588744589

微调后的BAAI/bge-base-zh-v1.5

finetuned = "local:test_model"
val_results_finetuned = evaluate(val_dataset, finetuned)
df_finetuned = pd.DataFrame(val_results_finetuned)
hit_rate_finetuned = df_finetuned['is_hit'].mean()
hit_rate_finetuned

微调后模型的最终评测结果为0.975。即微调后,我们的embedding模型在当前数据集的召回效果由0.766上升到0.975注意,得分并不是越高越好,需考虑是否过拟合,可以在其他数据集上再评测下。

以上,即是一次简单的微调过程。感谢技术的发展和开源大佬们的贡献,使得人工智能的应用门槛越来越低。

参考资料

  1. https://colab.research.google.com/github/wenqiglantz/nvidia-sec-finetuning/blob/main/embedding-finetuning/finetune_embedding_nvidia_sec.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/312036.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[蓝桥杯2019初赛]最大降雨量-模拟

题目描述 由于沙之国长年干旱,法师小明准备施展自己的一个神秘法术来求雨。 这个法术需要用到他手中的49 张法术符,上面分别写着1 至49 这49 个数字。 法术一共持续7 周,每天小明都要使用一张法术符,法术符不能重复使用。 每周&a…

用户登录的电话号码和密码进行测试PythonGUI实验

用户登录的电话号码和密码进行测试PythonGUI实验: 1.要求:对用户登录的电话号码和密码进行测试 2.电话号码:分为首字母不为0,长度必须为11为,类型全部为数字 3.密码:分为长度为6-12位,类型为AS…

武汉坚守第十二日——爆发期的困守待破

已经到了第十二天,不能出门的日子,浑身都是难受的,尤其对于我这样一个日常一天要步行10公里每年4000公里步行的人来说,坚持了四年的习惯变成这样,真心不舒服,于是在室内开始了一些自救——自救运动&#xf…

[蓝桥杯2019初赛]完全二叉树的权值-完全二叉树的性质

注意: j < n不能少!!! 代码如下&#xff1a; #include <iostream> using namespace std; const int N 100010; typedef long long LL; int a[N];int main() {int n;LL maxv -1e18;cin >> n;int depth 0;for (int i 1; i < n; i)cin >> a[i];for …

机器学习前的热身(必备)

机器学习前的热身&#xff08;必备&#xff09; 备注&#xff1a; 本内容参考李航的《统计学习方法第二版》第一章 注&#xff1a;如果本篇内容存在错误&#xff0c;望大家留言批评指正。

WTM系列视频教程:初体验

WTM系列视频教程第一章&#xff1a;初体验文字摘要&#xff1a;“如果你没用过wtm&#xff0c;今天的教程肯定能让你眼前一亮&#xff0c;这个框架竟然这么牛逼么&#xff1f;开发速度这么快么&#xff1f;”“至于为什么叫WTM&#xff0c;他的全称是WalkingTec MVVM&#xff0…

[蓝桥杯2019初赛]修改数组-并查集

代码如下&#xff1a; #include <iostream> using namespace std; const int N 1000010; int a[N];int find(int x) {if (a[x] ! x)a[x] find(a[x]);return a[x]; }int main() {int n;cin >> n;for (int i 1; i < N; i)a[i] i;for (int i 1; i < n; i)…

python通过tkinter和json界面库实现考研知识点统计

python通过tkinter和json界面库实现考研知识点统计 使用下列代码前必须安装tkinter和json库 一、实现简单界面&#xff1a; """ from tkinter import * import test class mainView():def __init__(self):self.initializeUI()def initializeUI(self):self.wi…

【听歌】GDB入门教程之查看函数调用堆栈

写在前面&#xff1a;又到周末啦~上上周忍痛买了个雅马哈声卡和 AKG 话筒&#xff0c;这周六才正式打开试用了下&#xff0c;效果还不错&#xff0c;我自己还挺享受的。不过这玩意儿太高端&#xff0c;还不会用 AI 调音。小伙伴们感觉下这首加了一点点电音效果的歌曲如何呢等我…

[蓝桥杯2019初赛]年号字串-数论+模拟

题目描述 小明用字母A 对应数字1&#xff0c;B 对应2&#xff0c;以此类推&#xff0c;用Z 对应26。对于27以上的数字 小明用两位或更长位的字符串来对应&#xff0c;例如AA 对应27&#xff0c;AB 对应28&#xff0c;AZ 对应52&#xff0c;LQ 对应329。 请问2019 对应的字符串…

python通过tkinter界面库实现三角形成立的测试

python通过tkinter界面库实现三角形成立的测试 from tkinter import * from tkinter import messagebox login Tk() login.title(验证) login.geometry(800x600) Label(login,text实现三角形成立的验证).grid(row0,column0,columnspan2) Label(login,text边a&#xff1a;).gr…

研发协同平台持续集成实践

源宝导读&#xff1a;“持续集成”是敏捷最佳实践中&#xff0c;保证高质量交付的关键环节之一。本文将分享&#xff0c;在大规模研发在线协同的背景下&#xff0c;如何支撑在线持续集成的高性能和高可用。 一、什么是持续集成 在《持续集成》一书中&#xff0c;对持续集成的定…

ios::sync_with_stdio(false)的作用

默认的时候&#xff0c;cin与stdin总是保持同步的&#xff0c;也就是说这两种方法可以混用&#xff0c;而不必担心文件指针混乱&#xff0c; 所以一般会用ios::sync_with_stdio(false)来取消cin与stdin的同步&#xff0c;从而使cin达到和scanf相差无几的输入效率。 代码如下&a…

机器学习朴素贝叶斯算法+tkinter库界面实现好瓜坏西瓜分类

机器学习朴素贝叶斯算法tkinter库界面实现好瓜坏西瓜分类 一、界面实现 from tkinter import * from tkinter import ttk import NBdef main():win Tk()win.title(甜的西瓜挑选系统)win.geometry(1000x600)lb2 Label(win, text"色泽", font"tahoma 12 norma…

《ASP.NET Core 微服务实战》-- 读书笔记(第3章)

第 3 章 使用 ASP.NET Core 开发微服务 微服务定义 微服务是一个支持特定业务场景的独立部署单元。它借助语义化版本管理、定义良好的 API 与其他后端服务交互。它的天然特点就是严格遵守单一职责原则。 为什么要用 API 优先 所有团队都一致把公开、文档完备且语义化版本管理的…

[蓝桥杯2019初赛]数的分解-枚举

题目描述 把2019分解成3个各不相同的正整数之和&#xff0c;并且要求每个正整数都不包含数字2和4&#xff0c;一共有多少种不同的分解方法&#xff1f; 注意交换3个整数的顺序被视为同一种方法&#xff0c;例如1000100118 和1001100018 被视为同一种。 关键代码: int j i …

数据结构----------实现最小堆排序

数据结构----------实现最小堆排序 原理&#xff1a; 来源于---------------趣学数据结构 代码&#xff1a; #include<stdio.h> #include<stdlib.h> #define N 65535//最大个数排序 int r[N] { -1,1,4,590,4,2,8,7,5,89,67,5,2,1,67,86,54 };//存储要排序的数,第…

abp vnext2.0之核心组件模块加载系统源码解析

abp vnext是abp官方在abp的基础之上构建的微服务架构,说实话,看完核心组件源码的时候,很兴奋,整个框架将组件化的细想运用的很好,真的超级解耦.老版整个框架依赖Castle的问题,vnext对其进行了解耦,支持AutoFac或者使用.Net Core的默认容器.vnext依然沿用EF core为主,其余ORM为辅…

最大堆和最小堆排序

最大堆和最小堆排序 原理参考趣学数据结构 代码 #include<stdio.h> #include<stdlib.h> int r[] { -1,1,4,590,4,2,8,7,5,89,67,5,2,1,67,86,54 };//存储要排序的数,第一个元素不存储元素赋值为-1 int length sizeof(r) / sizeof(int);//待排序的数的个数 void s…

sstream应用举例

题目描述 把2019分解成3个各不相同的正整数之和&#xff0c;并且要求每个正整数都不包含数字2和4&#xff0c;一共有多少种不同的分解方法&#xff1f; 注意交换3个整数的顺序被视为同一种方法&#xff0c;例如1000100118 和1001100018 被视为同一种。 代码如下&#xff1a; …