deepspeed+transformers模型微调

一、目录

  1. 代码讲解

二、实现。

1、代码讲解,trainer 实现。
transformers通过trainer 集成deepspeed功能,所以中需要进行文件配置,即可实现deepspeed的训练。
微调代码: 参数定义—>数据处理---->模型创建/评估方式---->trainer 框架训练
注意: V100 显卡,不包括float16 精度训练。

import deepspeed
deepspeed.ops.op_builder.CPUAdamBuilder().load()
import nltk
import torch
import evaluate
import datasets
import numpy as np
from nltk.tokenize import sent_tokenize
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
nltk.download("punkt")
import gc
import torch######################################定义参数####################################
dataset_name = "samsum" # 数据集名称
#model_name="google/flan-t5-xxl" # 模型名称
model_name="google/flan-t5-xl" # 模型名称
max_input_length = 256
max_gen_length = 128
output_dir = "checkpoints"
num_train_epochs = 5
learning_rate = 5e-5
deepspeed_config = "ds_config.json" #          deepspeed配置文件
per_device_train_batch_size=5 # batch size设置为1,因为太大导致OOM
per_device_eval_batch_size=5
gradient_accumulation_steps=10 # 由于单卡的batch size为1,为了扩展batch size,使用梯度累加#################################加载数据集,与数据预处理#########################################
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = datasets.load_dataset(dataset_name)
print(dataset["train"][0])def preprocess(examples):dialogues = ["summarize:" + dia for dia in examples["dialogue"]]# summaries = [summ for summ in examples["summary"]]model_inputs = tokenizer(dialogues, max_length=max_input_length, truncation=True)labels = tokenizer(text_target=examples["summary"], max_length=max_gen_length, truncation=True)model_inputs["labels"] = labels["input_ids"]return model_inputstokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=["dialogue", "summary", "id"])
# print(tokenized_dataset["train"]["input_ids"][0]) # 打印结果    对map后的数据进行查看。def collate_fn(features):batch_input_ids = [torch.LongTensor(feature["input_ids"]) for feature in features]batch_attention_mask = [torch.LongTensor(feature["attention_mask"]) for feature in features]batch_labels = [torch.LongTensor(feature["labels"]) for feature in features]batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)batch_attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)batch_labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100)return {"input_ids": batch_input_ids,"attention_mask": batch_attention_mask,"labels": batch_labels}##############################加载模型,采用seq2seqLM模型,并进行测试##############################
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#用于测试的代码
#dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, batch_size=4, collate_fn=collate_fn)
# batch = next(iter(dataloader))
# #print(batch)
# # 用于测试的代码
# dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, batch_size=4, collate_fn=collate_fn)
# batch = next(iter(dataloader))
# output = model(**batch)
#print(output)
#############################################模型训练,并采用trainer 架构####################################
print("==========train....================")
metric = evaluate.load("rouge")
def compute_metrics(eval_preds):preds, labels = eval_predsif isinstance(preds, tuple):preds = preds[0]decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)labels = np.where(labels != -100, labels, tokenizer.pad_token_id)decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)result = {k: round(v * 100, 4) for k, v in result.items()}prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]result["gen_len"] = np.mean(prediction_lens)return resulttraining_args = Seq2SeqTrainingArguments(output_dir=output_dir,per_device_train_batch_size=per_device_train_batch_size,per_device_eval_batch_size=per_device_eval_batch_size,gradient_accumulation_steps=gradient_accumulation_steps,eval_accumulation_steps=1, # 防止评估时导致OOMpredict_with_generate=True,learning_rate=learning_rate,num_train_epochs=num_train_epochs,# logging & evaluation strategieslogging_dir="logs",logging_strategy="steps",logging_steps=50, # 每50个step打印一次logevaluation_strategy="steps",eval_steps=500, # 每500个step进行一次评估save_steps=500,save_total_limit=2,load_best_model_at_end=True,deepspeed=deepspeed_config,                        # deepspeed配置文件的位置report_to="all"
)trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_dataset["train"],eval_dataset=tokenized_dataset["validation"],data_collator=collate_fn,compute_metrics=compute_metrics,
)
trainer.train()
gc.collect()
torch.cuda.empty_cache()
# 打印验证集上的结果
# print(trainer.evaluate(tokenized_dataset["validation"]))
# # 打印测试集上的结果
# print(trainer.evaluate(tokenized_dataset["test"]))
# 保存最优模型
trainer.save_model("best.pt")
#export NCCL_IB_DISABLE=1; export NCCL_P2P_DISABLE=1; NCCL_DEBUG=INFO deepspeed --include=localhost:0,1 test1.py

