【DKN】(四)train.py

内容

try:  #不用多言, 获得该模块下的model_name函数Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:print(f"{model_name} not included!")exit()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EarlyStopping

class EarlyStopping:def __init__(self, patience=5):self.patience = patience   self.counter = 0self.best_loss = np.Infdef __call__(self, val_loss):"""if you use other metrics where a higher value is better, e.g. accuracy,call this with its corresponding negative value"""# 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它if val_loss < self.best_loss:   #如果评测的损失小于最好的损失,那么就是最好的损失early_stop = Falseget_better = Trueself.counter = 0self.best_loss = val_loss  # 最好的损失 else:get_better = False         #  self.counter += 1if self.counter >= self.patience:early_stop = Trueelse:early_stop = Falsereturn early_stop, get_better  

def latest_checkpoint(directory):

看一看存储的模型路径名称:
在这里插入图片描述

def latest_checkpoint(directory):   #最新的检查点! if not os.path.exists(directory):  #该路径在不在return Noneall_checkpoints = {   #{10000 : ckpt-10000.pth, 11000: ckpt-11000.pth}  这就是最终的结果int(x.split('.')[-2].split('-')[-1]): xfor x in os.listdir(directory)}if not all_checkpoints:   #如果没有checkpoint,就返回空return Nonereturn os.path.join(directory,   #我们选择keys最大的选择all_checkpoints[max(all_checkpoints.keys())])

def train()

log_dir:
在这里插入图片描述

