内容
里面有的函数在这里https://blog.csdn.net/qq_35222729/article/details/119882362
try:config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:print(f"{model_name} not included!")exit()class BaseDataset(Dataset):def __init__(self, behaviors_path, news_path, roberta_embedding_dir):super(BaseDataset, self).__init__()assert all(attribute in ['category', 'subcategory', 'title', 'abstract', 'title_entities','abstract_entities', 'title_roberta', 'title_mask_roberta','abstract_roberta', 'abstract_mask_roberta'] for attribute in config.dataset_attributes['news']) #数据集的属性应该在这些属性中assert all(attribute in ['user', 'clicked_news_length'] #同上for attribute in config.dataset_attributes['record'])self.behaviors_parsed = pd.read_table(behaviors_path) #读入我们的行为并处理self.news_parsed = pd.read_table( #news_path,index_col='id',usecols=['id'] + config.dataset_attributes['news'],converters={attribute: literal_eval #对某些列执行literal-eval,将某些列转变为原类型,脱层for attribute in set(config.dataset_attributes['news']) & set(['title', 'abstract', 'title_entities', 'abstract_entities','title_roberta', 'title_mask_roberta', 'abstract_roberta','abstract_mask_roberta'])})self.news_id2int = {x: i for i, x in enumerate(self.news_parsed.index)}self.news2dict = self.news_parsed.to_dict('index') for key1 in self.news2dict.keys():for key2 in self.news2dict[key1].keys():self.news2dict[key1][key2] = torch.tensor(self.news2dict[key1][key2])padding_all = {'category': 0,'subcategory': 0,'title': [0] * config.num_words_title,'abstract': [0] * config.num_words_abstract,'title_entities': [0] * config.num_words_title,'abstract_entities': [0] * config.num_words_abstract,'title_roberta': [0] * config.num_words_title,'title_mask_roberta': [0] * config.num_words_title,'abstract_roberta': [0] * config.num_words_abstract,'abstract_mask_roberta': [0] * config.num_words_abstract}for key in padding_all.keys():padding_all[key] = torch.tensor(padding_all[key])self.padding = {k: vfor k, v in padding_all.items()if k in config.dataset_attributes['news']}def _news2dict(self, id):ret = self.news2dict[id]if model_name == 'Exp2' and not config.fine_tune:for k in set(config.dataset_attributes['news']) & set(['title', 'abstract']):ret[k] = self.roberta_embedding[k][self.news_id2int[id]]return retdef __len__(self):return len(self.behaviors_parsed)def __getitem__(self, idx): #返回单个itemitem = {}row = self.behaviors_parsed.iloc[idx]if 'user' in config.dataset_attributes['record']:item['user'] = row.useritem["clicked"] = list(map(int, row.clicked.split()))item["candidate_news"] = [self._news2dict(x) for x in row.candidate_news.split()]item["clicked_news"] = [self._news2dict(x)for x in row.clicked_news.split()[:config.num_clicked_news_a_user]]if 'clicked_news_length' in config.dataset_attributes['record']:item['clicked_news_length'] = len(item["clicked_news"])repeated_times = config.num_clicked_news_a_user - \len(item["clicked_news"])assert repeated_times >= 0item["clicked_news"] = [self.padding] * repeated_times + item["clicked_news"]return item
补充
1. ast.literal_eval
Python中,如果要将字符串型的list,tuple,dict转变成原有的类型呢?这个时候你自然会想到eval. eval函数在Python中做数据类型的转换还是很有用的。它的作用就是把数据还原成它本身或者是能够转化成的数据类型
string <=> list
In [1]: s = '[1, 2, 3, 4]'In [2]: l = eval(s)In [3]: s
Out[3]: '[1, 2, 3, 4]'In [4]: l
Out[4]: [1, 2, 3, 4]In [5]: type(s)
Out[5]: strIn [6]: type(l)
Out[6]: list