大模型LORA微调总结

大模型LORA微调总结

  • 大模型微调总结
    • 模型加载
      • 使用deepspeed
      • 不使用deepspeed
      • 使用lora
      • 加载分词器
    • 数据加载
    • 构建source和target
    • 构建input_ids和labels
    • 标签补齐
    • 构建训练器
    • LORA模型推理
      • 模型加载
      • 多batch推理构建
      • lora微调推理
      • 合并模型权重

大模型微调总结

模型加载

使用deepspeed

 model = transformers.AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,cache_dir=training_args.cache_dir,torch_dtype='auto',# if model_args.model_name_or_path.find("falcon") != -1 else Falsetrust_remote_code=True)

不使用deepspeed

model = transformers.AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,cache_dir=training_args.cache_dir,device_map='auto',torch_dtype='auto',# if model_args.model_name_or_path.find("falcon") != -1 else Falsetrust_remote_code=True)

使用lora

from peft import LoraConfig, get_peft_model
LORA_R = 32
# LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = [
"o_proj","gate_proj", "down_proj", "up_proj"
]config = LoraConfig(
r=LORA_R,
# lora_alpha=LORA_ALPHA,
target_modules=TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
#加载配置
model = get_peft_model(model, config)
#打印训练参数比例
model.print_trainable_parameters()

加载分词器

tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

数据加载

通过Hugging Face的dateset库进行加载数据

使用dateset可以轻松加载数据,样例如下所示:

from datasets import load_dataset
dataset = load_dataset('csv', data_files='my_file.csv')
dataset = load_dataset('csv', data_files=['my_file_1.csv', 'my_file_2.csv', 'my_file_3.csv'])
dataset = load_dataset('csv', data_files={'train':['my_train_file_1.csv','my_train_file_2.csv'],'test': 'my_test_file.csv'})

我们可以按下面方式加载数据

def load_dataset_from_own(data_path: Optional[str] = None,cache_dir: Optional[str] = "cache_data") -> Dataset:all_file_list = ['a.json','b.json','c.json']data_files = {'train': all_file_list}extension = all_file_list[0].split(".")[-1]datasets = load_dataset(extension,data_files=data_files,cache_dir=cache_dir,)['train']return datasets

构建source和target

  1. 构建prompt
PROMPT_DICT = {"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. ""Write a response that appropriately completes the request.\n\n""### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),"prompt_no_input": ("Below is an instruction that describes a task. ""Write a response that appropriately completes the request.\n\n""### Instruction:\n{instruction}\n\n### Response:"),
}
  1. 根据prompt构建source
sources = [prompt_input.format_map({'instruction': ins_data[i], 'input': input_data[i]}) if input_data[i] != "" else prompt_no_input.format_map({'instruction': ins_data[i]})for i in range(len_)]
#限制长度
sources = [i[:data_args.source_length] for i in sources]
  1. 根据prompt构建targets
targets = [f"{example[:data_args.target_length-1]}{tokenizer.eos_token}" for example in output]

构建input_ids和labels

输入需要构建的text,输出构建好的ids

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:"""Tokenize a list of strings."""tokenized_list = [tokenizer(text,return_tensors="pt",padding="longest",max_length=tokenizer.model_max_length,truncation=True,)for text in strings]#获得idsinput_ids = labels = [tokenized.input_ids[0]for tokenized in tokenized_list]#终止符设置ne_pad_token_id = IGNORE_INDEX if tokenizer.pad_token_id is None else tokenizer.pad_token_id#统计长度input_ids_lens = labels_lens = [tokenized.input_ids.ne(ne_pad_token_id).sum().item() for tokenized in tokenized_list]return dict(input_ids=input_ids,labels=labels,input_ids_lens=input_ids_lens,labels_lens=labels_lens,)

构建input_ids 和label

examples = [s + t for s, t in zip(sources, targets)]
#问题+答案、问题
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
#构建labels
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):label[:source_len] = IGNORE_INDEX

标签补齐

在动态batching中我们需要一个data collator完成padding。这里不适用DataCollatorWithPadding来进行补齐操作,因为这个函数仅对输入的键(包括input_ids, attention_mask, token_type_ids)进行补齐,不会对labels进行补齐操作。还有在对labels进行补齐操作时,使用的是-100而不是分词器的pad_token,这么做到的目的是在计算损失函数的时候忽略掉这些padding token。

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model,
label_pad_token_id=IGNORE_INDEX)

构建训练器

from transformers import DataCollatorForSeq2Seq, Trainer
trainer = Trainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,eval_dataset=None,data_collator=data_collator)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)

