大模型微调示例四之Llama-Factory-DPO - 教程

news/2025/9/23 18:41:37/文章来源:https://www.cnblogs.com/yfceshi/p/19107713

大模型微调示例四之Llama-Factory-DPO

  • 一、强化学习数据处理
  • 二、配置训练文档
  • 三、模型预测

一、强化学习数据处理

原始数据地址:https://nijianmo.github.io/amazon/index.html

第一步:读取 video game 信息

import codecs, json, re
from random import shuffle
# 第一步:读取 video game 信息
# key 是 productID,value是 title
games = {
}
cc = 0
with codecs.open('./data/src_data/meta_Video_Games.json', mode='r') as fin:
for line in fin:
tmp_info = json.loads(line.strip())
# asin - ID of the product
# title - name of the product
games[tmp_info["asin"]] = tmp_info["title"]
if len(games) % 10000 == 0:
print(f'Length of games: {
len(games)
}')

第二步:读取用户评分信息

# key 是 userid,value 是评价的游戏和评分
user_reviews = {
}
cc = 0
with codecs.open('./data/src_data/Video_Games_5.json', mode='r') as fin:
for line in fin:
tmp_info = json.loads(line.strip())
# reviewerID - ID of the reviewer
reviewer_id = tmp_info["reviewerID"]
time_info = re.split(', | ', tmp_info["reviewTime"])
review_time = time_info[2] + '-' + time_info[0] + '-' + time_info[1]
# asin - ID of the product
product_id = tmp_info["asin"]
# overall - rating of the product
rating = tmp_info["overall"]
# if cc > 1000:
# break
# print(tmp_info)
# print(user_reviews)
if product_id in games.keys():
product_title = games[product_id]
if reviewer_id in user_reviews.keys():
user_reviews[reviewer_id].append((product_title, rating, review_time))
else:
user_reviews[reviewer_id] = [(product_title, rating, review_time)]
if len(user_reviews) % 10000 == 0:
print(f'Length of user_reviews: {
len(user_reviews)
}')
cc += 1
user_reviews_sorted = {
}
for k, v in user_reviews.items():
# 首先去重
v = list(set(v))
# 然后根据评价时间从小到大排序,表示用户的评价历史
v_sorted = sorted(v, key=lambda x: x[2])
# 选择具有7个及以上的评论样本
if len(v) >= 7:
# print(f'v: {v}, v_sorted: {v_sorted}')
user_reviews_sorted[k] = v_sorted
print(f'Length of user_reviews_sorted: {
len(user_reviews_sorted)
}')

第三步 训练数据生成

# 总样本
samples = []
# 指令
instruction = "You are an assistant working on Video Games recommendations. Given the user's history of Video Games they have shopped, which includes the \"Title\" of the Video Games and the \"Rating\" the user rate (the Rating value is like or dislike), please decide whether the user likes to shop the target Video Games by outputting the order of their titles."
samples = []
cc = 0
for k, v in user_reviews_sorted.items():
# print('-'*10)
# print(v)
sample_input = "User shopped Video Games histories (Title and Rating): \n"
# 前面的当作对话历史
for vv in v[0: -2]:
# 当 rating 大于 3.0 的时候设置为 like
if vv[1] >
3.0:
rating = 'like'
# 当 rating 小于等于 3.0 的时候设置为 dislike
else:
rating = 'dislike'
sample_input += "<Title: {}, Rating: {}>\n".format(vv[0], rating)sample_input += "Based on the Video Games histories, please sort the following two Video Games titles. The one in the front is what the user like and should be recommended to user: \n"# 最后两个设置为需要预测的目标sample_input += "<Title: " + v[-2][0] + '>\n'sample_input += "<Title: " + v[-1][0] + '>\n'# print(f'v[-1][1]: {v[-1][1]}, v[-2][1]: {v[-2][1]}')# 保证有一个是 like,有一个是 dislikeif (v[-1][1] >3.0 and v[-2][1] <= 3.0) or (v[-1][1] <= 3.0 and v[-2][1] >3.0):# print(f'v[-1][1] != v[-2][1]: {v[-1][1]}, {v[-2][1]}')if v[-1][1] > v[-2][1]:# likeoption1 = v[-1][0]# dislikeoption2 = v[-2][0]else:# likeoption1 = v[-2][0]# dislikeoption2 = v[-1][0]# chosen 是 like 在前面chosen = "<Title: " + option1 + '>\n' + "<Title: " + option2 + '>'# rejected 是 dislike 在前面rejected = "<Title: " + option2 + '>\n' + "<Title: " + option1 + '>'sample = {"instruction": instruction,"input": sample_input,"chosen": chosen,"rejected": rejected}# print(f'--------')# print(v)# print(sample)samples.append(sample)if len(samples) % 10000 == 0:print(f'Length of samples: {len(samples)}')# cc += 1# if cc > 10:# breakprint(f'Length of samples: {len(samples)}')

第四步 划分 train 和 test 保存样本

