ChatGLM2-INT4 + Lora 结构适配和改造

Lora 是目前公认的最好的微调方法,一方面,它并不像AdapterTuning 一样,改变原有模型的架构,不便于在不同框架之间迁移;另一方面,它不像 PTuning 一样改变所有任务下的单词生成概率,严重破坏已习得的知识。

ChatGLM2-INT4 这个量化版本使用自定义的QuantizedLinear作为线性模块。如果我们要使用 PEFT 库添加 Lora 参数时,它就会不认识,因为它是为torch.nn.Linear准备的,于是我们只能自己写个模块来实现这件事情。

一、编写LoraQuantizedLinear

LoraQuantizedLinear是我们自己的带Lora的线性层,包含QuantizedLinear所有参数/配置和Lora的主要配置项。

通过阅读quantization.py,我们确定QuantizedLinear的参数或配置有四个:

  • weight:量化后权重,形状为[OutDim, InDim]。注意在 INT4 模式下,一个 INT8 元素当两个 INT4 使用,InDim是 FP16 版本的一半。
  • weight_scale:量化的缩放系数,形状为[OutDim]。也就是说这个量化针对每一个隐藏状态确定一个范围,而不是整个参数。
  • bias:不量化的偏置,形状为[OutDim]
  • weight_bit_width:量化位数,4 或者 8。

新的线性层应该原样保存这些参数,并且应该包含Lora的三个主要配置:

  • r:较低的维度大小
  • alpha:和r一起组成缩放系数
  • dropout_rate:前置 Dropout 的比例

新的线性层创建如下:

class LoraQuantizedLinear(torch.nn.Module):def __init__(self, q_linear, lora_r=32, lora_alpha=32, lora_dropout_rate=0.0):super().__init__()# 保存原始参数和Lora配置self.lora_r = lora_rself.lora_alpha = lora_alphaself.lora_dropout_rate = lora_dropout_rateself.weight_bit_width = q_linear.weight_bit_widthself.weight = q_linear.weightself.weight_scale = q_linear.weight_scaleself.bias = q_linear.bias# 冻结原始参数self.weight.requires_grad = Falseself.weight_scale.requires_grad = Falseif self.bias is not None: self.bias.requires_grad = False# 创建 Lora 参数,FP16out_dim, in_dim = self.weight.shape# INT4 模型下,InDim 是原始大小的一半if self.weight_bit_width == 4: in_dim *= 2# LoraA 正态初始化self.lora_a = torch.nn.Parameter(torch.empty([self.lora_r, in_dim],device=self.weight.device,dtype=torch.float16,))torch.nn.init.kaiming_normal_(self.lora_a)# LoraB 全零初始化self.lora_b = torch.nn.Parameter(torch.zeros([out_dim, self.lora_r],device=self.weight.device,dtype=torch.float16,))self.lora_dropout = torch.nn.Dropout(self.lora_dropout_rate)self.lora_scale = self.lora_alpha / self.lora_r

正向传播过程中,先用之前的forward()方法完成原始输出的计算,之后再手动编写 Lora 输出的计算,然后相加并返回:

# class LoraQuantizedLineardef forward(self, input):ori_output = QuantizedLinear.forward(self, input)lora_output = (self.lora_dropout(input.half()) @ self.lora_a.transpose(0, 1) @ self.lora_b.transpose(0, 1) * self.lora_scale)return ori_output + lora_output.to(ori_output.dtype)

合并方法,我们将原始参数解量化,将Lora两个参数相乘后与之相加,然后再次量化。

# class LoraQuantizedLineardef merge(self):# H = XW + b + XAB * s => H = X(W + AB * s) + b# 将 int 原始参数转成 fp16weight = extract_weight_to_half(self.weight, self.weight_scale, self.weight_bit_width)# 合并 lora 参数weight += self.lora_b @ self.lora_a * self.lora_scale# 再转回 intweight, weight_scale = half_weight_to_int(weight, self.weight_bit_width)self.weight = torch.nn.Parameter(weight, requires_grad=False)self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)# 重新初始化 lora 两个矩阵torch.nn.init.kaiming_normal_(self.lora_a)torch.nn.init.zeros_(self.lora_b)

half_weight_to_int取自QuantizedLinear类的构造器:

def half_weight_to_int(weight: torch.Tensor, weight_bit_width: int):assert weight_bit_width in [4, 8]assert weight.ndim == 2weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)weight = torch.round(weight / weight_scale[:, None]).to(torch.int8)if weight_bit_width == 4:weight = compress_int4_weight(weight)return weight, weight_scale

二、辅助方法

之后实现几个辅助方法,完成参数挂载卸载和合并:

首先是attach_lora:将所有的QuantizedLinear改成LoraQuantizedLinear

我们搜索模型的所有模块,再搜索它的直接子模块,如果任何东西是QuantizedLinear,就把它替换掉。之后把非 Lora 参数冻结。

