生成任务,大模型

一个生成项目

输入:文字描述(但是给的数据集是一串数字,id,ct描述,医生描述)
输出:诊断报告

一、数据处理

import pandas as pd  #处理表格数据pre_train_file= "data/train.csv"train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #读入数据print(train_df.head())train_data = train_df.sample(frac=0.9, random_state=0, axis=0)   #采样0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)]       #干啥的,  过来用train_data.to_csv("data/pro_train_data.csv", index=False,header=False)val_data.to_csv("data/pro_val_data.csv", index=False,header=False)

主要是用于从一个CSV文件中读取数据,并将其划分为训练集和验证集,然后将这两个数据集分别保存到新的CSV文件中。

代码逐行解释

导入必要的库
import pandas as pd  # 处理表格数据
  • pandas:一个强大的数据分析和处理库,特别适合处理表格数据(如CSV文件)。
定义文件路径并读取数据
pre_train_file = "data/train.csv"train_df = pd.read_csv(pre_train_file, header=None, names=["id", "input", "tgt"])  # 读入数据print(train_df.head())
  • pre_train_file:指定要读取的CSV文件路径。
  • pd.read_csv
    • header=None:表示CSV文件没有表头(第一行不是列名)。
    • names=["id", "input", "tgt"]:为每一列指定名称。
  • print(train_df.head()):打印前五行数据,以便检查读取是否正确。
数据划分
train_data = train_df.sample(frac=0.9, random_state=0, axis=0)  # 采样0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)]  # 干啥的, 过来用
  • train_data

    • 使用 sample 方法随机采样90%的数据作为训练集。
    • frac=0.9:表示采样的比例为90%。
    • random_state=0:设置随机种子以确保结果可重复。
    • axis=0:表示沿行方向进行采样(默认行为)。
  • val_data

    • 使用 ~train_df.index.isin(train_data.index) 来获取不在训练集中的数据作为验证集。
    • isin(train_data.index) 返回一个布尔数组,指示哪些索引在训练集中。
    • ~ 取反操作符,返回不在训练集中的索引。
保存数据
train_data.to_csv("data/pro_train_data.csv", index=False, header=False)val_data.to_csv("data/pro_val_data.csv", index=False, header=False)
  • to_csv 方法
    • 将DataFrame保存为CSV文件。
    • index=False:不保存行索引。
    • header=False:不保存列名。

二、处理词表

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_argsargs = parse_args()         #设置 ,字典, 属性类  config  {}def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 训练集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datastrain_data = load_data('./data/train.csv')token2count = Counter()     #计数工具 哈希表for i in train_data:token2count.update(i)       #不需要知道原理tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0,"[PAD]")
vocab.insert(100,"[UNK]")
vocab.insert(101,"[CLS]")
vocab.insert(102,"[SEP]")
vocab.insert(103,"[MASK]")
vocab.insert(104,"[EOS]")
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab()   #获取模型词表# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab:         #mn复杂度
#     if v not in vocabs:
#         count += 1
#         new_vocabs.append(v)
# print(len(new_vocabs))
new_vocabs = vocab
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n")    #保存model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)      #模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin')
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)

1. 导入必要的库

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args
  • sys:用于系统相关的操作(如命令行参数)。
  • torch:PyTorch的核心库,用于深度学习模型。
  • Counter:来自 collections 模块,用于统计元素出现的次数。
  • BertTokenizer, BartConfig, BartForConditionalGeneration:来自 transformers 库,分别用于分词、配置和加载预训练模型。
  • parse_args:自定义函数,用于解析命令行参数或配置文件,返回一个包含配置参数的对象。

2. 解析参数

args = parse_args()  # 设置,字典,属性类 config {}
  • parse_args:调用自定义函数解析配置参数,并将其存储在 args 对象中。假设 args 包含诸如 pre_model_path 等路径信息。

3. 定义数据加载函数

def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 训练集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datas
  • load_data 函数
    • 打开指定路径的文件并读取每一行。
    • 使用 strip() 去除每行的前后空白字符,并使用 split(",") 将其按逗号分割为列表。
    • 如果列表长度为3(假设是训练集),则将第二列和第三列的数据拆分为单词列表,并合并后添加到 datas 列表中。
    • 如果列表长度不为3,则仅处理第二列的数据,并将其拆分为单词列表后添加到 datas 列表中。
    • 返回 datas 列表。

4. 加载数据

