transformers - 预测中间词

代码


from transformers import AutoTokenizer#加载编码器
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base', use_fast=True)print(tokenizer)#编码试算
tokenizer.batch_encode_plus(['hide new secretions from the parental units','contains no wit , only labored gags'
])

PreTrainedTokenizerFast(name_or_path='distilroberta-base', vocab_size=50265, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
{'input_ids': [[0, 37265, 92, 3556, 2485, 31, 5, 20536, 2833, 2], [0, 10800, 5069, 117, 22094, 2156, 129, 6348, 3995, 821, 8299, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

加载数据

from datasets import load_dataset, load_from_disk#加载数据
dataset = load_dataset(path='glue', name='sst2')
# dataset = load_from_disk('datas/glue/sst2')#分词,同时删除多余的字段
def f(data):return tokenizer.batch_encode_plus(data['sentence'])dataset = dataset.map(f,batched=True,batch_size=1000,num_proc=4,remove_columns=['sentence', 'idx', 'label'])#过滤掉太短的句子
def f(data):return [len(i) >= 9 for i in data['input_ids']]dataset = dataset.filter(f, batched=True, batch_size=1000, num_proc=4)#截断句子,同时整理成模型需要的格式
def f(data):b = len(data['input_ids'])data['labels'] = data['attention_mask'].copy()for i in range(b):#裁剪长度到9data['input_ids'][i] = data['input_ids'][i][:9]data['attention_mask'][i] = [1] * 9data['labels'][i] = [-100] * 9#input_ids最后一位是2data['input_ids'][i][-1] = 2#每一句话第4个词为mask#tokenizer.get_vocab()['<mask>'] -> 50264data['labels'][i][4] = data['input_ids'][i][4]data['input_ids'][i][4] = 50264return datadataset = dataset.map(f, batched=True, batch_size=1000, num_proc=4)dataset, dataset['train'][0]

import torch
from transformers.data.data_collator import default_data_collator#能够实现随机mask的collate_fn
#如果要使用这个工具类,在数据预处理时就不需要设置数据中的mask,然后让labels=input_ids.copy即可
#from transformers import DataCollatorForLanguageModeling
#data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm_probability=0.1)#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=8,collate_fn=default_data_collator,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breaklen(loader), data

(5534,{'input_ids': tensor([[    0, 12196,   128,    29, 50264, 10132,    59,  9326,     2],[    0,  1250,     5,  3768, 50264, 34948, 16658,     8,     2],[    0,   627,   936,    16, 50264,   240, 12445,  2129,     2],[    0,  3654,   350, 13185, 50264,    45,   350,  8794,     2],[    0,   560,    28,    56, 50264,  3541, 34261,    19,     2],[    0,   560,   224,    14, 50264,   473,   295,    75,     2],[    0,     6, 14784,  1054, 50264,    10,   686,   865,     2],[    0,  9006,  1495,  2156, 50264, 23317,  4780,     8,     2]]),'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1]]),'labels': tensor([[-100, -100, -100, -100,  144, -100, -100, -100, -100],[-100, -100, -100, -100,   32, -100, -100, -100, -100],[-100, -100, -100, -100,    5, -100, -100, -100, -100],[-100, -100, -100, -100, 2156, -100, -100, -100, -100],[-100, -100, -100, -100,   31, -100, -100, -100, -100],[-100, -100, -100, -100,   24, -100, -100, -100, -100],[-100, -100, -100, -100,   34, -100, -100, -100, -100],[-100, -100, -100, -100,   10, -100, -100, -100, -100]])})

from transformers import AutoModelForCausalLM, RobertaModel#加载模型
#model = AutoModelForCausalLM.from_pretrained('distilroberta-base')#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = RobertaModel.from_pretrained('distilroberta-base')decoder = torch.nn.Linear(768, tokenizer.vocab_size)decoder.bias = torch.nn.Parameter(torch.zeros(tokenizer.vocab_size))self.fc = torch.nn.Sequential(torch.nn.Linear(768, 768),torch.nn.GELU(),torch.nn.LayerNorm(768, eps=1e-5),decoder,)#加载预训练模型的参数parameters = AutoModelForCausalLM.from_pretrained('distilroberta-base')self.fc[0].load_state_dict(parameters.lm_head.dense.state_dict())self.fc[2].load_state_dict(parameters.lm_head.layer_norm.state_dict())self.fc[3].load_state_dict(parameters.lm_head.decoder.state_dict())self.criterion = torch.nn.CrossEntropyLoss()def forward(self, input_ids, attention_mask, labels=None):logits = self.pretrained(input_ids=input_ids,attention_mask=attention_mask)logits = logits.last_hidden_statelogits = self.fc(logits)loss = Noneif labels is not None:shifted_logits = logits[:, :-1].reshape(-1, tokenizer.vocab_size)shifted_labels = labels[:, 1:].reshape(-1)loss = self.criterion(shifted_logits, shifted_labels)return {'loss': loss, 'logits': logits}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)out = model(**data)out['loss'], out['logits'].shape



测试

#测试
def test():model.eval()loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],batch_size=8,collate_fn=default_data_collator,shuffle=True,drop_last=True,)correct = 0total = 0for i, data in enumerate(loader_test):#保存下数据中的label,后面计算正确率要用label = data['labels'][:, 4].clone()#从数据中抹除掉label,防止模型作弊data['labels'] = None#计算with torch.no_grad():out = model(**data)#[8, 10, 50265] -> [8, 10]out = out['logits'].argmax(dim=2)[:, 4]correct += (label == out).sum().item()total += 8if i % 10 == 0:print(i)print(label)print(out)if i == 50:breakprint(correct / total)for i in range(8):print(tokenizer.decode(data['input_ids'][i]))print(tokenizer.decode(label[i]), tokenizer.decode(out[i]))test()

0
tensor([   47, 14838,  5392,    28,    80,  4839,  3668,    29])
tensor([   47, 14633,   749,    28,    80,  4839,  3668,  2156])
10
tensor([ 101,  668,   16,   14,  352,  650, 3961,   16])
tensor([ 101,  773, 7897,   59, 2156, 7397, 3961,   16])
20
tensor([40485,    13,    29, 19303,    33,    16,   295,     9])
tensor([40485,    13,  4839, 16393,    33,  3391,   256,     9])
30
tensor([   53, 33469,  3315,  3723,     7, 24473, 40776,    41])
tensor([11248, 15923,  3315,  3723,     7, 24473, 40776,    41])
40
tensor([ 2435,     5,  2046,  2084, 25210,     9, 42661,     7])
tensor([ 2343,    42,  4265,  8003, 33709,  7021,  9021,     6])
50
tensor([  297, 22258,   998,    64,    10,  1499,    65,  2156])
tensor([  457, 22258,  6545,    64,    10, 10416,    65, 33647])
0.32598039215686275
<s>a strong first<mask>, slightly less</s>quarter  half
<s>( villene<mask> ) seems to</s>
uve uve
<s>going to the<mask> may be just</s>website  gym

from transformers import AdamW
from transformers.optimization import get_scheduler#训练
def train():optimizer = AdamW(model.parameters(), lr=2e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader),optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 50 == 0:label = data['labels'][:, 4]out = out['logits'].argmax(dim=2)[:, 4]correct = (label == out).sum().item()accuracy = correct / 8lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), accuracy, lr)torch.save(model, 'models/2.预测中间词.model')train()

/root/anaconda3/envs/cpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warningFutureWarning,
0 18.949838638305664 0.0 1.9996385977593064e-05
50 4.755198001861572 0.625 1.9815684857246115e-05
100 5.0272216796875 0.25 1.963498373689917e-05
150 4.625316143035889 0.125 1.9454282616552225e-05
200 3.663780927658081 0.5 1.927358149620528e-05
250 2.5342917442321777 0.375 1.909288037585833e-05
300 4.986537933349609 0.375 1.8912179255511386e-05
350 3.403028964996338 0.625 1.873147813516444e-05
400 4.041268348693848 0.125 1.8550777014817495e-05
450 3.2715964317321777 0.5 1.8370075894470547e-05
500 2.6591811180114746 0.5 1.81893747741236e-05
550 4.937175750732422 0.25 1.8008673653776656e-05
600 4.845945835113525 0.25 1.7827972533429708e-05
650 1.8658218383789062 0.625 1.7647271413082763e-05
700 3.9473319053649902 0.25 1.7466570292735818e-05
750 2.065851926803589 0.625 1.728586917238887e-05
800 2.957096576690674 0.5 1.7105168052041924e-05
850 4.987250804901123 0.25 1.692446693169498e-05
900 3.5697021484375 0.5 1.674376581134803e-05
950 2.898092746734619 0.5 1.6563064691001085e-05
1000 4.39031457901001 0.375 1.638236357065414e-05

预测

model = torch.load('models/2.预测中间词.model')
test()

2022-12-08

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/2981.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

@PropertySource的使用

假设我们有一个名为 database.properties 的属性文件&#xff0c;内容如下&#xff0c;该文件位于项目的类路径 (resources 目录) 下&#xff1a; # database.properties db.urljdbc:mysql://localhost:3306/mydb db.usernameroot db.passwordpassword 然后&#xff0c;创建一…

STM32,复位和时钟控制

外部时钟 HSE 以后需要用到什么就这样直接拿去配就行了

【文件上传与包含漏洞综合利用】DVWA-文件上传-难度:High

实验过程和结果 步骤1&#xff1a;尝试直接上传php木马&#xff0c;失败&#xff0c;截图如下&#xff1a; 步骤2&#xff1a;将php木马后缀改为jpeg尝试上传&#xff0c;依旧失败&#xff0c;截图如下&#xff1a; 步骤3&#xff1a;将真实的jpeg图片1.jpeg上传&#xff0c;成…

CNPM、NPM 和 Yarn:JavaScript 包管理器的比较

在现代Web开发中&#xff0c;包管理器是不可或缺的工具&#xff0c;它们帮助开发者管理项目中使用的各种第三方库。在JavaScript世界里&#xff0c;最常见的包管理器有 NPM、Yarn 和 CNPM。本文将详细介绍这三者的不同之处&#xff0c;并用简单的例子来帮助初学者理解每种工具的…

企业微信hook接口协议,ipad协议http,外部联系人图片视频文件下载

外部联系人文件下载 参数名必选类型说明file_id是StringCDNkeyopenim_cdn_authkey是String认证keyaes_key是Stringaes_keysize是int文件大小 请求示例 {"url": "https://imunion.weixin.qq.com/cgi-bin/mmae-bin/tpdownloadmedia?paramv1_e80c6c6c0cxxxx3544d9…

AI作画算法原理详解

人工智能绘画&#xff08;AI绘画&#xff09;算法通常基于深度学习框架&#xff0c;尤其是生成对抗网络&#xff08;GANs&#xff09;。这些算法通过训练大量的艺术作品数据&#xff0c;学会生成新的图像&#xff0c;这些图像在风格和内容上与训练数据相似。 生成对抗网络&…

智慧火电厂合集 | 数字孪生助推能源革命

火电厂在发电领域中扮演着举足轻重的角色。主要通过燃烧如煤、石油或天然气等化石燃料来产生电力。尽管随着可再生能源技术的进步导致其比重有所减少&#xff0c;但直至 2023 年&#xff0c;火电依然是全球主要的电力来源之一。 通过图扑软件自主研发 HT for Web 产品&#xf…

[Algorithm][前缀和][和为K的子数组][和可被K整除的子数组][连续数组][矩阵区域和]详细讲解

目录 1.和为 K 的子数组1.题目链接2.算法原理详解3.代码实现 2.和可被 K 整除的子数组1.题目链接2.算法原理详解3.代码实现 3.连续数组1.题目链接2.算法原理详解3.代码实现 4.矩阵区域和1.题目链接2.算法原理详解3.代码实现 1.和为 K 的子数组 1.题目链接 和为 K 的子数组 2.…

牛客 题解

文章目录 day4_17**BC149** **简写单词**思路&#xff1a;模拟代码&#xff1a; dd爱框框思路&#xff1a;滑动窗口&#xff08;同向双指针&#xff09;代码&#xff1a; 除2&#xff01;思路&#xff1a;模拟贪心堆代码&#xff1a; day4_17 BC149 简写单词 https://www.now…

如何在 Ubuntu 14.04 上配置 StatsD 以收集 Graphite 的任意统计数据

介绍 Graphite 是一个图形库&#xff0c;允许您以灵活和强大的方式可视化不同类型的数据。它通过其他统计收集应用程序发送给它的数据进行图形化。 在之前的指南中&#xff0c;我们讨论了如何安装和配置 Graphite 本身&#xff0c;以及如何安装和配置 collectd 以编译系统和服…

【MATLAB源码-第197期】基于matlab的粒子群算法(PSO)结合人工蜂群算法(ABC)无人机联合卡车配送仿真。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 基于粒子群优化&#xff08;PSO&#xff09;算法的无人机联合卡车配送系统是一个高效的物流配送策略&#xff0c;旨在优化配送过程中的成本、时间和资源利用率。该系统融合了无人机和卡车的配送能力&#xff0c;通过智能算法…

mermaid 之 (Flowchart) 流程图

(Flowchart) 流程图是一种在Mermaid中常用的图形&#xff0c;用于描述一系列步骤和决策。以下是Mermaid中创建流程图的详细语法介绍&#xff1a; 前言 官网文档 基础语法 图的方向 graph TD&#xff1a;从上到下 (Top Down)graph LR&#xff1a;从左到右 (Left to Right)g…

Java23种设计模式-创建型模式之抽象工厂模式

抽象工厂模式(Abstract Factory Pattern)是一种创建型设计模式&#xff0c;它用于创建相关或相互依赖对象的一组&#xff0c;而无需指定其具体的类。这种模式特别适用于产品族的情况&#xff0c;即一组相互关联的产品对象。 存在四种角色&#xff1a; 角色1&#xff1a;抽象工…

Tiny11作者开源:利用微软官方镜像制作独属于你的Tiny11镜像

微软对Windows 11的最低硬件要求包括至少4GB的内存、双核处理器和64GB的SSD存储。然而&#xff0c;这些基本要求仅仅能保证用户启动和运行系统&#xff0c;而非流畅使用 为了提升体验&#xff0c;不少用户选择通过精简系统来减轻硬件负担&#xff0c;我们熟知的Tiny11便是其中…

【简单介绍下机器学习之sklearn基础】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

【机器学习】深度神经网络(DNN):原理、应用与代码实践

深度神经网络&#xff08;DNN&#xff09;&#xff1a;原理、应用与代码实践 一、深度神经网络&#xff08;DNN&#xff09;的基本原理二、DNN的优缺点分析三、DNN的代码实践四、总结与展望 在人工智能与机器学习的浪潮中&#xff0c;深度神经网络&#xff08;Deep Neural Netw…

演示在一台Windows主机上运行两个Mysql服务器(端口号3306 和 3307),安装步骤详解

目录 在一台Windows主机上运行两个Mysql服务器&#xff0c;安装步骤详解因为演示需要两个 MySQL 服务器终端&#xff0c;我只有一个 3306 端口号的 MySQL 服务器&#xff0c;所以需要再创建一个 3307 的。创建一个3307端口号的MySQL服务器1、复制 mysql 的安装目录2、修改my.in…

安全开发实战(4)--whois与子域名爆破

目录 安全开发专栏 前言 whois查询 子域名 子域名爆破 1.4 whois查询 方式1: 方式2: 1.5 子域名查询 方式1:子域名爆破 1.5.1 One 1.5.2 Two 方式2:其他方式 总结 安全开发专栏 安全开发实战​​http://t.csdnimg.cn/25N7H 前言 whois查询 Whois 查询是一种用…

MCU功耗测量

功耗测量 一、相关概念二、功耗的需求三、测量仪器仪表测量连接SMU功能SMU性能指标 四、功耗测量注意点板子部分存在功耗MCU方面&#xff0c;可能存在干扰项仪器仪表方面 一、相关概念 静态功耗和动态功耗&#xff1a;动态功耗为运行功耗&#xff0c;功耗测量注重每MHz下的功耗…

DevOps文化对团队有何影响?

DevOps文化对团队有很多积极影响&#xff0c;包括提高团队效率、促进沟通与协作、提高产品质量和推动创新等方面。然而&#xff0c;实施DevOps文化也需要一定的挑战&#xff0c;如改变团队成员的观念、引入新的工具和流程等。因此&#xff0c;团队需要充分了解DevOps文化的价值…