从零学习大模型(九)-----P-Tuning(下)

代码展示P-Tuning的全过程

import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset# 1. 数据准备
dataset = load_dataset("imdb")# 2. 构建提示
def add_prompt(examples):examples['text'] = ["这段文本的情感是:'{}'".format(text) for text in examples['text']]return examplesdataset = dataset.map(add_prompt)# 3. 模型选择
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)# 4. 添加可训练的嵌入向量
class PromptEmbedding(nn.Module):def __init__(self, prompt_length, embedding_dim):super(PromptEmbedding, self).__init__()self.prompt_embedding = nn.Parameter(torch.randn(prompt_length, embedding_dim))def forward(self, x):prompt = self.prompt_embedding.unsqueeze(0).repeat(x.size(0), 1, 1)  # 扩展到batch大小return torch.cat((prompt, x), dim=1)# 定义新模型
class P_Tuning_BERT(nn.Module):def __init__(self, base_model, prompt_length):super(P_Tuning_BERT, self).__init__()self.base_model = base_modelself.prompt_embedding = PromptEmbedding(prompt_length, base_model.bert.config.hidden_size)def forward(self, input_ids, attention_mask=None, labels=None):# 获取原始的输入嵌入embeddings = self.base_model.bert.embeddings(input_ids)# 添加prompt嵌入embeddings = self.prompt_embedding(embeddings)outputs = self.base_model.bert(inputs_embeds=embeddings, attention_mask=attention_mask)logits = self.base_model.classifier(outputs[1])  # 只取池化输出return (logits,)# 设置P-Tuning模型
prompt_length = 5  # Prompt的长度
p_tuning_model = P_Tuning_BERT(model, prompt_length)# 冻结原模型参数
for param in p_tuning_model.base_model.parameters():param.requires_grad = False# 5. 数据预处理
def tokenize_function(examples):return tokenizer(examples['text'], truncation=True, padding=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)# 6. 微调过程
training_args = TrainingArguments(output_dir='./results',evaluation_strategy='epoch',learning_rate=2e-5,per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=3,weight_decay=0.01,
)trainer = Trainer(model=p_tuning_model,args=training_args,train_dataset=tokenized_datasets['train'],eval_dataset=tokenized_datasets['test'],
)# 7. 训练模型
trainer.train()# 8. 测试模型
trainer.evaluate()# 9. 应用模型
def predict(text):p_tuning_model.eval()inputs = tokenizer("这段文本的情感是:'{}'".format(text), return_tensors="pt", truncation=True, padding=True)with torch.no_grad():outputs = p_tuning_model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])logits = outputs[0]predicted_class = torch.argmax(logits, dim=-1)return "积极" if predicted_class.item() == 1 else "消极"# 测试应用
print(predict("这家餐厅的服务很好。"))

P-Tuning的实验结果

文本分类任务

  • P-Tuning:在小数据集上,F1-score为0.85,训练时间为1小时。
  • 全参数微调:在相同数据集上,F1-score为0.88,训练时间为3小时,但在验证集上过拟合(F1-score为0.80)。

对话生成任务

  • P-Tuning:生成的回复自然性评分为4.2/5,训练时间为2小时。
  • 全参数微调:生成的回复自然性评分为4.5/5,但训练时间为5小时。

P-Tuning的优点

1. 计算效率高

  • 参数更新少:P-Tuning仅更新与提示相关的嵌入向量,减少了训练过程中需要优化的参数数量。这意味着在同样的计算资源下,可以更快速地进行实验和模型调整。

2. 减少过拟合风险

  • 冻结预训练模型的参数:通过冻结大部分模型参数,P-Tuning降低了在小数据集上过拟合的风险。对于数据量有限的任务,P-Tuning能够更好地泛化。

3. 灵活性和适应性强

  • 任务适应性:可以通过简单地调整提示内容来适应不同的任务,无需修改整个模型架构。这使得在多任务场景中,P-Tuning能够快速切换和调整。
  • Prompt设计自由:研究者可以根据具体任务设计不同的提示,以探索对模型性能的影响。这种灵活性允许在多个任务之间共享同一模型,而只需修改提示。

