【LLM多模态】MiniGPT4模型结构和训练流程

note

  • 图生文应用场景:比如电商领域根据产品图像生成产品描述、娱乐领域中根据电影海报生成电影介绍等
  • MiniGPT-4将预训练的大语言模型和视觉编码器参数同时冻结,只需要单独训练线性投影层,使视觉特征和语言模型对齐。
  • MiniGPT4的视觉编码器:使用了与BLIP-2相同的预训练视觉语言模型,该模型由2个部分组成:
    • 视觉编码器ViT(vision transformer):提取图像中的基本视觉特征。miniGPT-4使用了EVA-CLIP中的ViT-G/14进行实现(初始化该模块的代码如下)
    • 图文对齐模块Q-former:进一步将视觉编码与文本编码对齐,得到语言模型可以理解的向量编码
  • minigpt4主要对blip2的第二步训练(视觉到文本生成)改进,Linear Layer修改了输出维度,同时对LLM模型输入时,增加了prompt,提高了模型的问答能力。
    • Linear Layer: 由于vit输出的编码向量维度默认为768,此处就是一个升维操作,变成4096(对比blip2,这里是2560)。
    • img embed:图像经过vit和Q-Former之后,得到图像的embeding编码,编码最后一维为768,经过Linear Layer,转成4096维。

文章目录

  • note
  • 零、
  • 一、MiniGPT模型
    • 1. Vicuna 模型
    • 2. 视觉编码器
    • 3. 线性投影层
  • 二、训练过程
    • 1. 预训练
    • 2. 微调训练
  • 三、MiniGPT-v2模型(待更新)
  • Reference

零、

一、MiniGPT模型

项目链接:https://github.com/Vision-CAIR/MiniGPT-4
对应信息: 地址:https://github.com/Vision-CAIR/MiniGPT-4,https://huggingface.co/Vision-CAIR/MiniGPT-4/tree/main

《MiniGPT-v2: large language model as a unified interface for vision-language multi-task learning》,https://arxiv.org/abs/2310.09478

《MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models》,https://arxiv.org/abs/2304.10592

多模态LLM的任务类型:
在这里插入图片描述

MiniGPT-4模型架构:三部分,预训练的大语言模型,预训练的视觉编码器以及一个单一的线性投影层。
在这里插入图片描述

1. Vicuna 模型

decoder类型的语言模型,其在miniGPT-4中任务是理解输入进来的文本和图像数据,对多模信息有感知理解能力,生成符合指令的文本描述。MiniGPT-4 并不从头开始训练大语言模型,而是直接利用现有的 Vicuna-13B 或 Vicuna-7B 版本,冻结所有的参数权重,降低计算开销。

2. 视觉编码器

使用了与BLIP-2相同的预训练视觉语言模型,该模型由2个部分组成:

  • 视觉编码器ViT(vision transformer):提取图像中的基本视觉特征。miniGPT-4使用了EVA-CLIP中的ViT-G/14进行实现(初始化该模块的代码如下)
  • 图文对齐模块Q-former:进一步将视觉编码与文本编码对齐,得到语言模型可以理解的向量编码

(1)视觉编码器ViT:miniGPT-4使用了EVA-CLIP中的ViT-G/14进行实现

# miniGPT-4使用了EVA-CLIP中的ViT-G/14进行实现def init_vision_encoder(cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze):logging.info('Loading VIT')assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"if not freeze:precision = "fp32"  # fp16 is not for trainingvisual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision)ln_vision = LayerNorm(visual_encoder.num_features)if freeze:for name, param in visual_encoder.named_parameters():param.requires_grad = Falsevisual_encoder = visual_encoder.eval()visual_encoder.train = disabled_trainfor name, param in ln_vision.named_parameters():param.requires_grad = Falseln_vision = ln_vision.eval()ln_vision.train = disabled_trainlogging.info("freeze vision encoder")logging.info('Loading VIT Done')return visual_encoder, ln_vision