LORA模型推理

模型加载

base_model_name_or_path = "internlm-7b"
lora_model_name_or_path ="checkpoint-9695"model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path,torch_dtype="auto",# device_map="auto",# if model_args.model_name_or_path.find("falcon") != -1 else Falsetrust_remote_code=True,
).cuda(0)model = PeftModel.from_pretrained(model, model_id=lora_model_name_or_path)
model.eval()
print("ok")tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, trust_remote_code=True, padding_side="left"
)

多batch推理构建

def batch_generate_data(text_input: List[str], use_train_model: bool = True, temp: float = 0.7
):text_input_format = [generate_input(i) for i in text_input]batch_inputs = tokenizer.batch_encode_plus(text_input_format, padding="longest", return_tensors="pt")batch_inputs["input_ids"] = batch_inputs["input_ids"].cuda()batch_inputs["attention_mask"] = batch_inputs["attention_mask"].cuda()if use_train_model:# with model.disable_adapter():outputs = model.generate(**batch_inputs,max_new_tokens=256,do_sample=True,temperature=temp,top_p=0.8,)else:with model.disable_adapter():outputs = model.generate(**batch_inputs,max_new_tokens=256,do_sample=True,temperature=temp,top_p=0.8,)outputs = tokenizer.batch_decode(outputs.cpu()[:, batch_inputs["input_ids"].shape[-1] :],skip_special_tokens=True,)return outputs

lora微调推理

text_input = ["工作压力太大怎么办\n"] * 32
# lora 训练结果
batch_generate_data(text_input, use_train_model=True, temp=0.8)
# 原来的模型
batch_generate_data(text_input, use_train_model=False, temp=0.8)

合并模型权重

model = model.merge_and_unload()
model.save_pretrained("internlm-7b-lml")
tokenizer.save_pretrained("internlm-7b-lml")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/627163.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年甘肃省职业院校技能大赛信息安全管理与评估 样题一 理论题

竞赛需要完成三个阶段的任务,分别完成三个模块,总分共计 1000分。三个模块内容和分值分别是: 1.第一阶段:模块一 网络平台搭建与设备安全防护(180 分钟,300 分)。 2.第二阶段:模块二…

一文让你对mysql索引底层实现明明白白

开篇: 图片是本人随笔画的,有点粗糙,望大家谅解,如有不对的地方,请联系我们,感谢 一、索引到底是什么 .索引是帮助mysql高效获取数据的排好序的数据结构 .索引是存储在文件里的 .数据结构: 二…

微信小程序怎么引入webview的url是本地的路径

当微信小程序访问类似http://10.27.0.15:8065/#/my这样的地址的时候会出问题。但是我们也不能每次把写的H5的代码发布在看效果啊? 只需要修改一个地方就可以啦。

Transformer 位置编码

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心&…

LLM(十)| Tiny-Vicuna-1B:Tiny Models轻量化系列Top One

在过去的一年里,见证了LLM的蓬勃发展,而模型的参数量也不断刷新记录,在2023年下半年,外界传言GPT-4是一个专家混合模型。因此,如果你想用人工智能做点什么,你需要IBM或NASA类似的计算能力:你怎么…

JAVA进化史: JDK16特性及说明

JDK 16于2021年3月发布。这个版本引入了一些新特性和改进,以下是其中一些主要特性 JEP 338: 引入了向量API(Vector API) 引入了向量API(Vector API),这是一个孵化器特性,用于提供更好地利用硬…

openharmony 编译LLVM编译器基础架构

1. 编译库地址 third_party_llvm-project: 管理员 liwentao_uiw dhy308 huanghuijin 2. 编译方法 git clone https://gitee.com/openharmony/third_party_llvm-project.gitcd third_party_llvm-projectmkdir buildcd buildcmake -G Ninja -DCMAKE_BUILD_TYPERelease ../llvm …

纯c++简易的迷宫小游戏

一个用c写的黑框框迷宫 适合新手入门学习 也适合大学生小作业 下面附上代码 总体思路 初始化游戏界面:设置迷宫的大小(WIDTH和HEIGH),生成迷宫地图(map),包括墙壁、空地、起点和终点。显示…

3、python布尔类型和条件表达式

使用布尔值进行分支逻辑! 文章目录 1.布尔类型1.1比较运算1.2组合布尔值2.条件语句2.1布尔转换1.布尔类型 Python有一种称为bool的变量类型。它有两个可能的值:True和False。 In [1]: x = True print(x) print(type(x)) True <class bool>除了直接在代码中使用True或…

