【llm对话系统】大模型 Llama 源码分析之 LoRA 微调

1. 引言

微调 (Fine-tuning) 是将预训练大模型 (LLM) 应用于下游任务的常用方法。然而,直接微调大模型的所有参数通常需要大量的计算资源和内存。LoRA (Low-Rank Adaptation) 是一种高效的微调方法,它通过引入少量可训练参数,固定预训练模型的权重,从而在保持性能的同时大大减少了计算开销。

本文将深入分析 LoRA 的原理,并结合 Llama 源码解读其实现逻辑,最后探讨 LoRA 的优势。

2. LoRA 原理

LoRA 的核心思想是:预训练模型中已经包含了大量的低秩 (low-rank) 特征,微调时只需要对这些低秩特征进行微调即可。

具体来说,LoRA 假设权重更新矩阵 ΔW 也是低秩的。对于一个预训练的权重矩阵 W ∈ R^(d×k),LoRA 将其更新表示为:

W' = W + ΔW = W + BA

其中:

  • W 是预训练的权重矩阵。
  • ΔW 是权重更新矩阵。
  • B ∈ R^(d×r)A ∈ R^(r×k) 是两个低秩矩阵,r 远小于 dkr 被称为 LoRA 的秩 (rank)。

在训练过程中,W 被冻结,只有 AB 是可训练的。

直观理解:

可以将 W 看作一个编码器,将输入 x 编码成一个高维表示 Wx。LoRA 认为,在微调过程中,我们不需要完全改变这个编码器,只需要通过 BA 对其进行一个低秩的调整即可。

3. Llama 中 LoRA 的实现

虽然 Llama 官方代码没有直接集成 LoRA,但我们可以使用一些流行的库 (例如 peft by Hugging Face) 来实现 Llama 的 LoRA 微调。peft 库提供了 LoraConfigget_peft_model 等工具,可以方便地将 LoRA 应用于各种 Transformer 模型。

3.1 使用 peft 库实现 Llama 的 LoRA 微调

以下是一个使用 peft 库实现 Llama 的 LoRA 微调的简化示例:

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType# 加载预训练的 Llama 模型和分词器
model_name = "meta-llama/Llama-2-7b-hf"  # 假设使用 Llama 2 7B
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)# LoRA 配置
config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=8,  # LoRA 的秩lora_alpha=32,  # LoRA 的缩放因子lora_dropout=0.1,  # Dropout 比例target_modules=["q_proj", "v_proj"], # 需要应用 LoRA 的模块
)# 获取支持 LoRA 的模型
model = get_peft_model(model, config)# 打印可训练参数的比例
model.print_trainable_parameters()# ... (加载数据,进行训练) ...

代码解释:

  1. 加载预训练模型:使用 transformers 库加载预训练的 Llama 模型和分词器。
  2. LoRA 配置:创建一个 LoraConfig 对象,指定 LoRA 的配置参数:
    • task_type:任务类型,这里是因果语言模型 (Causal Language Modeling)。
    • r:LoRA 的秩。
    • lora_alpha:LoRA 的缩放因子,用于控制 LoRA 模块的权重。
    • lora_dropout:Dropout 比例。
    • target_modules: 指定需要应用 LoRA 的模块, 通常是注意力层中的 q_proj, v_proj, 还可以是k_proj, o_proj, gate_proj, up_proj, down_proj等。不同的模型需要根据实际情况配置。
  3. 获取支持 LoRA 的模型:使用 get_peft_model 函数将原始的 Llama 模型转换为支持 LoRA 的模型。
  4. 打印可训练参数:使用 model.print_trainable_parameters() 可以查看模型中可训练参数的比例,通常 LoRA 的可训练参数比例非常小。

3.2 peft 库中 LoRA 的实现细节 (部分)

peft 库中 LoraModel 类的部分代码 (为了清晰起见,已进行简化):

class LoraModel(torch.nn.Module):# ...def _find_and_replace(self, model):# ... (遍历模型的每一层) ...if isinstance(module, nn.Linear) and name in self.config.target_modules:new_module = Linear(module.in_features,module.out_features,bias=module.bias is not None,r=self.config.r,lora_alpha=self.config.lora_alpha,lora_dropout=self.config.lora_dropout,)# ... (将原模块的权重赋值给新模块) ...# ...class Linear(nn.Linear):def __init__(self,in_features: int,out_features: int,r: int = 0,lora_alpha: int = 1,lora_dropout: float = 0.0,**kwargs,):super().__init__(in_features, out_features, **kwargs)# LoRA 参数self.r = rself.lora_alpha = lora_alpha# 初始化 A 和 Bif r > 0:self.lora_A = nn.Parameter(torch.randn(r, in_features))self.lora_B = nn.Parameter(torch.zeros(out_features, r)) # B 初始化为全 0self.scaling = self.lora_alpha / self.rdef forward(self, x: torch.Tensor):result = F.linear(x, self.weight, bias=self.bias) # W @ xif self.r > 0:result += (self.lora_B @ self.lora_A @ x.transpose(-2, -1) # (B @ A) @ x).transpose(-2, -1) * self.scalingreturn result