# 首先打乱
shuffle(samples)
train = samples[:int(len(samples)*0.8)]
test = samples[int(len(samples)*0.8):]
print(f'总样本数: {
len(samples)
},训练集样本数: {
len(train)
},测试集样本数: {
len(test)
}')
with open("./data/processed/rlhf_train.json", "w", encoding='utf-8') as save_file:
json.dump(train, save_file, indent=4)
with open("./data/processed/rlhf_test.json", "w", encoding='utf-8') as save_file:
json.dump(test, save_file, indent=4) # , sort_keys=True

二、配置训练文档

rlhf_train.yaml

### model
model_name_or_path: /ZhipuAI/glm-4-9b-chat
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: 16
lora_alpha: 32
pref_beta: 0.1
pref_loss: orpo
### dataset
dataset: amazon_video_games
template: glm4
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: ./saves/amazon_video_games_orpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

rlhf_inference.yaml

model_name_or_path: /ZhipuAI/glm-4-9b-chat
adapter_name_or_path: ./saves/amazon_video_games_orpo
template: glm4
finetuning_type: lora

三、模型预测

import json
from openai import OpenAI
from tqdm import tqdm
# 加载模型
client = OpenAI(
api_key="EMPTY",
# 需要修改为大模型地址
base_url="http://10.114.16.65:8000/v1/"
)
# 加载测试数据
test_file_path = "./data/processed/rlhf_test.json"
with open(test_file_path, "r", encoding='utf-8') as test_file:
test_data = json.load(test_file)
print(len(test_data))
# 开始预测
labels = []
predictions = []
cc = 0
for each_test in tqdm(test_data):
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": each_test["instruction"]
},
{
"role": "user",
"content": each_test["input"],
}
],
model="glm4",
)
predictions.append(chat_completion.choices[0].message.content)
labels.append(each_test["chosen"])
if len(labels) % 100 == 0:
correct = 0
wrong = 0
for l, p in zip(labels, predictions):
l = l.strip()
p = p.strip()
# print(f'l: {l}, p: {p}')
if l == p:
correct += 1
else:
wrong += 1
# print(f'\nl: {l}, \np: {p}')
print(f'总样本数:{
len(labels)
},准确数:{correct
}, 错误数:{wrong
}, 准确率:{correct / len(labels)
}')
cc += 1
# if cc > 100:
# break
assert len(predictions) == len(labels)
correct = 0
wrong = 0
for l, p in zip(labels, predictions):
l = l.strip()
p = p.strip()
if l == p:
correct += 1
else:
wrong += 1
print(f'总样本数:{
len(labels)
},准确数:{correct
}, 错误数:{wrong
}, 准确率:{correct/len(labels)
}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/913506.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第9节-子查询-ALL - 详解

第9节-子查询-ALL - 详解pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", &q…

网站可以做无形资产潍坊网站seo外包

当我们在写前端页面的时候为了页面的美观我们通常会为页面设置图片背景&#xff0c;那么我们如何来设置全屏的背景图呢&#xff1f;&#xff1f;&#xff1f; 设置全屏背景图需要三个属性 background-image: url(img/untitled.png);background-repeat: no-repeat;background-s…

网站如何查看浏览量wordpress里的导航用什么

一.HTTPS是如何保证安全传输的 https通过使⽤对称加密、⾮对称加密、数字证书等⽅式来保证数据的安全传输。 客户端向服务端发送数据之前&#xff0c;需要先建⽴TCP连接&#xff0c;所以需要先建⽴TCP连接&#xff0c;建⽴完TCP连接后&#xff0c;服务端会先给客户端发送公钥…

长沙网站搭建公司联系方式网站建化

目录 1.查看网卡信息 2.修改yaml配置文件 3.应用新的网络配置 4.重新启动网络服务 文章内容 本文介绍Ubuntu 22.04.3 Server系统通过修改yaml配置文件配置静态 ip 的方法。 1.查看网卡信息 使用ifconfig命令查看网卡信息获取网卡名称​ 如果出现Command ifconfig not fo…

郑州网站建设公司排行榜wordpress主题如何汉化

人每时每刻都要呼吸&#xff0c;呼吸是生命得以存在的基础。不过人类赖以生存的氧气并不是地球上含量最高的气体&#xff0c;地球上含量最高的气体是氮气。在地球的大气之中&#xff0c;氮气的含量占到了78%&#xff0c;而氧气的含量排名第二&#xff0c;约为21%。我们经常会提…

自己免费怎么制作网站吗家纺网站模板

一个网站中&#xff0c;大部分网页的模块是重复的&#xff0c;比如顶部的导航栏&#xff0c;底部的备案信息。如果在每个页面中都重复的去写这些代码&#xff0c;会让项目变得臃肿&#xff0c;提高后期维护成本。比较好的做法是&#xff0c;通过模板继承&#xff0c;把一些重复…

html5响应式公司网站模版专业建设企业网站