4. 易于实现和部署

  • 实现简单:相较于全参数微调,P-Tuning的实现更加简便,尤其是在不需要重新训练整个模型的情况下。只需在输入中添加提示即可。
  • 资源需求低:由于只更新部分参数,P-Tuning对计算资源的需求较低,可以在较小的硬件上进行训练和部署。

5. 在小数据集上的表现良好

  • 数据效率高:P-Tuning特别适用于小数据集场景,在这些场景中,训练整个模型可能导致性能下降,而P-Tuning可以利用预训练的知识,有效提升模型的性能。

6. 提升模型的可解释性

  • 可解释性增强:由于P-Tuning强调了提示的作用,研究者可以更清晰地理解模型如何通过特定提示来做出不同的决策。这对于分析模型的行为和结果非常有帮助。

7. 迁移学习效果好

  • 知识迁移:P-Tuning能够有效地利用预训练模型中存储的知识,通过适当的提示,将这种知识迁移到新任务中。这使得在许多下游任务中,P-Tuning能够实现与全参数微调相当甚至更好的性能。

P-Tuning的局限性

1. 提示设计的依赖性

  • 提示的有效性:P-Tuning的性能高度依赖于提示的设计和选择。不同的提示可能会导致模型产生不同的预测结果。如果提示设计不当,可能会影响模型的理解和预测能力。
  • 提示选择的挑战:设计有效的提示需要领域知识和经验,这对于非专业人士来说可能是一个挑战。

2. 学习到的提示嵌入的复杂性

  • 提示嵌入的可解释性:虽然P-Tuning提供了一定的可解释性,但学习到的提示嵌入的具体意义和如何影响模型决策可能仍然不够清晰。研究者可能难以解读这些嵌入的具体作用。
  • 相似性问题:不同任务或数据集可能会导致提示嵌入相似性较高,导致模型在迁移到新任务时表现不佳。

3. 数据集和任务的限制

  • 适用性问题:P-Tuning在小数据集上表现良好,但在大规模和复杂任务中,可能无法完全发挥预训练模型的潜力。在某些情况下,全参数微调可能仍然是更优的选择。
  • 数据分布差异:如果训练和测试数据的分布差异较大,P-Tuning的效果可能受到影响,特别是如果提示未能充分捕捉任务的关键特征。

4. 对训练资源的需求

  • 额外的训练时间:尽管P-Tuning的训练参数较少,但学习提示嵌入仍然需要一定的训练时间和计算资源。在资源有限的情况下,可能仍需权衡使用全参数微调与P-Tuning的选择。

5. 任务特定性

  • 领域适应性:某些领域的特定任务可能不适合使用P-Tuning,尤其是在需要高度专业化的知识和上下文理解的情况下。全参数微调可能更好地适应这些特定的领域。

6. 模型性能的极限

  • 性能瓶颈:由于只更新部分参数,P-Tuning在某些情况下可能无法突破预训练模型的性能极限。在需要极高性能的任务中,全参数微调可能更能挖掘模型的潜力。

P-Tuning的未来发展方向

1. 大规模模型的适应性

  • 模型架构的调整:为适应更大规模的模型,P-Tuning可以通过调整提示嵌入的维度和数量来保持与模型的对齐。这意味着需要为每个新任务设计适当的嵌入结构。
  • 分层提示:对于大型模型,可以设计分层的提示结构,允许在不同层次上进行信息传递,从而使模型更有效地利用提示信息。

2. 多任务学习

  • 共享提示嵌入:在多任务设置中,可以设计共享的提示嵌入,以便在不同任务之间传递信息。这有助于提高模型的训练效率,并减少为每个任务单独训练提示的需求。
  • 动态提示调整:利用动态生成的提示来适应不同任务的需求。通过实时分析任务特征,生成适合特定任务的提示,从而增强模型的适应性。