代码解释:

  1. _find_and_replace 函数:遍历模型的每一层,找到需要应用 LoRA 的线性层 (例如,q_proj, v_proj),并将其替换为 Linear 层。
  2. Linear 类:继承自 nn.Linear,并添加了 LoRA 的参数 lora_Alora_B
    • lora_A 初始化为随机值。
    • lora_B 初始化为全 0,这是为了保证在训练开始时,LoRA 部分的输出为 0,不影响预训练模型的原始行为。
    • scaling 是一个缩放因子,用于控制 LoRA 模块的权重。
  3. forward 函数:
    • F.linear(x, self.weight, bias=self.bias) 计算原始的线性变换 W @ x
    • (self.lora_B @ self.lora_A @ x.transpose(-2, -1)).transpose(-2, -1) * self.scaling 计算 LoRA 部分的输出 (B @ A) @ x,并乘以缩放因子。
    • 将两者相加,得到最终的输出。

4. LoRA 的优势

  • 高效的参数利用:LoRA 只需微调少量的参数 (A 和 B),而冻结了预训练模型的大部分参数,大大减少了训练时的内存占用和计算开销。
  • 快速的训练速度:由于可训练参数较少,LoRA 的训练速度通常比全量微调快得多。
  • 防止过拟合:LoRA 的低秩约束起到了一定的正则化作用,有助于防止过拟合。
  • 性能相当:在许多任务上,LoRA 可以达到与全量微调相当的性能。
  • 易于部署:训练完成后,可以将 WBA 相加,得到新的权重矩阵 W',然后像使用原始的预训练模型一样进行部署,无需额外的计算开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894360.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 内核学习(5) --- Linux 内核底半部机制

目录 中断底半部软中断tasklet工作队列使用工作队列 中断底半部 当产生一个中断时,会进入中断处理程序,但中断处理程序必须快速、异步、简单的对硬件做出迅速响应并完成那些时间要求很严格的操作,因此,对于那些其他的、对时间要求…

3D图形学与可视化大屏:什么是材质属性,有什么作用?

一、颜色属性 漫反射颜色 漫反射颜色决定了物体表面对入射光进行漫反射后的颜色。当光线照射到物体表面时,一部分光被均匀地向各个方向散射,形成漫反射。漫反射颜色的选择会直接影响物体在光照下的外观。例如,一个红色的漫反射颜色会使物体在…

Jenkins未在第一次登录后设置用户名,第二次登录不进去怎么办?

Jenkins在第一次进行登录的时候,只需要输入Jenkins\secrets\initialAdminPassword中的密码,登录成功后,本次我们没有修改密码,就会导致后面第二次登录,Jenkins需要进行用户名和密码的验证,但是我们根本就没…

Qt常用控件 输入类控件

文章目录 1.QLineEdit1.1 常用属性1.2 常用信号1.3 例子1,录入用户信息1.4 例子2,正则验证手机号1.5 例子3,验证输入的密码1.6 例子4,显示密码 2. QTextEdit2.1 常用属性2.2 常用信号2.3 例子1,获取输入框的内容2.4 例…

有没有个性化的UML图例

绿萝小绿萝 (53****338) 2012-05-10 11:55:45 各位大虾,有没有个性化的UML图例 绿萝小绿萝 (53****338) 2012-05-10 11:56:03 例如部署图或时序图的图例 潘加宇 (35***47) 2012-05-10 12:24:31 "个性化"指的是? 你的意思使用你自己的图标&…

Go学习:字符、字符串需注意的点

Go语言与C/C语言编程有很多相似之处,但是Go语言中在声明一个字符时,数据类型与其他语言声明一个字符数据时有一点不同之处。通常,字符的数据类型为 char,例如 :声明一个字符 (字符名称为 ch) 的语句格式为 char ch&am…

本地部署 DeepSeek-R1 模型

文章目录 霸屏的AIDeepSeek是什么?安装DeepSeek安装图形化界面总结 霸屏的AI 最近在刷视频的时候,总是突然突然出现一个名叫 DeepSeek 的玩意,像这样: 这样: 这不经激起我的一顿好奇心,这 DeepSeek 到底是个…

断裂力学课程报告