miniGPT-4使用了EVA-CLIP中的ViT-G/14进行实现(初始化该模块的代码如上),其中:

  • img_size 表示输入图像的尺寸;
  • drop_path_rate 表示使用 drop_path 的比例,这是一种正则化技术;
  • use_grad_checkpoint 表示是否使用梯度检查点技术来减少内存使用;
  • precision表示训练过程中的精度设置。

该函数通过创建 ViT 视觉编码器模型,将输入图像转换为特征表示。

(2)图文对齐模块Q-former:通常使用预训练的BERT模型,通过计算图像编码和查询(一组可学习的参数)之间的交叉注意力,更好将图像emb和文本emb对齐。初始化该模块代码如下:

def init_Qformer(cls, num_query_token, vision_width, freeze):# 使用预训练的bert模型配置q-formerencoder_config = BertConfig.from_pretrained("bert-base-uncased")encoder_config.encoder_width = vision_width# insert cross-attention layer every other blockencoder_config.add_cross_attention = Trueencoder_config.cross_attention_freq = 2# 设置查询长度encoder_config.query_length = num_query_tokenQformer = BertLMHeadModel(config=encoder_config)# 创建查询标记并初始化,是一组可训练的参数,用于查询图像和文本之间的关系query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)Qformer.cls = NoneQformer.bert.embeddings.word_embeddings = NoneQformer.bert.embeddings.position_embeddings = Nonefor layer in Qformer.bert.encoder.layer:layer.output = Nonelayer.intermediate = Noneif freeze:for name, param in Qformer.named_parameters():param.requires_grad = FalseQformer = Qformer.eval()Qformer.train = disabled_trainquery_tokens.requires_grad = Falselogging.info("freeze Qformer")# 返回初始化的q-former模型、查询标记return Qformer, query_tokens

3. 线性投影层

在这里插入图片描述

  • 视觉编码器虽然已经在广泛的图像-文本任务中做了预训练,但它们本质上没有针对 LLaMA、Vicuna 等大语言模型做过微调。为了弥补视觉编码器和大语言模型之间的差距,MiniGPT-4 增加了一个可供训练的线性投影层,期望通过训练将编码的视觉特征与 Vicuna 语言模型对齐。
  • 通过定义一个可训练的线性投影层,将 Q-Former 输出的图像特征映射到大语言模型的表示空间,以便结合后续的文本输入做进一步的处理和计算。
  • miniGPT-4模型的前向传播过程如下:
self.llama_proj = nn.Linear(img_f_dim, self.llama_model.config.hidden_size
)def encode_img(self, image):device = image.deviceif len(image.shape) > 4:image = image.reshape(-1, *image.shape[-3:])with self.maybe_autocast():# 使用视觉编码器对图像编码后,再使用LayerNorm标准化image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)# 默认使用冻结的q-formerif self.has_qformer:# 创建图像的注意力掩码image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)# 扩展查询标记以匹配图像特征的维度query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)# 使用q-former模块计算查询标记和图像特征的交叉注意力,以更好的对齐图像和文本query_output = self.Qformer.bert(query_embeds=query_tokens,encoder_hidden_states=image_embeds,encoder_attention_mask=image_atts,return_dict=True,)# 通过线性投影层将q-former的output映射到语言模型的输入inputs_llama = self.llama_proj(query_output.last_hidden_state)else:image_embeds = image_embeds[:, 1:, :]bs, pn, hs = image_embeds.shapeimage_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))inputs_llama = self.llama_proj(image_embeds)# 创建语言模型的注意力掩码atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)# 返回最终输入语言模型中的图像编码、注意力掩码return inputs_llama, atts_llama

miniGPT-4将预训练的大语言模型和视觉编码器参数同时冻结,只需要单独训练线性投影层,使视觉特征和语言模型对齐。

二、训练过程

1. 预训练

  • 预训练数据:Conceptual Caption[175, 176]、SBU[177] 和 LAION[178] 的组合数据集进行模型预训练
  • 预训练共进行了约 2 万步,批量大小为 256,覆盖了 500 万个图像-文本
    对,在 4 张 A100 上训练了 10 个小时。
