内容
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.general.attention.additive import AdditiveAttentiondevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class KCNN(torch.nn.Module):"""Knowledge-aware CNN (KCNN) based on Kim CNN.Input a news sentence (e.g. its title), produce its embedding vector."""def __init__(self, config, pretrained_word_embedding,pretrained_entity_embedding, pretrained_context_embedding):#前面是单纯的定义: 获取单词嵌入、实体嵌入和上下文嵌入的预训练参数(不只是历史点击新闻还有候选新闻的)super(KCNN, self).__init__()self.config = configif pretrained_word_embedding is None: #如果预训练单词嵌入是空,那么就需要用集成在nn.Embedding()的函数了self.word_embedding = nn.Embedding(config.num_words,config.word_embedding_dim,padding_idx=0)else:self.word_embedding = nn.Embedding.from_pretrained(pretrained_word_embedding, freeze=False, padding_idx=0)if pretrained_entity_embedding is None:self.entity_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.entity_embedding = nn.Embedding.from_pretrained(pretrained_entity_embedding, freeze=False, padding_idx=0)if config.use_context:if pretrained_context_embedding is None:self.context_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.context_embedding = nn.Embedding.from_pretrained(pretrained_context_embedding, freeze=False, padding_idx=0)self.transform_matrix = nn.Parameter(torch.empty(self.config.entity_embedding_dim,self.config.word_embedding_dim).uniform_(-0.1, 0.1))self.transform_bias = nn.Parameter(torch.empty(self.config.word_embedding_dim).uniform_(-0.1, 0.1))self.conv_filters = nn.ModuleDict({str(x): nn.Conv2d(3 if self.config.use_context else 2,self.config.num_filters,(x, self.config.word_embedding_dim))for x in self.config.window_sizes})self.additive_attention = AdditiveAttention(self.config.query_vector_dim, self.config.num_filters)def forward(self, news):"""Args:news:{"title": batch_size * num_words_title,"title_entities": batch_size * num_words_title}Returns:final_vector: batch_size, len(window_sizes) * num_filters"""# batch_size, num_words_title, word_embedding_dimword_vector = self.word_embedding(news["title"].to(device))#获得单词向量 需要放到设备上的# batch_size, num_words_title, entity_embedding_dim entity_vector = self.entity_embedding( #获得实体向量news["title_entities"].to(device))if self.config.use_context: #用上下文的话就得获得上下文的向量# batch_size, num_words_title, entity_embedding_dimcontext_vector = self.context_embedding(news["title_entities"].to(device))# batch_size, num_words_title, word_embedding_dimtransformed_entity_vector = torch.tanh( #转换矩阵是将其中某些词替换掉! torch.add(torch.matmul(entity_vector, self.transform_matrix),self.transform_bias))if self.config.use_context: # batch_size, num_words_title, word_embedding_dimtransformed_context_vector = torch.tanh(torch.add(torch.matmul(context_vector, self.transform_matrix),self.transform_bias))# batch_size, 3, num_words_title, word_embedding_dimmulti_channel_vector = torch.stack([word_vector, transformed_entity_vector,transformed_context_vector], dim=1) #获得最终的concat向量else:# batch_size, 2, num_words_title, word_embedding_dimmulti_channel_vector = torch.stack([word_vector, transformed_entity_vector], dim=1)pooled_vectors = [] #for x in self.config.window_sizes: # window_size = 3 # batch_size, num_filters, num_words_title + 1 - xconvoluted = self.conv_filters[str(x)]( #后面就是卷积常规操作! 分别进行3种window_size的卷积multi_channel_vector).squeeze(dim=3)# batch_size, num_filters, num_words_title + 1 - xactivated = F.relu(convoluted)# batch_size, num_filters# Here we use a additive attention module# instead of pooling in the paperpooled = self.additive_attention(activated.transpose(1, 2))# pooled = activated.max(dim=-1)[0]# # or# # pooled = F.max_pool1d(activated, activated.size(2)).squeeze(dim=2)pooled_vectors.append(pooled)# batch_size, len(window_sizes) * num_filtersfinal_vector = torch.cat(pooled_vectors, dim=1)return final_vector
说明
最后的卷积有必要说一下