谈谈你对线弹性断裂力学和弹塑性断裂力学的认识 经过对本课程的学习,我首先认识到断裂力学研究的是宏观的断裂问题,而不是研究属于断裂物理研究范围的微观结构断裂机理。断裂力学从材料内部存在缺陷出发,研究裂纹的生成、亚临界拓展&#xff…

代码随想录刷题day22|(字符串篇)344.反转字符串、541.反转字符串 II

目录 一、题目思路 二、相关题目 三、总结与知识点 3.1 字符数组转换成字符串 一、题目思路 344反转字符串比较容易,双指针即可在空间复杂度为O(1)的基础上解决; 541反转字符串II :其中for循环中 i 每次的取值,不是 i&#…

【机器学习】自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

一、使用pytorch框架实现逻辑回归 1. 数据部分: 首先自定义了一个简单的数据集,特征 X 是 100 个随机样本,每个样本一个特征,目标值 y 基于线性关系并添加了噪声。将 numpy 数组转换为 PyTorch 张量,方便后续在模型中…

使用Visual Studio打包Python项目

1. 安装Visual Studio 首先,你需要在你的计算机上安装Visual Studio。 2. 创建项目 在Visual Studio中创建一个新的Python项目。 打开Visual Studio,点击“File”(文件) -> “New”(新建) -> “Pr…

TVM调度原语完全指南:从入门到微架构级优化

调度原语 在TVM的抽象体系中,调度(Schedule)是对计算过程的时空重塑。每一个原语都是改变计算次序、数据流向或并行策略的手术刀。其核心作用可归纳为: 优化目标 max ⁡ ( 计算密度 内存延迟 指令开销 ) \text{优化目标} \max…

51单片机——串口向电脑发送数据

引言 在电子技术领域,51 单片机作为一种广泛应用的微控制器,其串口通信功能具有重要意义。通过串口,51 单片机能够与电脑等外部设备进行数据交互,实现各种复杂的功能,为许多应用场景提供了可能。 51 单片机串口通信基…

高性能消息队列Disruptor

定义一个事件模型 之后创建一个java类来使用这个数据模型。 /* <h1>事件模型工程类&#xff0c;用于生产事件消息</h1> */ no usages public class EventMessageFactory implements EventFactory<EventMessage> { Overridepublic EventMessage newInstance(…

Java线程认识和Object的一些方法ObjectMonitor

专栏系列文章地址&#xff1a;https://blog.csdn.net/qq_26437925/article/details/145290162 本文目标&#xff1a; 要对Java线程有整体了解&#xff0c;深入认识到里面的一些方法和Object对象方法的区别。认识到Java对象的ObjectMonitor&#xff0c;这有助于后面的Synchron…

基于YOLO11的肺结节检测系统

基于YOLO11的肺结节检测系统 (价格90) LUNA16数据集 数据一共 1186张 按照8&#xff1a;1&#xff1a;1随机划分训练集&#xff08;948张&#xff09;、验证集&#xff08;118张&#xff09;与测试集&#xff08;120张&#xff09; 包含 nodule 肺结节 1种…

C++ Primer 自定义数据结构

欢迎阅读我的 【CPrimer】专栏 专栏简介&#xff1a;本专栏主要面向C初学者&#xff0c;解释C的一些基本概念和基础语言特性&#xff0c;涉及C标准库的用法&#xff0c;面向对象特性&#xff0c;泛型特性高级用法。通过使用标准库中定义的抽象设施&#xff0c;使你更加适应高级…

FFmpeg源码:av_base64_decode函数分析

一、引言 Base64&#xff08;基底64&#xff09;是一种基于64个可打印字符来表示二进制数据的表示方法。由于log2 646&#xff0c;所以每6个比特为一个单元&#xff0c;对应某个可打印字符。3个字节相当于24个比特&#xff0c;对应于4个Base64单元&#xff0c;即3个字节可由4个…

白话DeepSeek-R1论文(三)| DeepSeek-R1蒸馏技术:让小模型“继承”大模型的推理超能力

最近有不少朋友来询问Deepseek的核心技术&#xff0c;陆续针对DeepSeek-R1论文中的核心内容进行解读&#xff0c;并且用大家都能听懂的方式来解读。这是第三篇趣味解读。 DeepSeek-R1蒸馏技术&#xff1a;让小模型“继承”大模型的推理超能力 当大模型成为“老师”&#xff0c…

curope python安装

目录 curope安装 测试: 报错:libc10.so: cannot open shared object file: No such file or directory 解决方法: curope安装 git clone : GitHub - Junyi42/croco at bd6f4e07d5c4f13ae5388efc052dadf142aff754 cd models/curope/ python setup.py build_ext --inplac…