def preparing_embedding(self, samples):### prepare input tokensif 'image' in samples:# 对输入图像进行编码img_embeds, img_atts = self.encode_img(samples["image"])else:img_embeds = img_atts = Noneif 'conv_q' in samples:# handeling conversation datasetsconv_q, conv_a = samples['conv_q'], samples['conv_a']connect_sym = samples['connect_sym'][0]conv_q = [q.split(connect_sym)for q in conv_q]conv_a = [a.split(connect_sym) for a in conv_a]conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)else:# 生成文本指令if "instruction_input" in samples:instruction = samples["instruction_input"]elif self.prompt_list:instruction = random.choice(self.prompt_list)else:instruction = Noneif hasattr(self, 'chat_template') and self.chat_template:instruction = [self.prompt_template.format(instruct) for instruct in instruction]if 'length' in samples:# the input is a image train (like videos)bsz, pn, hs = img_embeds.shapeimg_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)# 将指令包装到提示中cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])else:cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)### prepare target tokens# 配置tokenizer以正确处理文本输入self.llama_tokenizer.padding_side = "right"text = [t + self.end_sym for t in samples["answer"]]# 使用tokenizer对文本进行编码regress_tokens = self.llama_tokenizer(text,return_tensors="pt",padding="longest",truncation=True,max_length=self.max_txt_len,add_special_tokens=False).to(self.device)regress_token_ids = regress_tokens.input_idsregress_atts = regress_tokens.attention_maskpart_targets = regress_token_ids.masked_fill(regress_token_ids == self.llama_tokenizer.pad_token_id, -100)# 连接图像编码、图像注意力、文本编码和文本注意力regress_embeds = self.embed_tokens(regress_token_ids)return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targetsdef forward(self, samples, reduction='mean'):# prepare the embedding to condition and the embedding to regresscond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \self.preparing_embedding(samples)# concat the embedding to condition and the embedding to regressinputs_embeds, attention_mask, input_lens = \self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)# get bos token embeddingbos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_idbos_embeds = self.embed_tokens(bos)bos_atts = cond_atts[:, :1]# add bos token at the begining# 获得整体的输入编码和注意力掩码inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)attention_mask = torch.cat([bos_atts, attention_mask], dim=1)# ensemble the final targets# 创建完整的目标序列,用于计算损失targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],dtype=torch.long).to(self.device).fill_(-100)for i, target in enumerate(part_targets):targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target  # plus 1 for bos# 在自动混合精度环境下,计算语言模型的输出with self.maybe_autocast():outputs = self.llama_model(inputs_embeds=inputs_embeds,attention_mask=attention_mask,return_dict=True,labels=targets,reduction=reduction)loss = outputs.lossreturn {"loss": loss}

2. 微调训练

  • 预训练后的模型一般不能直接生成符合用户意图的文本输出,多模态LLM这里一样和语言模型类似可以进行指令微调和RLHF
  • 优化策略1:改prompt让多模态LLM回答详细:
###Human: <Img><ImageFeature></Img> Describe this image in detail.
Give as many details as possible. Say everything you see. ###Assistant:
  • 优化策略2:筛选高质量SFT图文对微调数据,用如下prompt+chatGPT的方法进行筛选,修正文本中的语义、语法错误or结构问题。最终miniGPT4作者从5k条图文文本对数据中筛出3.5k数据。
Fix the error in the given paragraph.
Remove any repeating sentences, meaningless characters, not English sentences, and so on.
Remove unnecessary repetition. Rewrite any incomplete sentences.
Return directly the results without explanation.
Return directly the input paragraph if it is already correct without explanation.
  • 优化策略3:SFT阶段中query可以多样化,比如“详细描述该图像”、“你可以为我描述该图像的内容吗”、“解释这张图为啥有趣?”等。微调训练知识在训练数据和文本提示上与预训练过程略有不同。
    • 微调:只需要 400 个训练步骤,批量大小为 12,使用单张 A100 训练 7 分钟即可完成

三、MiniGPT-v2模型(待更新)

Reference

[1] https://github.com/Vision-CAIR/MiniGPT-4
[2] MiniGPT-4 知识点汇总
[3] 【vlm多模态大模型】minigpt-4详细解析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/9386.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