def attach_lora(model, lora_r=32, lora_alpha=32, lora_dropout_rate=0.0):if model.lora_attached: return modellora_conf = dict(lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout_rate=lora_dropout_rate)for mod in model.modules():for name in dir(mod):submod = getattr(mod, name, None)if not isinstance(submod, QuantizedLinear):continuenew_submod = LoraQuantizedLinear(submod, **lora_conf)setattr(mod, name, new_submod)for name, param in model.named_parameters():if 'lora_' not in name:param.requires_grad = Falsemodel.lora_attached = Truereturn model

lora_state_dict:导出所有 Lora 参数。

def lora_state_dict(model):return {   k:vfor k, v in model.state_dict().items()if 'lora_' in k}

merge_lora:将所有LoraQuantizedLinear的参数合并。

def merge_lora(model):for mod in model.modules():if isinstance(mod, LoraQuantizedLinear):mod.merge()return model

detach_lora:搜索模型的所有模块,再搜索它的直接子模块,如果任何东西是LoraQuantizedLinear,就把它替换回QuantizedLinear

def detach_lora(model):if not model.lora_attached: return modelfor mod in model.modules():for name in dir(mod):submod = getattr(mod, name, None)if not isinstance(submod, LoraQuantizedLinear):continuenew_submod = QuantizedLinear.from_params(submod.weight_bit_width,submod.weight,submod.weight_scale,submod.bias,)setattr(mod, name, new_submod)model.lora_attached = Falsereturn model

这就需要给QuantizedLinear加一个工厂方法,接受它的四个参数并直接保存。原有的构造器与之不兼容,于是实现为类方法:

# class QuantizedLinear@classmethoddef from_params(cls, weight_bit_width, weight, weight_scale, bias):obj = cls.__new__(cls)super(cls, obj).__init__()obj.weight_bit_width = weight_bit_widthobj.weight = weightobj.weight_scale = weight_scaleobj.bias = biasreturn obj

由于这些方法是实现在单独的文件中的(例如lora.py),我们在modeling_chatglm.py中导入这个文件,然后添加到ChatGLMForConditionalGeneration类里面,便于使用:

from .lora import attach_lora, detach_lora, merge_lora, lora_state_dict# class ChatGLMForConditionalGeneration# def __init__self.lora_attached = Falseattach_lora = attach_loradetach_lora = detach_loramerge_lora = merge_loralora_state_dict = lora_state_dict

然后一切就完成了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/112908.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nginx配合tomcat、resin等java应用服务器提供java支持

首先探讨一下为什么要使用nginx: 1、类似于apacheresin,nginx用于提供静态页面服务,比java服务器要强。虽然这些java服务器的性能都不赖,tomcat新版甚至还支持了epoll,但是用nginx来处理静态文件是一定比这些服务器更…

macos使用搭建算法竞赛c/c++的g++/gcc编译环境(homebrew,含万能头,改环境变量,vscode/clion可用)

文章目录 1、homebrew安装2、安装g3、改环境变量 1、homebrew安装 我没改镜像,直接网上脚本一键安装的,具体命令忘了,可能是这个 反正装这个的方法很多,网上一搜都有。 成功装上homebrew就行。 /bin/bash -c "$(curl -fsSL…

微信小程序6

一、什么是后台交互? 在小程序中,与后台交互指的是小程序前端与后台服务器之间的数据通信和请求处理过程。通过与后台交互,小程序能够获取服务器端的数据、上传用户数据、发送请求等。 与后台交互可以通过以下方式实现: 发起网络请…

redis的cluster

1.我们的哨兵模式中,当主节点挂掉以后,此时哨兵会重新进行选举,选举出新的主节点去对外提供写服务 在选举的过程中,他redis整个集群是不提供写服务的 (因为此时我们哨兵对外提供写服务的只有Master) 2.我们单节点的red…

ESP32集成开发环境Espressif-IDE安装 – Windows

陈拓 2023/10/15-2023/10/16 1. 概述 Espressif IDE是一个基于Eclipse CDT的集成开发环境(IDE),用于使用ESP-IDF框架开发物联网应用程序。这是一个专门为ESP-IDF构建的独立定制IDE。Espressif IDE附带了IDF Eclipse插件、重要的Eclipse CDT插…

【数据结构】线性表(八)队列:顺序队列及其基本操作(初始化、判空、判满、入队、出队、存取队首元素)

文章目录 一、队列1. 定义2. 基本操作 二、顺序队列0. 顺序表1. 头文件和常量2. 队列结构体3. 队列的初始化4. 判断队列是否为空5. 判断队列是否已满6. 入队7. 出队8. 存取队首元素9. 主函数10. 代码整合 堆栈Stack 和 队列Queue是两种非常重要的数据结构,两者都是特…

美格智能出席无锡智能网联汽车生态大会,共话数字座舱新势力

10月20日,2023世界物联网博览会期间,以“智 行天下 启未来”为主题的2023无锡智能网联汽车生态大会暨域控制器及智能座舱论坛在无锡举行。大会邀请行业权威专家,多家知名企业重磅嘉宾出席,融汇智能网联汽车思想智慧、创新技术、产…