3. 增强训练方法

  • 自适应学习率:为不同任务的提示嵌入设置不同的学习率,以便更好地适应每个任务的特性。这可以通过监控每个任务的性能来动态调整学习率。
  • 数据增强:结合数据增强技术,在训练过程中引入多样化的训练样本,从而提高模型在新任务和大规模数据集上的泛化能力。

4. 集成方法

  • 与其他技术结合:将P-Tuning与其他微调技术(如LoRA、Adapter等)结合使用,可以进一步提升模型的性能。这些技术可以帮助在不大幅增加模型参数的情况下,增强模型对新任务的适应性。
  • 知识蒸馏:通过知识蒸馏技术,将大型模型的知识迁移到较小的模型中,同时利用P-Tuning进行微调,可以在资源有限的情况下实现较好的性能。

5. 任务定制化

  • 针对性任务提示设计:针对特定任务或领域设计专门的提示嵌入,以确保它们能有效捕捉任务特征。这可能包括对领域特定的语言和上下文的理解。
  • 领域适应性:在特定领域(如医疗、法律等)中,通过细化提示以增强对领域术语和上下文的理解,提升模型在特定领域任务上的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/58417.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是安全组件?

安全组件是信息系统中用于保护数据和系统安全的关键部分。它们通常包括一系列的软件和硬件组件,旨在提供身份验证、授权、数据加密、防病毒、入侵检测等功能。这些组件可以是独立的软件程序,也可以是嵌入到操作系统或应用程序中的模块,或者作…