train_data = load_data('./data/train.csv')
  • 调用 load_data 函数加载训练数据,并将结果存储在 train_data 变量中。

5. 统计词频

token2count = Counter()  # 计数工具 哈希表for i in train_data:token2count.update(i)  # 不需要知道原理
  • token2count:使用 Counter 类创建一个哈希表来统计每个单词出现的次数。
  • 遍历 train_data 中的每一行数据,并使用 update 方法更新 token2count,记录每个单词出现的次数。

6. 创建词汇表

tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0, "[PAD]")
vocab.insert(100, "[UNK]")
vocab.insert(101, "[CLS]")
vocab.insert(102, "[SEP]")
vocab.insert(103, "[MASK]")
vocab.insert(104, "[EOS]")
  • tail:筛选出频率大于等于 ct 的单词,并按字母顺序排序。注意这里 ct 设为0,因此所有单词都会被包含进来。
  • vocab:将 tail 赋值给 vocab
  • 插入特殊标记:在 vocab 中插入一些特殊的标记符号(如 [PAD], [UNK], [CLS], [SEP], [MASK], [EOS]),这些标记在自然语言处理任务中具有特定含义。

7. 保存词汇表

new_vocabs = vocab
with open(args.pre_model_path + '/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n")  # 保存
  • new_vocabs:直接赋值为 vocab
  • 保存词汇表:将词汇表中的每个单词写入 vocab.txt 文件中,文件路径由 args.pre_model_path 指定。

8. 加载预训练模型并调整词汇表大小

model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)  # 模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path + '/pytorch_model.bin')bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)
  • 加载预训练模型:使用 BartForConditionalGeneration.from_pretrained 加载预训练模型。
  • 调整词汇表大小:使用 resize_token_embeddings 方法调整模型的嵌入层大小以适应新的词汇表。
  • 保存模型状态:将模型的状态字典保存到 pytorch_model.bin 文件中,文件路径由 args.pre_model_path 指定。
  • 更新配置:更新 BartConfig 中的 vocab_size 属性,并保存配置。

三、自监督预训练

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        #日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
# os.environ['CUDA_VISIBLE_DEVICES']='0'def train_and_validate(args):# 1. load data  modelmodel = preModel(args)     #加载预训练模型optimizer, scheduler = build_optimizer(args, model)# model = model.to(args.device)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)# model = BalancedDataParallel(16, model, dim=0).to(args.device)# model = model.to(args.device)#-------ema here-----------------all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True,collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs):    #开始训练了for batch in train_dataloader:model.train()loss= model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')def main():args = parse_args()           #设置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)         #LINUXtrain_and_validate(args)if __name__ == '__main__':main()

实现了一个完整的训练和验证流程,包括数据加载、模型初始化、训练循环、日志记录以及模型保存等功能

1. 导入必要的库

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        # 日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
  • PreTrainDataset, loadData, MLM_Data:自定义模块,用于数据处理。
  • DataLoader, Dataset:PyTorch提供的类,用于数据加载和管理。
  • preModel:自定义模型类。
  • logging:用于记录日志信息。
  • os:用于操作系统相关的操作(如文件路径处理)。
  • parse_args:自定义函数,解析命令行参数或配置文件。
  • setup_device, setup_seed, setup_logging, build_optimizer:自定义工具函数,分别用于设置设备、随机种子、日志记录和优化器构建。
  • torch:PyTorch核心库。
  • time:用于时间相关操作。

2. 定义训练和验证函数

def train_and_validate(args):# 1. 加载数据和模型model = preModel(args)     # 加载预训练模型optimizer, scheduler = build_optimizer(args, model)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True, collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs):    # 开始训练了for batch in train_dataloader:model.train()loss = model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')
解释
  • 加载数据和模型

    • 使用 preModel 类加载预训练模型。
    • 使用 build_optimizer 函数构建优化器和学习率调度器。
    • 如果 use_pre 为真,则从指定路径加载预训练模型的权重。
    • 根据 args.deviceargs.paral 参数决定是否使用多GPU并行训练。
  • 数据加载

    • 使用 loadData 函数加载所有数据。
    • 使用 MLM_Data 类将数据转换为适合训练的数据集格式。
    • 使用 DataLoader 创建数据加载器,支持批量加载和数据打乱。
  • 训练循环

    • 对每个epoch进行遍历。
    • 对每个batch进行前向传播计算损失,反向传播更新权重。
    • 记录训练进度和剩余时间,并在特定步数时打印日志。
    • 每隔5个epoch保存一次模型。