def train():writer = SummaryWriter(  #这里的路径!  runs/DKN/.....log_dir=f"./runs/{model_name}/{datetime.datetime.now().replace(microsecond=0).isoformat()}{'-' + os.environ['REMARK'] if 'REMARK' in os.environ else ''}")if not os.path.exists('checkpoint'):  #如果没有checkpoint,那么就需要在当前目录下创建checkpointos.makedirs('checkpoint')try:pretrained_word_embedding = torch.from_numpy(  #读入预训练单词嵌入np.load('./data/train/pretrained_word_embedding.npy')).float()except FileNotFoundError:pretrained_word_embedding = Noneif model_name == 'DKN':   #如果是DKN模型try:pretrained_entity_embedding = torch.from_numpy(   #如果是DKN,嵌入实体np.load('./data/train/pretrained_entity_embedding.npy')).float()except FileNotFoundError:pretrained_entity_embedding = Nonetry:pretrained_context_embedding = torch.from_numpy(  #预训练上下文嵌入  但是numpy是在CPU上的! np.load('./data/train/pretrained_context_embedding.npy')).float()except FileNotFoundError:pretrained_context_embedding = Nonemodel = Model(config, pretrained_word_embedding,   #创建模型pretrained_entity_embedding,pretrained_context_embedding)print(torch.cuda.device_count())   #这里是自己加的,想要实现并行操作! if torch.cuda.device_count() > 1:   #如果设备数目大于1,那么就并行操作# model.to(device)device_ids = [0, 1]model = torch.nn.DataParallel(model, device_ids=device_ids)model.to(device)# for param in next(model.parameters()):#     print(param, param.device)# print(next(model.parameters()).device)if model_name != 'Exp1':print(model)else:print(models[0])dataset = BaseDataset('data/train/behaviors_parsed.tsv','data/train/news_parsed.tsv', 'data/train/roberta')#获得原数据集print(f"Load training dataset with size {len(dataset)}.")dataloader = iter(   #改成dataloader,并被迭代器包装,使得每次访问只需要next()即可  DataLoader(dataset,   #由于自己原来接触过dataloader所以这里是懂点的,不再解释batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))if model_name != 'Exp1':criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(),lr=config.learning_rate)else:criterion = nn.NLLLoss()  #最大似然函数optimizers = [             #定义优化器torch.optim.Adam(model.parameters(), lr=config.learning_rate)for model in models]start_time = time.time()    #定义开始的时间loss_full = []       #全部损失exhaustion_count = 0  #竭尽全力_count???step = 0   early_stopping = EarlyStopping()  #早点结束,看上面的函数定义checkpoint_dir = os.path.join('./checkpoint', model_name)  #检查点/model_namePath(checkpoint_dir).mkdir(parents=True, exist_ok=True)    #创建checkpoint目录checkpoint_path = latest_checkpoint(checkpoint_dir)  #获得最新的检查点if checkpoint_path is not None:          #开始带入checkpointprint(f"Load saved parameters in {checkpoint_path}")checkpoint = torch.load(checkpoint_path)   #加载检查点,里面的格式是字典类型的early_stopping(checkpoint['early_stop_value'])   #step = checkpoint['step']     #if model_name != 'Exp1':model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])model.train()else:for model in models:   model.load_state_dict(checkpoint['model_state_dict'])  #直接加载模型参数model.train()  for optimizer in optimizers:   #直接加载优化器参数optimizer.load_state_dict(checkpoint['optimizer_state_dict'])for i in tqdm(range(    # epochs * (len(dataset) // config.batch_size + 1)这么多次迭代1,config.num_epochs * len(dataset) // config.batch_size + 1),desc="Training"):try:   #获取小dataloader中的batchminibatch = next(dataloader)# if torch.cuda.device_count() > 1:#     minibatch = torch.nn.DataParallel(minibatch)#     minibatch.to(device)# minibatch.to(device)except StopIteration:  #如果迭代出问题了exhaustion_count += 1tqdm.write(f"Training data exhausted for {exhaustion_count} times after {i} batches, reuse the dataset.")dataloader = iter(DataLoader(dataset,batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))minibatch = next(dataloader)step += 1y_pred = model(minibatch["candidate_news"],  #结算损失, 候选新闻是预测得到的!minibatch["clicked_news"])y = torch.zeros(len(y_pred)).long().to(device)loss = criterion(y_pred, y)loss_full.append(loss.item())  #要保存损失的if model_name != 'Exp1':optimizer.zero_grad()else:for optimizer in optimizers:  #优化器更新权重optimizer.zero_grad()loss.backward()if model_name != 'Exp1':optimizer.step()else:for optimizer in optimizers:optimizer.step()if i % 10 == 0:   #如果10次计算了,那么就写入我们的损失writer.add_scalar('Train/Loss', loss.item(), step)if i % config.num_batches_show_loss == 0:  #写出结果tqdm.write(f"Time {time_since(start_time)}, batches {i}, current loss {loss.item():.4f}, average loss: {np.mean(loss_full):.4f}, latest average loss: {np.mean(loss_full[-256:]):.4f}")if i % config.num_batches_validate == 0:   #(model if model_name != 'Exp1' else models[0]).eval()val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(model if model_name != 'Exp1' else models[0], './data/val',200000)(model if model_name != 'Exp1' else models[0]).train()writer.add_scalar('Validation/AUC', val_auc, step)writer.add_scalar('Validation/MRR', val_mrr, step)writer.add_scalar('Validation/nDCG@5', val_ndcg5, step)writer.add_scalar('Validation/nDCG@10', val_ndcg10, step)tqdm.write(f"Time {time_since(start_time)}, batches {i}, validation AUC: {val_auc:.4f}, validation MRR: {val_mrr:.4f}, validation nDCG@5: {val_ndcg5:.4f}, validation nDCG@10: {val_ndcg10:.4f}, ")#后面的都是如果是最好的效果,就保存模型参数early_stop, get_better = early_stopping(-val_auc)if early_stop:tqdm.write('Early stop.')breakelif get_better:try:torch.save({'model_state_dict': (model if model_name != 'Exp1'else models[0]).state_dict(),'optimizer_state_dict':(optimizer if model_name != 'Exp1' elseoptimizers[0]).state_dict(),'step':step,'early_stop_value':-val_auc}, f"./checkpoint/{model_name}/ckpt-{step}.pth")except OSError as error:print(f"OS error: {error}")

def time_since(since)

def time_since(since):   #运行了多长时间"""Format elapsed time string."""now = time.time()elapsed_time = now - since  #return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))if __name__ == '__main__':# print('Using device:', device)print(f'Training model {model_name}')train()

补充

1. os.listdir() 方法

概述

os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)

它不包括 . 和 … 即使它在文件夹中。

只支持在 Unix, Windows 下使用。

语法

listdir()方法语法格式如下:

os.listdir(path)

参数

path – 需要列出的目录路径

返回值

返回指定路径下的文件和文件夹列表。

实例

#!/usr/bin/python
# -*- coding: UTF-8 -*-import os, sys# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )# 输出所有文件和文件夹
for file in dirs:print (file)

在这里插入图片描述

2. Python replace()方法

描述

Python replace() 方法把字符串中的 old(旧字符串) 替换成 new(新字符串),如果指定第三个参数max,则替换不超过 max 次。

语法

replace()方法语法:

str.replace(old, new[, max])