J3学习打卡

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 DensNet模型 import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers, models, initializersclass DenseLayer(lay…

基于微信小程序的小区管理系统设计与实现(lw+演示+源码+运行)

摘 要 社会发展日新月异,用计算机应用实现数据管理功能已经算是很完善的了,但是随着移动互联网的到来,处理信息不再受制于地理位置的限制,处理信息及时高效,备受人们的喜爱。所以各大互联网厂商都瞄准移动互联网这个潮…

随机变量、取值、样本和统计量之间的关系

1. 随机变量 (Random Variable) 随机变量是用来量化随机现象结果的一种数学工具。随机变量是一个函数,它将实验结果映射到数值。随机变量可以是离散的或连续的。 离散随机变量:取有限或可数无限个值。例如,掷骰子的结果。连续随机变量&…

Matlab实现蚁群算法求解旅行商优化问题(TSP)(理论+例子+程序)

一、蚁群算法 蚁群算法由意大利学者Dorigo M等根据自然界蚂蚁觅食行为提岀。蚂蚁觅食行为表示大量蚂蚁组成的群体构成一个信息正反馈机制,在同一时间内路径越短蚂蚁分泌的信息就越多,蚂蚁选择该路径的概率就更大。 蚁群算法的思想来源于自然界蚂蚁觅食&a…

Pandas行转列与列装行

实际上,两种操作的核心代码确实非常相似,因为它们都涉及到将 JSON 数据解析并进行拆分。主要的区别在于操作的顺序和处理的对象: 一列转多列: 首先,我们将 JSON 数据列中的每个 JSON 对象解析为 Python 字典&#xff…

物联网智能项目实战:智能温室监控系统

物联网(Internet of Things, IoT)技术正在以前所未有的速度改变着我们的生活方式。通过将传感器、执行器和其他物理设备连接到互联网,物联网技术可以实现远程监测和控制。本文将通过一个具体的物联网智能项目——智能温室监控系统的实现&…

给哔哩哔哩bilibili电脑版做个手机遥控器

前言 bilibili电脑版可以在电脑屏幕上观看bilibili视频。然而,电脑版的bilibili不能通过手机控制视频翻页和调节音量,这意味着观看视频时需要一直坐在电脑旁边。那么,有没有办法制作一个手机遥控器来控制bilibili电脑版呢? 首先…

JavaEE初阶---网络原理之TCP篇(二)

文章目录 1.断开连接--四次挥手1.1 TCP状态1.2四次挥手的过程1.3time_wait等待1.4三次四次的总结 2.前段时间总结3.滑动窗口---传输效率机制3.1原理分析3.2丢包的处理3.3快速重传 4.流量控制---接收方安全机制4.1流量控制思路4.2剩余空间大小4.3探测包的机制 5.拥塞控制---考虑…

【C语言刷力扣】3216.交换后字典序最小的字符串

题目: 解题思路: 字典序最小的字符串:是指按照字母表顺序排列最前的字符串。即字符串在更靠前的位置出现比原字符串对应字符在字母表更早出现的字符。 枚举数组元素,尽早将较小的同奇偶的相邻字符交换。 char* getSmallestString…

定时器(多线程)

标准库中的定时器 • 标准库中提供了⼀个 Timer 类. Timer 类的 核⼼⽅法为 schedule . • schedule 包含两个参数. 第⼀个参数指定即将要执⾏的任务代码, 第⼆个参数指定多⻓时间之后 执⾏ (单位为毫秒). Timer timer new Timer (); timer.schedule( new TimerTas…

Linux(centOS)的安全命令

先全部列出来: 命令及其作用: - setenforce 0:将 SELinux 临时切换为宽松模式(permissive) - setenforce 1:将 SELinux 临时切换为强制模式(enforcing) - selinux的配置文件在/e…

Java:Map和Set练习

目录 查找字母出现的次数 只出现一次的数字 坏键盘打字 查找字母出现的次数 这道题的思路在后面的题目过程中能用到,所以先把这题给写出来 题目要求:给出一个字符串数组,要求输出结果为其中每个字符串及其出现次数。 思路:我…

【宠粉赠书】大模型项目实战:多领域智能应用开发

在当今的人工智能与自然语言处理领域,大型语言模型(LLM)凭借其强大的生成与理解能力,正在广泛应用于多个实际场景中。《大模型项目实战:多领域智能应用开发》为大家提供了全面的应用技巧和案例,帮助开发者深…

【商汤科技-注册/登录安全分析报告】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

Chromium HTML5 新的 Input 类型week对应c++

一、Input 类型: week week 类型允许你选择周和年。 <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>test</title> </head><body><form action"demo-form.php">选择周: <inp…

Nginx防盗链配置

1. 什么是盗链? 盗链是指服务提供商自己不提供服务的内容&#xff0c;通过技术手段绕过其它有利益的最终用户界面&#xff08;如广告&#xff09;&#xff0c;直接在自己的网站上向最终用户提供其它服务提供商的服务内容&#xff0c;骗取最终用户的浏览和点击率。受益者不提供…

Oracle+11g+笔记(8)-备份与恢复机制

Oracle11g笔记(8)-备份与恢复机制 8、备份与恢复机制 8.1 备份与恢复的方法 数据库的备份是对数据库信息的一种操作系统备份。这些信息可能是数据库的物理结构文件&#xff0c;也可能是某一部分数 据。在数据库正常运行时&#xff0c;就应该考虑到数据库可能出现故障&#…

基于Multisim的篮球比赛电子记分牌设计与仿真

一、设计任务与要求 设计一个符合篮球比赛规则的记分系统。 &#xff08;1&#xff09;有得1分、2分和3分的情况&#xff0c;电路要具有加、减分及显示的功能。 &#xff08;2&#xff09;有倒计时时钟显示&#xff0c;在“暂停时间到”和“比赛时间到”时&#xff0c;发出声光…

CTF-PWN: 什么是_IO_FILE?

重要概念:fopen()返回的是一个结构体的指针 _IO_FILE 结构体在什么时候被创建&#xff1f; _IO_FILE 结构体的实例是在程序使用标准 I/O 函数&#xff08;如 fopen、fclose、fread、fwrite 等&#xff09;时创建和管理的。这个结构体实际上是 GNU C Library (glibc) 用于处理…