配置文件:ds_config.json

{"fp16": {"enabled": "auto"},"optimizer": {"type": "AdamW","params": {"lr": "auto","betas": "auto","eps": "auto","weight_decay": "auto"}},"scheduler": {"type": "WarmupLR","params": {"warmup_min_lr": "auto","warmup_max_lr": "auto","warmup_num_steps": "auto"}},"zero_optimization": {"stage": 3,"offload_optimizer": {"device": "cpu","pin_memory": true},"offload_param": {"device": "cpu","pin_memory": true},"overlap_comm": true,"contiguous_gradients": true,"sub_group_size": 1e9,"reduce_bucket_size": "auto","stage3_prefetch_bucket_size": "auto","stage3_param_persistence_threshold": "auto","stage3_max_live_parameters": 1e9,"stage3_max_reuse_distance": 1e9,"stage3_gather_16bit_weights_on_model_save": false},"gradient_accumulation_steps": "auto","gradient_clipping": "auto","steps_per_print": 2000,"train_batch_size": "auto","train_micro_batch_size_per_gpu": "auto","wall_clock_breakdown": false
}

启动: 单机多卡

export NCCL_IB_DISABLE=1; export NCCL_P2P_DISABLE=1; NCCL_DEBUG=INFO deepspeed --include=localhost:0,1 test1.py>output.log 2>&1 &

二、代码讲解,peft微调+trainer 实现。