参数

  • old – 将被替换的子字符串。
  • new – 新字符串,用于替换old子字符串。
  • max – 可选字符串, 替换不超过 max 次

返回值

返回字符串中的 old(旧字符串) 替换成 new(新字符串)后生成的新字符串,如果指定第三个参数max,则替换不超过 max 次。

实例

str = "this is string example....wow!!! this is really string";
print str.replace("is", "was");
print str.replace("is", "was", 3);thwas was string example....wow!!! thwas was really string
thwas was string example....wow!!! thwas is really string

3. datetime测试

print(datetime.datetime.now())   #2021-08-27 09:47:48.748545
print(datetime.datetime.now().replace(microsecond=0))  #2021-08-27 09:48:26
print(datetime.datetime.now().replace(microsecond=0).isoformat())  #2021-08-27T09:49:18

4. NLLLoss 和 CrossEntropyLoss

https://blog.csdn.net/qq_22210253/article/details/85229988

NLLLoss的全称是Negative Log Likelihood Loss,也就是最大似然函数

在图片进行单标签分类时,【注意NLLLoss和CrossEntropyLoss都是用于单标签分类,而BCELoss和BECWithLogitsLoss都是使用与多标签分类。这里的多标签是指一个样本对应多个label.】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/476176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用promise封装ajax_ES6-promise封装AJAX请求

【摘要】ES6-promise封装AJAX请求考必过小编为大家整理了关于ES6-promise封装AJAX请求的信息&#xff0c;希望可以帮助到大家&#xff01;ES6-promise封装AJAX请求标签&#xff1a;const状态码setreject对象响应状态ISErequest// 接口地址:https://api.apiopen.top/getJoke// 1…

REST和SOAP Web Service的比较(写得非常清晰易懂,转载于此)

本文转载自他人的博客&#xff0c;ArcGIS Server 推出了 对 SOAP 和 REST两种接口&#xff08;用接口类型也许并不准确&#xff09;类型的支持,本文非常清晰的比较了SOAP和Rest的区别联系&#xff01;REST似乎在一夜间兴起了&#xff0c;这可能引起一些争议&#xff0c;反对者可…

LeetCode 1249. 移除无效的括号(栈+set / deque)

1. 题目 给你一个由 (、) 和小写字母组成的字符串 s。 你需要从字符串中删除最少数目的 ‘(’ 或者 ‘)’ &#xff08;可以删除任意位置的括号)&#xff0c;使得剩下的「括号字符串」有效。 请返回任意一个合法字符串。 有效「括号字符串」应当符合以下 任意一条 要求&…

【DKN】(七)dataset.py【未完】

内容 里面有的函数在这里https://blog.csdn.net/qq_35222729/article/details/119882362 try:config getattr(importlib.import_module(config), f"{model_name}Config") except AttributeError:print(f"{model_name} not included!")exit()class BaseDa…

php raabitmq中间件_rabbitMQ消息中间件环境配置及原理了解

视频教程一、Docker 入门Docker是什么&#xff1f;Docker 是一个开源的应用容器引擎&#xff0c;你可以将其理解为一个轻量级的虚拟机&#xff0c;开发者可以打包他们的应用以及依赖包到一个可移植的容器中&#xff0c;然后发布到任 何流行的 Linux 机器上。为什么要使用 Docke…

CSS 中的定位:relative,absolute

今天碰到一个定位问题&#xff0c;问题解决不好&#xff0c;于是花了大量的时间&#xff0c;调试了好久&#xff0c;得出了一些结果&#xff1a;1、如果有两个不交叉的盒子位于一个大盒子里面&#xff0c;位于上边的盒子的定位为relative&#xff0c;而下边的那个盒子的定位则是…

【DKN】(六)KCNN.py

内容 import torch import torch.nn as nn import torch.nn.functional as F from src.model.general.attention.additive import AdditiveAttentiondevice torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class KCNN(torch.nn.Module):…

北京精雕现状_6秒精密加工,日本走下神坛,北京精雕也做了一个!

各位社友还记得吗&#xff0c;机械社区之前分享过——日本6秒的精密加工火遍制造业圈子~▲点击上图 查看日本怎么用6s让世界惊奇在一阵惊呼赞叹中&#xff0c;一部分人也表示不服&#xff01;比如&#xff0c;国内一位牛人也展示了他的产品。一起看看视频介绍吧——而近日&…

LeetCode 859. 亲密字符串