微信小程序连接数据库与WXS的使用

微信小程序连接数据库与WXS的使用 1.搭建数据库连接,使用后端获取数据1.请求方式的封装2.化一下代码,这样写太繁琐了3.前端代码 四、WXS的使用1..解决数据显示数字问题2. 解决统计人数问题3.解决时间进制问题 ) 1.搭建数据库连接,使用后端获取数据 为了后期方便维护…

李m圆申论

听话出活 3小时 /处理7500字 /一共5题 /写出2200字 字写得好看点,符号也算字数,占一个格 基本思路:考什么范围答什么 。。。落后;资源闲置、缺乏 申论: 作文题:举例子 处理材料 摘抄: 有人出…

centos或aws linux部署java应用,环境搭建shell

目录 设置root密码开启密码登录安装docker安装nginx设置nginx自启动nginx配置https配置http集群tcp集群 安装docker设置docker自启动修改docker基础配置创建docker网关docker安装mysql单机版本主从版本 docker安装redis设置密码:不要密码: docker安装rab…

基础MySQL的语法练习

基础MySQL的语法练习 create table DEPT(DEPTNO int(2) not null,DNAME VARCHAR(14),LOC VARCHAR(13) );alter table DEPTadd constraint PK_DEPT primary key (DEPTNO);create table EMP (EMPNO int(4) primary key,ENAME VARCHAR(10),JOB VARCHAR(9),MGR …

【机器学习】集成模型/集成学习:多个模型相结合实现更好的预测

1. 概述 1.1 什么是集成模型/集成学习 "模型集成"和"集成学习"是相同的概念。它们都指的是将多个机器学习模型组合在一起,以提高预测的准确性和稳定性的技术。通过结合多个模型的预测结果,集成学习可以减少单个模型的偏差和方差&am…

13.3测试用例进阶

一.测试对象划分 1.界面测试(参考软件规格说明书和UI视觉稿) a.什么是界面 1)WEB站(浏览器) 2)app 3)小程序 4)公众号 b.测试内容 1)界面内容显示的一致性,完整性,准确性,友好性.比如界面内容对屏幕大小的自适应,换行,内容是否全部清晰展示. 2)验证整个界面布局和排版…

RunnerGo 支持UI自动化的测试平台

RunnerGo提供从API管理到API性能再到可视化的API自动化、UI自动化测试功能模块,覆盖了整个产品测试周期。 RunnerGo UI自动化基于Selenium浏览器自动化方案构建,内嵌高度可复用的测试脚本,测试团队无需复杂的代码编写即可开展低代码的自动化…

Leetcode——字符

520. 检测大写字母 class Solution { public:bool detectCapitalUse(string word) {int big 0, small 0, len word.length();for (int i 0; i < len; i) {if (word[i] > 65 && word[i] < 90) {big;}else {small;}}if (big len || small len) {return tr…

工业电子中的深力科分享一款PWM控制器 KA3525A

关于PWM控制器&#xff1a; PWM控制器是一种用于控制电机或其他设备的电路&#xff0c;它通过改变脉冲宽度调制&#xff08;PWM&#xff09;信号的占空比来控制设备的输出。PWM控制器可以使用单片机或开发板等设备来实现&#xff0c;通过设定占空比&#xff0c;可以轻松地控制…

关于数据库连接池和线程,记录几个问题

文章目录 1.HirakiPool - Connection is not available, request timed out after2.在一个线程内&#xff0c;调用多次dataSource.getConnection()这是为什么呢&#xff1f;是谁来实现的线程内连接唯一呢&#xff1f; 1.HirakiPool - Connection is not available, request tim…

【微信小程序调试工具试用】

【微信小程序调试工具试用】 试用大佬开发的dll拿到某物小程序sign签名 &#xff08;过于简单 大佬勿喷&#xff09;本次工具分享到此结束 什么是爬虫逆向&#xff1f; 试用大佬开发的dll拿到某物小程序sign签名 &#xff08;过于简单 大佬勿喷&#xff09; 1 如图 下面小程序…

MIKE水动力笔记17_MIKE文件转shp、统计每个单元格的面积

本文目录 前言Step 1 MIKE文件转shpStep 2 在ArcGIS中打开shp统计相应指标拓展&#xff1a;关于shp文件的介绍 前言 MIKE的工具箱中自带一个转shp的工具&#xff0c;然后可以拖进ArcGIS中很方便的统计每个单元格的面积和每个网格点的水深。 Step 1 MIKE文件转shp MIKE允许转…

SSM - Springboot - MyBatis-Plus 全栈体系(三十四)

第八章 项目实战 四、后台功能开发 1. 用户模块开发 1.1 jwt 和 token 介绍 1.1.1 token 介绍 令牌&#xff08;Token&#xff09;&#xff1a;在计算机领域&#xff0c;令牌是一种代表某种访问权限或身份认证信息的令牌。它可以是一串随机生成的字符或数字&#xff0c;用…