import os
import torch
import random
import datasets
import numpy as np
from typing import Dict
from transformers import (AutoModelForCausalLM,AutoTokenizer,DataCollatorForSeq2Seq,TrainingArguments,Trainer
)
from peft import (LoraConfig,TaskType,get_peft_model,get_peft_model_state_dict,
)def set_random_seed(seed):if seed is not None and seed > 0:random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.random.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueset_random_seed(1234)# 1. 设置参数
# LoRA参数
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
# 训练参数
EPOCHS=3
LEARNING_RATE=5e-5
OUTPUT_DIR="./checkpoints"
BATCH_SIZE=4 # 2
GRADIENT_ACCUMULATION_STEPS=3
# 其他参数
MODEL_PATH = "bigscience/bloomz-7b1-mt"
DATA_PATH = "./data/belle_open_source_1M.train.json"
MAX_LENGTH = 512
PATTERN = "{}\n{}"
DS_CONFIG = "ds_zero2_config.json"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # 加载tokenizer
# 加载数据
dataset = datasets.load_dataset("json", data_files=DATA_PATH)
# print(dataset["train"][0])# 2. tokenize
def tokenize(text: str, add_eos_token=True):result = tokenizer(text,truncation=True,max_length=MAX_LENGTH,padding=False,return_tensors=None)# 判断是否要添加eos_tokenif (result["input_ids"][-1] != tokenizer.eos_token_idand len(result["input_ids"]) < MAX_LENGTHand add_eos_token):result["input_ids"].append(tokenizer.eos_token_id)result["attention_mask"].append(1)result["labels"] = result["input_ids"].copy()return resultdef preprocess(example: Dict, train_on_inputs: bool = False):prompt = example["input"]response = example["target"]text = PATTERN.format(prompt, response)tokenized_inp = tokenize(text)# 若train_on_inputs为False,则将label中与input相关的token替换为-100if not train_on_inputs:tokenized_prompt = tokenize(prompt,add_eos_token=False)prompt_tokens_len = len(tokenized_prompt["input_ids"])tokenized_inp["labels"] = [-100]*prompt_tokens_len + tokenized_inp["labels"][prompt_tokens_len:]return tokenized_inptrain_data = dataset["train"].shuffle().map(preprocess, remove_columns=["id", "input", "target"])
print(train_data[0])# pad_to_multiple_of=8表示padding的长度是8的倍数
collate_fn = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)# 2. 加载模型
evice_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
# device_map指定模型加载的GPU;troch_dtype=torch.float16表示半精度加载模型
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map=device_map)# 3. LoRA相关
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=LORA_R,                         # LoRA中低秩近似的秩lora_alpha=LORA_ALPHA,            # 见上文中的低秩矩阵缩放超参数lora_dropout=LORA_DROPOUT,       # LoRA层的dropout
)
# 转换模型
model = get_peft_model(model, lora_config)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# 打印模型中的可训练参数
model.print_trainable_parameters()# 4. 训练参数
args = TrainingArguments(output_dir=OUTPUT_DIR, # checkpoint的存储目录per_device_train_batch_size=BATCH_SIZE, # 单设备上的batch sizegradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, # 梯度累加的step数warmup_steps=100,num_train_epochs=EPOCHS,learning_rate=LEARNING_RATE,fp16=True, # 使用混合精度训练logging_steps=50,evaluation_strategy="no", # 不进行评估save_strategy="steps",save_steps=2000, # 保存checkpoint的step数save_total_limit=5, # 最多保存5个checkpointdeepspeed=DS_CONFIG                               #deepspeed 配置
)# 5. 模型训练
trainer = Trainer(model=model,train_dataset=train_data,eval_dataset=None,args=args,data_collator=collate_fn
)
trainer.train()
model.save_pretrained("best_model")
{"train_micro_batch_size_per_gpu": "auto","gradient_accumulation_steps": "auto","steps_per_print": 50,"gradient_clipping": 1.0,"zero_optimization": {"stage": 2,"offload_optimizer": {"device": "cpu"},"contiguous_gradients": true,"overlap_comm": true},"zero_allow_untested_optimizer": true,"fp16": {"enabled": true,"loss_scale": 0,"loss_scale_window": 1000,"hysteresis": 2,"min_loss_scale": 1},"optimizer": {"type": "Adam","params": {"lr": "auto","betas": "auto","eps": "auto","weight_decay": "auto"}},"activation_checkpointing": {"partition_activations": true,"contiguous_memory_optimization": true},"wall_clock_breakdown": false
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835086.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Apache SeaTunnel 正式发布2.3.5版本,功能增强及多个Bug修复

经过两个月的筹备&#xff0c;我们在2.3.4版本基础上进行了新一轮的迭代&#xff0c;本次更新不仅修复了多个关键问题&#xff0c;还引入了若干重要功能增强和性能优化。 在此&#xff0c;我们先提前感谢社区成员的贡献和支持&#xff0c;如果你想升级最新的版本&#xff0c;快…

websocket最大数量的限制问题

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码&#xff1a; https://gitee.com/nbacheng/ruoyi-nbcio 演示地址&#xff1a;RuoYi-Nbcio后台管理系统 http://218.75.87.38:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码&#xff1a; h…

安全 | 开源入侵防御系统 Snort

目录 Snort 概要 入侵预防系统模式 数据包记录器和嗅探器模式 网络安全学习路线 &#xff08;2024最新整理&#xff09; 学习资料的推荐 1.视频教程 2.SRC技术文档&PDF书籍 3.大厂面试题 特别声明&#xff1a; Snort 概要 Snort 概要 是世界上最重要的开源入…

《十日终焉》中的定律整理-向虫队学习(举例+持续更新)

1、二八定律 二八定律&#xff0c;又称帕累托法则&#xff0c;也叫巴莱多定律。 是19世纪末20世纪初意大利经济学家巴莱多发明的。其中指出&#xff0c;约仅有20%的因素影响80%的结果。也就是说&#xff1a;所有变因中&#xff0c;最重要的仅有20%&#xff0c;虽然剩余的80%占…

实习体验报告怎么写:AI产品助理实习经历

笔灵实习体验报告模版免费分享&#xff0c;更多需要可以点击使用⬇️ https://ibiling.cn/scene/inex?fromcsdnsx 一、实习背景与目的在过去的几个月里&#xff0c;我有幸在一家知名科技公司实习&#xff0c;担任AI产品助理的角色。这次实习让我有机会深入了解AI领域&#x…

JAVA 双亲委派之一

JAVA 双亲委派之一 JVM类加载流程 java语言系统内置了众多类加载器&#xff0c;从一定程度上讲&#xff0c;只存在两种不同的类加载器&#xff1a;一种是启动类加载器&#xff0c;此类加载由C实现&#xff0c;是JVM的一部分&#xff1b;另一种就是所有其他的类加载器&#xf…

ASP.NET学生成绩管理系统

摘要 本系统依据开发要求主要应用于教育系统&#xff0c;完成对日常的教育工作中学生成绩档案的数字化管理。开发本系统可使学院教职员工减轻工作压力&#xff0c;比较系统地对教务、教学上的各项服务和信息进行管理&#xff0c;同时&#xff0c;可以减少劳动力的使用&#xf…

【网络】gateway 可以提供的一些功能之三 “ 支持Eureka服务发现 ”

一、Eureka是干什么的 Eureka就像是一个电话簿&#xff0c;但是用来存储和管理各种微服务的地址信息。它帮助微服务之间相互发现和交流&#xff0c;就像你想找某人电话号码一样&#xff0c;只需查看电话簿就能找到他们的联系方式。Eureka也可以帮助系统在服务出现问题时自动发现…

STM32CubeMX软件使用(超详细)

1、Cube启动页介绍 2、芯片选择页面介绍 3、输入自己的芯片型号&#xff0c;这里以STM32U575RIT6举例 4、芯片配置页码介绍 5、芯片外设配置栏详细说明 6、点击ClockConfiguration进行时钟树的配置&#xff0c;选择时钟树后可以选择自己想使用的时钟源&#xff0c;也可以直接输…

unreal engine4 创建动画蒙太奇

UE4系列文章目录 文章目录 UE4系列文章目录前言一、创建动画蒙太奇 前言 动画蒙太奇的官方解释&#xff1a;Animation Montages are animation assets that enable you to combine animations in a single asset and control playback using Blueprints.You can use Animation…

与Apollo共创生态:助力自动驾驶迈向新台阶

引言Apollo七周年大会企业协同工具链携手伙伴共创生态未来展望与总结 引言 2024年4月19日&#xff0c;一场智能汽车未来的盛宴正朝我们走来——Apollo开放平台的七周年大会。 此次大会主题为“破晓•拥抱智变时刻”其中“破晓”象征着新时代的曙光&#xff0c;意味着智能汽车技…

最新版在线客服系统源码

源码介绍 首发最新在线客服系统源码&#xff0c;优化更好并且重构源码布局UI 性能不吃cpu并发快,普通1H2G都能带动最新版只要是服务器都能带动 搭建即可使用,操作简单,易懂 修复了老版本bug 内附有搭建教程 gofly.v1kf.com 运行环境 Nginx 1.20 MySQL 5.7 演示截图

Spring IoCDI(3)—DI详解

目录 一、属性注入 二、构造方法注入 小结&#xff1a;构造函数的注入 三、Setter注入 四、三种注入的优缺点分析&#xff08;面试题&#xff09; 1、属性注入 优点&#xff1a; 缺点&#xff1a; 2、构造方法注入&#xff08;Spring4.X推荐&#xff09; 优点&#x…

netty配置SSL、netty配置https(开发)

netty配置SSL、netty配置https&#xff08;开发&#xff09; 我们在开发下使用ssl&#xff0c;所用的证书将不被客户端信任。 方案一 快速。使用netty提供的临时签发证书 private static SslContext sslContext null;public ServerChannelHandler(RouterConfig config) {th…

使用Python在PowerPoint演示文稿之间复制样式(复制幻灯片母版)

在专业演示文稿设计与制作领域&#xff0c;多场演示间保持一致性至关重要。在PowerPoint演示文稿之间复制幻灯片母版成为了一项关键技巧&#xff0c;用以维持统一的视觉风格&#xff0c;确保品牌形象的一致性&#xff0c;并提升观众的参与度。这一做法不仅能节省宝贵的时间&…

TCP 连接,一端断电和进程崩溃有什么区别?

TCP 连接&#xff0c;一端断电和进程崩溃有什么区别&#xff1f; 前言主机崩溃进程崩溃有数据传输的场景客户端主机宕机&#xff0c;又迅速重启客户端主机宕机&#xff0c;一直没有重启 总结 前言 有的小伙伴在面试腾讯的时候&#xff0c;遇到了这么个问题&#xff1a; 这个属…

Java入门基础学习笔记11——关键字和标识符

1、关键字 关键字是java中已经被赋予特定意义的&#xff0c;有特殊作用的一些单词&#xff0c;不可以把这些单词作为标识符来使用。 注意&#xff1a;关键字是java用了的&#xff0c;我们就不能用来作为&#xff1a;类名、变量名、否则会报错。 标识符&#xff1a; 标识符就是…

机器学习的一些知识点分享

下面数据集中&#xff0c;第2个样本的第4个属性的值是&#xff08; &#xff09;。 A 52 B 男 C 50 D 49 本题得分&#xff1a; 2分 正确答案&#xff1a; D 2.单选题 (2分) 10-折交叉验证是把数据集分成&#xff08; &#xff09;个子集&#xff0c;将其中&#xff…

EmploLeaks:一款针对企业安全的组织员工信息收集OSINT工具

关于EmploLeaks EmploLeaks是一款针对企业安全的组织员工信息收集OSINT工具&#xff0c;在该工具的帮助下&#xff0c;企业内部的安全人员和管理员可以有效地收集组织内员工的各种信息&#xff0c;并以此来判断组织内部的网络安全态势。 工作机制 首先&#xff0c;该工具会在…

MySQL库操作 表操作【详细解析】

MySQL MySQL是一个数据库软件 mysql mysql是一个“客户端—服务器”结构的软件 (1) a.客户端&#xff1a;主动发起请求的一方&#xff08;Client&#xff09; b.服务器&#xff1a;被动接收请求的一方&#xff08;Server&#xff09; 客户端和服务器之间通过网络 进行通信 (…