对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型QWen-VL的一条示例数据可能如下所示:
{"input": "Picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>这是什么?","output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}
由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用IterableDataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
基于IterableDataset的数据加载,代码实现如下:
import torch
from torch.utils.data import IterableDatasetclass MyIterableDataset(IterableDataset):def __init__(self, data_file):self.data_file = data_filedef __iter__(self):return iter(self._load_data())def _load_data(self):with open(self.data_file, 'r') as file:for line in file:sample = process_line(line)yield sampledef process_line(self, line):# Process the line to convert it to a sample...return sample# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)for batch in dataloader:# Train your model using the batch of datapass
在实际训练中还会遇到两个问题:
- 大模型一般需要使用多机多卡训练,需要避免多个进程中dataloader读取数据的竞争,并保证不同进程之间不会重复读取数据;
- 数据文件中某些行无法正确被解析,或者引用的外部资源找不到,导致process_line成员函数报错。数据集需要handle这类错误,防止因为报错中断训练。
以上问题对策如下:
- 在多机多卡的DDP训练中,可以使用DistributedSampler来处理多进程读数据的情形。DistributedSampler可以确保不同进程之间不会重复读取数据。具体的代码实现如下:
# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)# Create a DistributedSampler
sampler = DistributedSampler(dataset)# Create a DataLoader using the DistributedSampler
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)for batch in dataloader:# Train your model using the batch of datapass
- 可以在调用process_line的时候试图handle一个错误,如果出错就跳过这条数据,改为(试图)获取下一条数据。具体的代码实现如下:
import torch
import logger
from torch.utils.data import IterableDatasetclass MyIterableDataset(IterableDataset):def __init__(self, data_file):self.data_file = data_filedef __iter__(self):return iter(self._load_data())def _load_data(self):with open(self.data_file, 'r') as file:for line in file:try:sample = process_line(line)yield sampleexcept Exception as e:# Print the detailed error informationlogger.error(line)logger.error(e)passdef process_line(self, line):# Process the line to convert it to a sample...return sample
如果使用的是普通的Dataset,则参考以下代码,在__getitem__里面加入报错逻辑:
class MyDataset(Dataset):def __init__(self, file_path):self.data = []with open(file_path, 'r') as file:for line in file:self.data.append(line)def __len__(self):return len(self.data)def __getitem__(self, index):line = self.data[index]try:sample = self.process_line(line)return sampleexcept Exception as e:# Print the detailed error informationlogger.error(line)logger.error(e)return self.__getitem__((index+1) % self.__len__())def process_line(self, line):# Process the line to convert it to a sample...return sample