centos 8.5 Node v20.12.2 npm 安装及环境配置 配置淘宝最新镜像地址

1下载&#xff1a;Node.js — Download Node.js 2文件上传到服务器 rootlocalhost software]# tar xvf node-v20.12.2-linux-x64.tar.xz [rootlocalhost software]# mv node-v20.12.2-linux-x64/ /usr/local/node [rootlocalhost software]# vim /etc/profile export PA…

VSCode:设置搜索时的排除目录

VSCode搜索时默认会搜索目录下所有文件 $ tree . ├── a.c ├── m.c └── x └── b.c //a.c #include <stdio.h> #include <string.h>int main() {char s[] "hello\n";fprintf(stdout, s, strlen(s));return 0; } //m.c #include <stdio…

【LSTM】LSTM网络及参数学习笔记

图1 LSTM模型结构可视化 [6]. 图2 LSTM cell结构说明 图3 LSTM cell和num_units说明 [4]. 图4 LSTM的网络结构 1. LSTM 是对一个LSTM层的抽象&#xff0c;可以看成是由多个LSTM cell组成&#xff0c;是包含时间步的一个网络 2. LSTM cell 图2是LSTM在时间步上的结构&#xf…

【硬件开发】原型设计对于成功产品开发重要性及常见阶段

电子产品的设计与开发始于原型制作阶段。这些原型虽可能与最终产品极为相似&#xff0c;但总带有实验性质&#xff0c;因为电子原型的制作过程包括对新概念、新思想及新解决方案的测试。虽然存在出错的风险&#xff0c;跳过这一阶段可能会导致不必要的开支。不擅长电子硬件设计…

PELCO-D相机云台控制协议

pelco pelco D云台控制协议 参考手册 PELCO-D协议手册PELCO-D命令列表 PELCO-D格式 Pelco-D是由7个十六进制字节组成&#xff08;除非另有说明&#xff0c;本页中使用的所有字节数据均为十六进制格式&#xff09;。 Byte1Byte2Byte3Byte4Byte5Byte6Byte7Sync同步字节控制…

packageKit学习(一)

最近在学习packagekit&#xff0c;本系列主要讲述&#xff0c;如何使用packageKit接口。 1. 导入依赖 在使用packageKit 之前需要导入一些依赖和安装一些包&#xff0c;不然会报错&#xff0c;以下以报错信息讲解&#xff1a; cmakelist demo/updatesystemdemo/mainwindow.cpp…

element-ui 图片(图片压缩)与文件上传设置【添加请求头信息】

1.图片上传 <template><div><el-upload class"avatar-uploader" :action"upload /Api/upload" :show-file-list"false" :on-success"upSuccess":before-upload"beforeAvatarUpload" :on-exceed"handl…

2024年想要开一家抖音小店,需要多少钱?一篇详解!

大家好&#xff0c;我是电商糖果 随着抖音卖货的持续火爆&#xff0c;抖音小店也成了电商行业讨论度最大的项目之一。 不少朋友都想知道&#xff0c;如果今年开抖音小店大概需要多少钱。 糖果做小店的时间也比较长&#xff0c;也经营了多家小店。 对于开一家抖音小店需要多…

MADbench2

MADbench2 MADbench2是一款用于测试大规模并行架构的I/O、通信和计算子系统在真实科学应用压力下的综合性能的工具。 MADbench2 基于 MADspec 代码&#xff0c;该代码根据天空的噪声像素化图及其像素-像素噪声相关矩阵计算宇宙微波背景辐射的最大似然角功率谱。MADbench2 保留…

多规格产品应该如何设置呢?

今天一用户从供应商手中拿到产品价目表&#xff0c;但是设置起来蒙圈了&#xff0c;接下来我们就一起设置一下吧&#xff5e; 一、产品价格表 我们通过供应商手中拿到产品价目表是这个样子的&#xff1a; 我们可以看到此产品的销售客价根据不同地区导致的价格不同&#xff0…

ABAP小技巧汇总(自用)