3. 主函数

def main():args = parse_args()           # 设置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)         # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
  • main 函数
    • 调用 parse_args 解析命令行参数。
    • 调用 setup_logging 配置日志记录。
    • 调用 setup_devicesetup_seed 分别设置设备和随机种子。
    • 创建保存模型的目录(如果不存在)。
    • 打印训练和评估参数。
    • 调用 train_and_validate 函数开始训练和验证过程。

四、微调

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer,array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES']='0'# 不需要完全理解,  知道每一块在做什么就行   知道之后,  以后再用到, 搬过去就行def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n>0 and tot>n:breaksource = source.cuda()pred = model(source[:, :args. input_l])pred = pred.cpu().detach().numpy()#print(pred.shape)for i in range(pred.shape[0]):# res.append({'image_id':tot, 'caption': [array2str(pred[i][2:], args)]})# gts[tot] = [array2str(targets[i][1:], args)]res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_scoredef train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)#-------ema here-----------------model.train()#-------------------------------# loss, results = validate(model, val_dataloader)# 3. trainingstep = 0best_score = args.best_score     #评估指标  准确率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args. input_l], targets[:, :args.output_l])loss  = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')def main():args = parse_args()setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)train_and_validate(args)if __name__ == '__main__':main()

实现了一个完整的训练和验证流程,包括数据加载、模型初始化、训练循环、验证评估以及模型保存等功能。

1. 导入必要的库

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer, array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES'] = '0'
  • logging:用于记录日志信息。
  • os:用于操作系统相关的操作(如文件路径处理)。
  • time:用于时间相关操作。
  • torch:PyTorch核心库。
  • PretrainedBartModel:来自 transformers 库的预训练模型基类。
  • parse_args:自定义函数,解析命令行参数或配置文件。
  • create_dataloaders:自定义函数,创建数据加载器。
  • myModel:自定义模型类。
  • CiderD, CE:自定义评分函数,分别用于计算CIDEr-D分数和交叉熵损失。
  • setup_device, setup_seed, setup_logging, build_optimizer, array2str:自定义工具函数,分别用于设置设备、随机种子、日志记录、构建优化器和数组转字符串。
  • autocast:用于混合精度训练。
  • tqdm:用于显示进度条。

2. 定义验证函数

