内容
try: #不用多言, 获得该模块下的model_name函数Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:print(f"{model_name} not included!")exit()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EarlyStopping
class EarlyStopping:def __init__(self, patience=5):self.patience = patience self.counter = 0self.best_loss = np.Infdef __call__(self, val_loss):"""if you use other metrics where a higher value is better, e.g. accuracy,call this with its corresponding negative value"""# 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它if val_loss < self.best_loss: #如果评测的损失小于最好的损失,那么就是最好的损失early_stop = Falseget_better = Trueself.counter = 0self.best_loss = val_loss # 最好的损失 else:get_better = False # self.counter += 1if self.counter >= self.patience:early_stop = Trueelse:early_stop = Falsereturn early_stop, get_better
def latest_checkpoint(directory):
看一看存储的模型路径名称:
def latest_checkpoint(directory): #最新的检查点! if not os.path.exists(directory): #该路径在不在return Noneall_checkpoints = { #{10000 : ckpt-10000.pth, 11000: ckpt-11000.pth} 这就是最终的结果int(x.split('.')[-2].split('-')[-1]): xfor x in os.listdir(directory)}if not all_checkpoints: #如果没有checkpoint,就返回空return Nonereturn os.path.join(directory, #我们选择keys最大的选择all_checkpoints[max(all_checkpoints.keys())])
def train()
log_dir:
def train():writer = SummaryWriter( #这里的路径! runs/DKN/.....log_dir=f"./runs/{model_name}/{datetime.datetime.now().replace(microsecond=0).isoformat()}{'-' + os.environ['REMARK'] if 'REMARK' in os.environ else ''}")if not os.path.exists('checkpoint'): #如果没有checkpoint,那么就需要在当前目录下创建checkpointos.makedirs('checkpoint')try:pretrained_word_embedding = torch.from_numpy( #读入预训练单词嵌入np.load('./data/train/pretrained_word_embedding.npy')).float()except FileNotFoundError:pretrained_word_embedding = Noneif model_name == 'DKN': #如果是DKN模型try:pretrained_entity_embedding = torch.from_numpy( #如果是DKN,嵌入实体np.load('./data/train/pretrained_entity_embedding.npy')).float()except FileNotFoundError:pretrained_entity_embedding = Nonetry:pretrained_context_embedding = torch.from_numpy( #预训练上下文嵌入 但是numpy是在CPU上的! np.load('./data/train/pretrained_context_embedding.npy')).float()except FileNotFoundError:pretrained_context_embedding = Nonemodel = Model(config, pretrained_word_embedding, #创建模型pretrained_entity_embedding,pretrained_context_embedding)print(torch.cuda.device_count()) #这里是自己加的,想要实现并行操作! if torch.cuda.device_count() > 1: #如果设备数目大于1,那么就并行操作# model.to(device)device_ids = [0, 1]model = torch.nn.DataParallel(model, device_ids=device_ids)model.to(device)# for param in next(model.parameters()):# print(param, param.device)# print(next(model.parameters()).device)if model_name != 'Exp1':print(model)else:print(models[0])dataset = BaseDataset('data/train/behaviors_parsed.tsv','data/train/news_parsed.tsv', 'data/train/roberta')#获得原数据集print(f"Load training dataset with size {len(dataset)}.")dataloader = iter( #改成dataloader,并被迭代器包装,使得每次访问只需要next()即可 DataLoader(dataset, #由于自己原来接触过dataloader所以这里是懂点的,不再解释batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))if model_name != 'Exp1':criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(),lr=config.learning_rate)else:criterion = nn.NLLLoss() #最大似然函数optimizers = [ #定义优化器torch.optim.Adam(model.parameters(), lr=config.learning_rate)for model in models]start_time = time.time() #定义开始的时间loss_full = [] #全部损失exhaustion_count = 0 #竭尽全力_count???step = 0 early_stopping = EarlyStopping() #早点结束,看上面的函数定义checkpoint_dir = os.path.join('./checkpoint', model_name) #检查点/model_namePath(checkpoint_dir).mkdir(parents=True, exist_ok=True) #创建checkpoint目录checkpoint_path = latest_checkpoint(checkpoint_dir) #获得最新的检查点if checkpoint_path is not None: #开始带入checkpointprint(f"Load saved parameters in {checkpoint_path}")checkpoint = torch.load(checkpoint_path) #加载检查点,里面的格式是字典类型的early_stopping(checkpoint['early_stop_value']) #step = checkpoint['step'] #if model_name != 'Exp1':model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])model.train()else:for model in models: model.load_state_dict(checkpoint['model_state_dict']) #直接加载模型参数model.train() for optimizer in optimizers: #直接加载优化器参数optimizer.load_state_dict(checkpoint['optimizer_state_dict'])for i in tqdm(range( # epochs * (len(dataset) // config.batch_size + 1)这么多次迭代1,config.num_epochs * len(dataset) // config.batch_size + 1),desc="Training"):try: #获取小dataloader中的batchminibatch = next(dataloader)# if torch.cuda.device_count() > 1:# minibatch = torch.nn.DataParallel(minibatch)# minibatch.to(device)# minibatch.to(device)except StopIteration: #如果迭代出问题了exhaustion_count += 1tqdm.write(f"Training data exhausted for {exhaustion_count} times after {i} batches, reuse the dataset.")dataloader = iter(DataLoader(dataset,batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))minibatch = next(dataloader)step += 1y_pred = model(minibatch["candidate_news"], #结算损失, 候选新闻是预测得到的!minibatch["clicked_news"])y = torch.zeros(len(y_pred)).long().to(device)loss = criterion(y_pred, y)loss_full.append(loss.item()) #要保存损失的if model_name != 'Exp1':optimizer.zero_grad()else:for optimizer in optimizers: #优化器更新权重optimizer.zero_grad()loss.backward()if model_name != 'Exp1':optimizer.step()else:for optimizer in optimizers:optimizer.step()if i % 10 == 0: #如果10次计算了,那么就写入我们的损失writer.add_scalar('Train/Loss', loss.item(), step)if i % config.num_batches_show_loss == 0: #写出结果tqdm.write(f"Time {time_since(start_time)}, batches {i}, current loss {loss.item():.4f}, average loss: {np.mean(loss_full):.4f}, latest average loss: {np.mean(loss_full[-256:]):.4f}")if i % config.num_batches_validate == 0: #(model if model_name != 'Exp1' else models[0]).eval()val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(model if model_name != 'Exp1' else models[0], './data/val',200000)(model if model_name != 'Exp1' else models[0]).train()writer.add_scalar('Validation/AUC', val_auc, step)writer.add_scalar('Validation/MRR', val_mrr, step)writer.add_scalar('Validation/nDCG@5', val_ndcg5, step)writer.add_scalar('Validation/nDCG@10', val_ndcg10, step)tqdm.write(f"Time {time_since(start_time)}, batches {i}, validation AUC: {val_auc:.4f}, validation MRR: {val_mrr:.4f}, validation nDCG@5: {val_ndcg5:.4f}, validation nDCG@10: {val_ndcg10:.4f}, ")#后面的都是如果是最好的效果,就保存模型参数early_stop, get_better = early_stopping(-val_auc)if early_stop:tqdm.write('Early stop.')breakelif get_better:try:torch.save({'model_state_dict': (model if model_name != 'Exp1'else models[0]).state_dict(),'optimizer_state_dict':(optimizer if model_name != 'Exp1' elseoptimizers[0]).state_dict(),'step':step,'early_stop_value':-val_auc}, f"./checkpoint/{model_name}/ckpt-{step}.pth")except OSError as error:print(f"OS error: {error}")
def time_since(since)
def time_since(since): #运行了多长时间"""Format elapsed time string."""now = time.time()elapsed_time = now - since #return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))if __name__ == '__main__':# print('Using device:', device)print(f'Training model {model_name}')train()
补充
1. os.listdir() 方法
概述
os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)
它不包括 . 和 … 即使它在文件夹中。
只支持在 Unix, Windows 下使用。
语法
listdir()方法语法格式如下:
os.listdir(path)
参数
path – 需要列出的目录路径
返回值
返回指定路径下的文件和文件夹列表。
实例
#!/usr/bin/python
# -*- coding: UTF-8 -*-import os, sys# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )# 输出所有文件和文件夹
for file in dirs:print (file)
2. Python replace()方法
描述
Python replace() 方法把字符串中的 old(旧字符串) 替换成 new(新字符串),如果指定第三个参数max,则替换不超过 max 次。
语法
replace()方法语法:
str.replace(old, new[, max])
参数
- old – 将被替换的子字符串。
- new – 新字符串,用于替换old子字符串。
- max – 可选字符串, 替换不超过 max 次
返回值
返回字符串中的 old(旧字符串) 替换成 new(新字符串)后生成的新字符串,如果指定第三个参数max,则替换不超过 max 次。
实例
str = "this is string example....wow!!! this is really string";
print str.replace("is", "was");
print str.replace("is", "was", 3);thwas was string example....wow!!! thwas was really string
thwas was string example....wow!!! thwas is really string
3. datetime测试
print(datetime.datetime.now()) #2021-08-27 09:47:48.748545
print(datetime.datetime.now().replace(microsecond=0)) #2021-08-27 09:48:26
print(datetime.datetime.now().replace(microsecond=0).isoformat()) #2021-08-27T09:49:18
4. NLLLoss 和 CrossEntropyLoss
https://blog.csdn.net/qq_22210253/article/details/85229988
NLLLoss的全称是Negative Log Likelihood Loss,也就是最大似然函数。
在图片进行单标签分类时,【注意NLLLoss和CrossEntropyLoss都是用于单标签分类,而BCELoss和BECWithLogitsLoss都是使用与多标签分类。这里的多标签是指一个样本对应多个label.】