1.TIMESTAMP搜索帮助 PARAMETERS:p_begin TYPE ty_screen-date_begiu MATCHCODE OBJECT cpe_timestamp, "开始时间戳p_end TYPE ty_screen-date_end MATCHCODE OBJECT cpe_timestamp. "结束时间戳 效果&#xff1a;

Git笔记-常用指令

Git笔记-常用指令 一、概述二、仓库管理二、缓存区操作1. 添加文件到缓存区2. 取消缓存文件3. 忽略列表 三、日志状态信息四、分支操作五、六、 一、概述 这里记录一些git常用的指令。 二、仓库管理 # 本地仓库初始化 git init# 克隆仓库 git clone git_url # git clone ht…

深入剖析Java的“幽灵之手“:NullPointerException的原因、解决与创意思考

1. 原因分析 java.lang.NullPointerException&#xff08;简称NPE&#xff09;是Java程序员在编程过程中经常会遇到的“幽灵之手”&#xff0c;它在毫无预警的情况下出现&#xff0c;让程序崩溃。NPE的根源在于尝试访问或修改一个null对象的成员或方法。以下是NPE出现的几个常…

怎么选择适合Selenium使用的网络代理

Selenium可以让你使用所有主流浏览器&#xff0c;访问你想测试的任何网站或服务。这种多功能性使 Selenium 不仅在测试方面不可或缺。例如&#xff0c;你可以将 Selenium 与 Python 结合使用&#xff0c;对网站进行搜刮。当然&#xff0c;为了不被拦截&#xff0c;你需要一个代…

2024数维杯数学建模竞赛B题完整思路代码和论文分析

2024数维杯数学建模B题完整代码和成品论文已更新&#xff0c;获取↓↓↓↓↓ https://www.yuque.com/u42168770/qv6z0d/bgic2nbxs2h41pvt?singleDoc# 2024数维杯数学建模竞赛B题完整思路代码论文分析如下&#xff1a; 问题分析 问题(1):分析正己烷不溶物(INS)对热解产率的…

win11个性化锁屏界面怎么关闭?

win11个性化锁屏界面关闭方法对于win11用户来说&#xff0c;关闭个性化锁屏界面是一个常见问题。本文将由php小编苹果详细介绍如何执行此操作&#xff0c;分步指导并提供操作截图。继续阅读以了解具体步骤。 win11个性化锁屏界面关闭方法 第一步&#xff0c;点击底部Windows图…

2024数维杯数学建模竞赛A题完整思路代码论文分析

2024数维杯数学建模A题完整代码和成品论文获取↓↓↓↓↓ https://www.yuque.com/u42168770/qv6z0d/bgic2nbxs2h41pvt?singleDoc# 2024数维杯数学建模竞赛A题完整思路代码论文分析如下&#xff1a; 问题分析 对A题4个小问题的分析如下: 第一个小问题的分析: 这一问题要求…

知识付费 管理系统,专业技术课程讲解视频怎么制作?制作事项有几条?

现在的网络课程&#xff0c;分为专业和非专业的两种&#xff0c;专业的就是要提供硬性技术的&#xff0c;如果是值了的老师&#xff0c;要制作专业技术课程讲解视频&#xff0c;那需要怎么制作?因为&#xff0c;专业课程的要求更为的严苛&#xff0c;所以&#xff0c;老师们也…

python数据分析常用基础语法

Python语言基础——语法基础 前言一、变量的介绍与使用变量的介绍变量命名规则变量的使用拓展 二、标识符标识符命名命名规则注意事项 三、数据类型数据类型的介绍数据类型的查看示例 四、输入与输出输入和输出的介绍format格式化输出占位符 五、代码缩进与注释代码缩进 前言 …

vue3 JSX的使用与警告【JSX 元素隐式具有类型 “any“,因为不存在接口 “JSX.IntrinsicElements“】解决办法

一、安装 pnpm i vitejs/plugin-vue-jsx -D 二、配置 1、tsconfig.json "compilerOptions":{"jsx":"preserve" } 2、vite.config.ts import VueJsx from "vitejs/plugin-vue-jsx"...plugin:[vue(),VueJsx() ] 三、简单使用案例…