【K12】Python写串联电阻问题的求解思路解析

问题源代码 方法&#xff1a;calculate_circuit_parameter 构造题目&#xff1a; 模板&#xff1a; 已知电阻R1为 10Ω&#xff0c;电阻R2为 5Ω&#xff0c;电压表示数为2.5V&#xff0c;求电源电压U&#xff1f; 给合上面题目&#xff0c;利用Python程序&#xff0c;可以任…

LeetCode 76.最小覆盖子串Java

题目链接 这个是滑动窗口问题比较难的了&#xff0c;不太好想。 我借鉴了这个大佬的思想&#xff0c;用更容易理解的方式实现了一下&#xff0c;可能时间复杂度有点提高。 代码搭配详解使用&#xff1a;题解 这个是我的题解 class Solution {public String minWindow(String …

【论文笔记合集】卷积神经网络之深度可分离卷积(Depthwise Separable Convolution)

本文作者&#xff1a; slience_me 我看的论文地址&#xff1a;MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 内容 1. 标准卷积 假设输入为DFDFM&#xff0c;输出为输入为DFDFN&#xff0c;卷积核为DKDKM&#xff0c;共有N个卷积核进…

人机对话:程序设计,学哪种语言好?

人机对话&#xff1a;程序设计&#xff0c;学哪种语言好&#xff1f; 程序设计&#xff0c;学哪种语言好&#xff1f;学习目的&#xff1a;职业发展&#xff1a;个人兴趣&#xff1a; go语言怎么样&#xff1f;优点&#xff1a;缺点&#xff1a; 要开发手机APP&#xff0c;还需…

LeetCode刷题---随机链表的复制

解题思路&#xff1a; 使用哈希表来解决该问题 因为题中要求是深拷贝 首先对原链表遍历&#xff0c;将原链表每个节点和新链表每个节点形成对应关系&#xff0c;存入到哈希表中&#xff0c;key为原链表的节点&#xff0c;value为新链表的节点。 之后重置辅助链表指向原链表头节…

墨刀原型-实现轮播图功能

在墨刀中实现轮播图效果&#xff0c;可以按照以下步骤进行操作&#xff1a; 1.添加轮播图组件&#xff1a;在墨刀的组件面板中&#xff0c;找到轮播图组件并将其拖拽到画布上。 2.上传轮播图&#xff1a;在右侧的属性面板中&#xff0c;你可以上传你的轮播图图片。点击“”按钮…

动态pv(nfs方式挂载)

1、定义 发布pvc之后可以生成pv&#xff0c;还可以在共享服务器上直接生成挂载目录 pvc直接绑定和使用pv 2、动态pv依赖两个组件 &#xff08;1&#xff09;provisioner卷插件&#xff1a;k8s本身支持的动态pv创建不包括nfs&#xff0c;需要声明和安装一个外部插件provisio…

NET Core发布 HTTP Error 500.31 - Failed to load ASP.NET Core runtime

记录一下踩过的坑&#xff1a; 首先&#xff0c;不论是500.31还是500.30 &#xff0c;首先确保安装了三个文件 1.NET Core RunTime 2.NET SDK 3.NET Hosting 其次&#xff0c;确保三个文件的版本一致&#xff0c;如下&#xff1a; 要装就统一装同一个大版本&#xff0c;不要东…

Linux第28步_编译“修改正点原子TF-A源码中的Makefile并编译生成新的TF-A 固件”

了解学习内容&#xff1a; 1)、正点原子STM32MP157开发板使用的主控型号是STM32MP157DAA1&#xff1b; 2)、“linux /atk-mp1/atk-mp1/alientek_tf-a/tf-a-stm32mp-2.2.r1”目录下的文件是正点原子STM32MP157D开发板的“TF-A源码”。 3)、“linux /atk-mp1/atk-mp1/alientek…

字符串匹配

模板&#xff1a; KMP: 细节在代码中 看不懂的可以参照&#xff1a;如何更好地理解和掌握 KMP 算法? - 阮行止的回答 - 知乎 https://www.zhihu.com/question/21923021/answer/1032665486 package StringMatch.KMP;import java.util.ArrayList; import java.util.List;publ…

k8s的配置资源管理

Secret Secret用来保存密码、token密钥以及一些敏感的k8s资源。这类数据虽然可以存放在镜像当中&#xff0c;但是放在secret当中可以更方便控制。减少暴露的风险。 Secret的作用&#xff1a;保存加密的信息 Secret的类型 docker-registry()主要用于存储docker仓库的认证信息…