众所周知&#xff0c;网络安全是一个非常重要的课题&#xff0c;而服务器是网络安全中最关键的环节。Linux被认为是一个比较安全的Internet服务器&#xff0c;作为一种开放源代码操作系统&#xff0c;一旦Linux系统中发现有安全漏洞&#xff0c;Internet上来自世界各地的志愿者…

软件工程感想

软件工程感想 在之前的概念里,我一直觉得软件开发就等于“写代码”——只要熟练掌握一门编程语言,能把想法用代码实现出来,就是一个合格的程序员了。然而,上了第一堂课之后,我发现自己之前的理解实在是太狭隘了。…

n8n+MySQL实现数据库查询!

为什么使用了 n8n 之后,会觉得惊喜? 因为使用他实在太方便了,但让这里的方便不单是本地部署、升级上的方便(dify 要启动 7 个服务,coze 要启动 9 个服务,而 n8n 一个服务就搞定了),而是他整体的便利性。例如他…

My Tricks

tricks 和注意事项 【数据删除】构造题!!! 杂项多测未清空 没开 long long 如果正面处理不方便,可以考虑拆单个的贡献然后用差分 跳来跳去的或要操作很多次的考虑倍增 判断等比数列时考虑正负性,并用比例的性质来…

完整教程:机器学习入门,支持向量机

完整教程:机器学习入门,支持向量机pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monac…

建设网站广州市wordpress发邮件接收验证码

如有错误或有补充&#xff0c;以及任何改进的意见&#xff0c;请在评论区留下您的高见&#xff0c;同时文中给出大部分命令的示例&#xff0c;即是您暂时无法在Linux中查看&#xff0c;您也可以知道各种操作的功能以及输出 如果觉得本文写的不错&#xff0c;不妨点个赞&#x…

网站建设构成技术要求教资注册网址

开发插件的都知道插件的content scripts和top window只共享Dom不共享window和其他数据&#xff0c;如果想拿挂载在window的数据还有点难度&#xff0c;下面会通过事件的方式传递cs和top window之间的数据写一个例子 代码 manifest.json 这里只搞了2个js&#xff0c;content.…

两学一做网站按钮图片小微企业名录

面向对象的基本原则&#xff1a;单一原则&#xff1a;引起一个类发生变化的原因只有一个 开放封闭原则&#xff1a;对于类的修改是封闭的&#xff0c;而对于类的扩展是开放的 依赖倒置原则&#xff1a;高层不依赖于底层&#xff0c;都应该依赖与抽象&#xff1b;抽象不依赖于…

做网站系统学校备案接入阿里云后原网站还能访问吗

1.2 课程架构介绍&#xff1a;STM32H5 芯片生命周期管理与安全调试 下面开始学习课程的第二节&#xff0c;简单介绍下STM32H5芯片的生命周期和安全调试&#xff0c;具体课程大家可以观看STM32官方录制的课程&#xff0c;链接&#xff1a;1.2. 课程架构介绍&#xff1a;STM32H5…

宁德商城网站建设wordpress找回文章

6.6&#xff1a;说明形参、局部变量以及局部静态变量的区别。编写一个函数&#xff0c;同时用到这三种形式。 Ans&#xff1a;形参及函数体内定义的变量&#xff0c;都是局部变量&#xff0c;必须进行初始化&#xff0c;否则会出现未定义行为&#xff0c;这是由于局部变量的生命…

移动互联网站建设修改网站的设计

本文主要介绍了Prompt设计、大语言模型SFT和LLM在手机天猫AI导购助理项目应用。 ChatGPT基本原理 “会说话的AI”&#xff0c;“智能体” 简单概括成以下几个步骤&#xff1a; 预处理文本&#xff1a;ChatGPT的输入文本需要进行预处理。 输入编码&#xff1a;ChatGPT将经过预…

月嫂网站模板企业网站推广服务协议

PHP 日期处理完全指南 引言 在PHP开发中,日期和时间处理是一个常见且重要的任务。PHP提供了丰富的内置函数来处理日期和时间,包括日期的格式化、计算、解析等。本文将详细介绍PHP中日期处理的相关知识,帮助读者全面理解和掌握这一技能。 1. PHP日期函数基础 1.1 date()函…

宁海县城镇建设局网站wordpress 提示

本文原文来自DataLearnerAI官方网站&#xff1a;ChatGPT内置隐藏debug功能&#xff1a;支持下载原始对话、可视化对话分支等 | 数据学习者官方网站(Datalearner) AIPRM的工作人员最近发现ChatGPT的客户端隐藏内置了一个新的debug特性&#xff0c;可以提高ChatGPT对话的问题调试…

网站系统里不能打印江苏高效网站制作公司

viewdata[alert]"alert(你好)"<script>viewdata[alert]</script> 在Controller存储数据 在 界面得到 关于 ViewData和ViewMode 点击这里 http://wanshiqian1221.blog.163.com/blog/static/6872130420095242016546/