1. 题目 给定两个由小写字母构成的字符串 A 和 B &#xff0c;只要我们可以通过交换 A 中的两个字母得到与 B 相等的结果&#xff0c;就返回 true &#xff1b;否则返回 false 。 示例 1&#xff1a; 输入&#xff1a; A "ab", B "ba" 输出&#xff1a…

ASP.Net快速开发新闻系统 在线播放

http://www.so138.com/sov/d19a5913-88cf-4abf-a487-69293bb0c403.html转载于:https://www.cnblogs.com/freedom831215/archive/2009/10/03/1577631.html

【DKN】(五)attention.py

感觉还是挺简单&#xff0c;这里只是方便之后回来瞅瞅 import torch import torch.nn as nn import torch.nn.functional as Fclass Attention(torch.nn.Module):"""Attention Net.Input embedding vectors (produced by KCNN) of a candidate news and all of…

小米扫地机器人充电座指示灯不亮_小米扫地机器人常见问题处理 充电后无法取电怎么办?...

与其他科技领域一样&#xff0c;人工智能领域也得到蓬勃发展。如今人工智能已经无处不在。专家把人工智能比作电力&#xff0c;因为它是一种可能改变各行各业的资源。诚然&#xff0c;每个领域都有一些特别重要的技术&#xff0c;例如随着生活的水平的提高&#xff0c;扫地机器…

Enterprise Library 4.1 快速上手(图)

简介&#xff1a; 关于Enterprise Library 的概念&#xff0c;网上可以很容易的找到&#xff0c;在这里要做的是如何快速的打通Enterprise Library 4.1的使用&#xff0c; 让咱们可以用最短的时间使用起来&#xff0c;并且在需要的时候在此基础上再花时间延伸&#xff0c;这是学…

知识图谱源码详解【八】__init__.py

import torch from src.model.DKN.KCNN import KCNN from src.model.DKN.attention import Attention from src.model.general.click_predictor.DNN import DNNClickPredictor# 就是把整个模型框架梳理到一块了&#xff01; class DKN(torch.nn.Module):"""Deep…

python complex函数def_【Python3】Python函数

1. 函数对象函数是第一类对象&#xff0c;即函数可以当做数据传递可以被引用可以当做参数传递返回值可以是函数可以当做容器类型的元素def foo():print(from foo)def index():print(from index)dic {foo:foo,index:index,}while True:choice input(">>>>>…

追MM与设计模式的有趣见解

Posted on 2007-01-18 12:53 东人EP 阅读(383) 评论(0) 编辑 收藏 引用 所属分类: Design Pattern 追MM与设计模式的有趣见解 创建型模式 1、FACTORY —追MM少不了请吃饭了&#xff0c;麦当劳的鸡翅和肯德基的鸡翅都是MM爱吃的东西&#xff0c;虽然口味有所不同&#xff0c;…

LeetCode 872. 叶子相似的树

1. 题目 请考虑一颗二叉树上所有的叶子&#xff0c;这些叶子的值按从左到右的顺序排列形成一个 叶值序列 。 举个例子&#xff0c;如上图所示&#xff0c;给定一颗叶值序列为 (6, 7, 4, 9, 8) 的树。 如果有两颗二叉树的叶值序列是相同&#xff0c;那么我们就认为它们是 叶…

【十】推荐系统遇到知识图谱RippleNet

RippleNet: Propagating User Preferences on the Knowledge Graph for Recommender Systems 代码&#xff1a; https://github.com/hwwang55/RippleNet 心得 &#xff08;1&#xff09;你需要知道Kg是如何起到作用的&#xff01; KG的形式是什么&#xff01; &#xff08;2&…

桩筏有限元中的弹性板计算_采用PKPM系列JCCAD软件桩筏筏板有限元方法计算的模型参数 -...

*****采用PKPM系列JCCAD软件桩筏筏板有限元方法计算的模型参数******计算模型:弹性地基梁板模型 (桩和土按WINKLER模型)地基基础形式及参照规范:天然地基(地基规范)、常规桩基(桩基规范)上部结构影响(共同作用计算): 网格划分依据:所有底层网格线有限元网格控制边长(m): 2.0 采…

[VC]旋转位图图片的算法函数

网上有很多关于位图旋转的资料,但是讲得很清楚的不多(我没有仔细查找).于是我也写了一个,希望能给向我这样的初学者一点帮助. 第一步,你必须知道位图即BMP格式的文件的结构. 位图(bmp)文件由以下几个部分组成: 1.BITMAPFILEHEADER,它的定义如下: typedef struct tagBITMAPFILEH…