def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n > 0 and tot > n:breaksource = source.cuda()pred = model(source[:, :args.input_l])pred = pred.cpu().detach().numpy()for i in range(pred.shape[0]):res.append({'image_id': tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_score
解释
  • 输入参数

    • model: 需要验证的模型。
    • loader: 数据加载器。
    • args: 命令行参数或配置对象。
    • output_file: 输出文件路径(可选)。
    • beam: 束搜索宽度(可选,默认为1)。
    • n: 验证样本数限制(可选,默认为-1,表示不限制)。
  • 逻辑

    • 初始化结果列表 res 和真实标签字典 gts
    • 使用 tqdm 显示进度条遍历数据加载器中的每个批次 (source, targets)
    • source 移动到 GPU 并进行前向传播得到预测结果 pred
    • 将预测结果和真实标签转换为字符串格式并添加到 resgts 中。
    • 使用 CiderD 计算预测结果与真实标签之间的 CIDEr-D 分数。
    • 返回 CIDEr-D 分数。

3. 定义训练和验证函数

def train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)model.train()step = 0best_score = args.best_score  # 评估指标 准确率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args.input_l], targets[:, :args.output_l])loss = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')
解释
  • 加载数据

    • 使用 create_dataloaders 函数加载训练和验证数据加载器。
  • 初始化模型和优化器

    • 使用 myModel 类加载模型。
    • 如果 use_pre 为真,则从指定路径加载预训练模型的权重。
    • 使用 build_optimizer 函数构建优化器和学习率调度器。
    • 将模型移动到指定设备(CPU或GPU)。
  • 训练循环

    • 对每个epoch进行遍历。
    • 对每个batch进行前向传播计算损失,反向传播更新权重。
    • 每个epoch结束后调用 validate 函数计算验证集上的 CIDEr-D 分数。
    • 如果当前 CIDEr-D 分数优于历史最佳分数,则保存模型。

4. 主函数

def main():args = parse_args()  # 设置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)  # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
解释
  • 主函数
    • 调用 parse_args 解析命令行参数。
    • 调用 setup_logging 配置日志记录。
    • 调用 setup_devicesetup_seed 分别设置设备和随机种子。
    • 创建保存模型的目录(如果不存在)。
    • 打印训练和评估参数。
    • 调用 train_and_validate 函数开始训练和验证过程。

五、inference

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_argsdef inference(args):test_loader = create_dataloaders(args,test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'],strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()if __name__ == '__main__':args = parse_args()inference(args)

实现了一个推理(inference)流程,包括数据加载、模型加载、前向传播以及结果保存等功能。

1. 导入必要的库

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_args
  • tqdm:用于显示进度条。
  • csv:用于处理CSV文件的读写操作。
  • to_device:自定义函数,将数据移动到指定设备(CPU或GPU)。
  • array2str:自定义函数,将数组转换为字符串。
  • myModel:自定义模型类。
  • create_dataloaders:自定义函数,创建数据加载器。
  • torch:PyTorch核心库。
  • parse_args:自定义函数,解析命令行参数或配置文件。

2. 定义推理函数

def inference(args):test_loader = create_dataloaders(args, test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'], strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()
解释
  • 加载测试数据

    • 使用 create_dataloaders 函数加载测试数据加载器,设置 test=True 表示加载测试集。
  • 初始化模型并加载权重

    • 使用 myModel 类加载模型。
    • 打印预训练模型路径 args.ckpt_file
    • 使用 torch.load 加载预训练模型的权重,并使用 load_state_dict 方法加载到模型中。
    • 将模型移动到 GPU(cuda:0),并设置为评估模式(model.eval())。
  • 推理过程

    • 打开输出 CSV 文件,并创建 CSV 写入器。
    • 使用 tqdm 显示进度条遍历测试数据加载器中的每个批次 source
    • source 移动到 GPU 并进行前向传播得到预测结果 pred
    • 将预测结果转换为 NumPy 数组,并逐个样本写入 CSV 文件。

3. 主函数

if __name__ == '__main__':args = parse_args()inference(args)
  • 主函数
    • 调用 parse_args 解析命令行参数。
    • 调用 inference 函数开始推理过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/71738.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot API 项目中 HAProxy 与 Nginx 的选择与实践

在开发 Spring Boot 构建的 RESTful API 项目时,负载均衡和反向代理是提升性能与可用性的关键环节。HAProxy 和 Nginx 作为两种流行的工具,经常被用于流量分发,但它们各有侧重。究竟哪一个更适合你的 Spring Boot API 项目?本文将…

Java常用集合与映射的线程安全问题深度解析

Java常用集合与映射的线程安全问题深度解析 一、线程安全基础认知 在并发编程环境下,当多个线程同时操作同一集合对象时,若未采取同步措施,可能导致以下典型问题: 数据竞争:多个线程同时修改数据导致结果不可预测状…

DeepLabv3+改进6:在主干网络中添加SegNext_Attention|助力涨点

🔥【DeepLabv3+改进专栏!探索语义分割新高度】 🌟 你是否在为图像分割的精度与效率发愁? 📢 本专栏重磅推出: ✅ 独家改进策略:融合注意力机制、轻量化设计与多尺度优化 ✅ 即插即用模块:ASPP+升级、解码器 PS:订阅专栏提供完整代码 目录 论文简介 步骤一 步骤二…

使用 Elastic-Agent 或 Beats 将 Journald 中的 syslog 和 auth 日志导入 Elastic Stack

作者:来自 Elastic TiagoQueiroz 我们在 Elastic 一直努力将更多 Linux 发行版添加到我们的支持矩阵中,现在 Elastic-Agent 和 Beats 已正式支持 Debian 12! 本文演示了我们正在开发的功能,以支持使用 Journald 存储系统和身份验…

3.9[A]csd

在传统CPU中心架构中,中央处理器通过内存访问外部存储器,而数据必须经过网络接口卡才能到达外部存储器。这种架构存在集中式计算、DRAM带宽和容量挑战、大量数据移动(服务器内和网络)以及固定计算导致工作负载容量增长等问题。 而…

ESP32S3读取数字麦克风INMP441的音频数据

ESP32S3 与 INMP441 麦克风模块的集成通常涉及使用 I2S 接口进行数字音频数据的传输。INMP441 是一款高性能的数字麦克风,它通过 I2S 接口输出音频数据。在 Arduino 环境中,ESP32S3 的开发通常使用 ESP-IDF(Espressif IoT Development Framew…

DeepSeek大模型 —— 全维度技术解析

DeepSeek大模型 —— 全维度技术解析 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,可以分享一下给大家。点击跳转到网站。 https://www.captainbed.cn/ccc 文章目录 DeepSeek大模型 —— 全维度技术解析一、模型架构全景解析1…

[Kubernetes] 7控制平面组件

1. 调度 kube- scheduler what 负责分配调度pod到集群节点监听kube-apiserver,查询未分配node的pod根据调度策略分配这些pod(更新pod的nodename)需要考虑的因素: 公平调度,资源有效利用,QoS,affinity, an…

PyTorch系列教程:编写高效模型训练流程

当使用PyTorch开发机器学习模型时,建立一个有效的训练循环是至关重要的。这个过程包括组织和执行对数据、参数和计算资源的操作序列。让我们深入了解关键组件,并演示如何构建一个精细的训练循环流程,有效地处理数据处理,向前和向后…

LeetCode Hot100刷题——反转链表(迭代+递归)

206.反转链表 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1]示例 2: 输入:head [1,2] 输出:[2,1]示例 3&#…

机器学习的发展史

机器学习(Machine Learning, ML)作为人工智能(AI)的一个分支,其发展经历了多个阶段。以下是机器学习的发展史概述: 1. 早期探索(20世纪50年代 - 70年代) 1950年:艾伦图…

Springboot redis bitMap实现用户签到以及统计,保姆级教程

项目架构,这是作为demo展示使用: Redis config: package com.zy.config;import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.Ob…

Ardupilot开源无人机之Geek SDK进展2025Q1

Ardupilot开源无人机之Geek SDK进展2025Q1 1. 源由2. 内容汇总2.1 【jetson-fpv】YOLO INT8 coco8 dataset 精度降级2.2 【OpenIPC-Configurator】OpenIPC Configurator 固件升级失败2.3 【OpenIPC-Adaptive-link】OpenIPC RF信号质量相关显示2.4 【OpenIPC-msposd】.srt/.osd…

《云原生监控体系构建实录:从Prometheus到Grafana的观测革命》

PrometheusGrafana部署配置 Prometheus安装 下载Prometheus服务端 Download | PrometheusAn open-source monitoring system with a dimensional data model, flexible query language, efficient time series database and modern alerting approach.https://prometheus.io/…

SpringMvc与Struts2

一、Spring MVC 1.1 概述 Spring MVC 是 Spring 框架的一部分,是一个基于 MVC 设计模式的轻量级 Web 框架。它提供了灵活的配置和强大的扩展能力,适合构建复杂的 Web 应用程序。 1.2 特点 轻量级:与 Spring 框架无缝集成,依赖…

数据类设计_图片类设计之1_矩阵类设计(前端架构基础)

前言 学的东西多了,要想办法用出来.C和C是偏向底层的语言,直接与数据打交道.尝试做一些和数据方面相关的内容 引入 图形在底层是怎么表示的,用C来表示 认识图片 图片是个风景,动物,还是其他内容,人是可以看出来的.那么计算机是怎么看懂的呢?在有自主意识的人工智能被设计出来…

开发者社区测试报告(功能测试+性能测试)

功能测试 测试相关用例 开发者社区功能背景 在当今数字化时代,编程已经成为一项核心技能,越来越多的人开始学习编程,以适应快速变化的科技 环境。基于这一需求,我设计开发了一个类似博客的论坛系统,专注于方便程序员…

EasyRTC嵌入式音视频通话SDK:基于ICE与STUN/TURN的实时音视频通信解决方案

在当今数字化时代,实时音视频通信技术已成为人们生活和工作中不可或缺的一部分。无论是家庭中的远程看护、办公场景中的远程协作,还是工业领域的远程巡检和智能设备的互联互通,高效、稳定的通信技术都是实现这些功能的核心。 EasyRTC嵌入式音…

【OneAPI】网页截图API-V2

API简介 生成指定URL的网页截图或缩略图。 旧版本请参考:网页截图 V2版本新增全屏截图、带壳截图等功能,并修复了一些已知问题。 全屏截图: 支持全屏截图,通过设置fullscreentrue来支持全屏截图。全屏模式下,系统…

简单的 Python 示例,用于生成电影解说视频的第一人称独白解说文案

以下是一个简单的 Python 示例,用于生成电影解说视频的第一人称独白解说文案。这个示例使用了 OpenAI 的 GPT 模型,因为它在自然语言生成方面表现出色。 实现思路 安装必要的库:使用 openai 库与 OpenAI API 进行交互。设置 API 密钥&#…