【PyTorch与深度学习】5、深入剖析PyTorch DataLoader源码

课程地址
最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,此节课很详细,笔记记的比较粗

1. DataLoader

1.1 DataLoader类实现

1.1.1 构造函数__init__实现

构造函数有如下参数:

  • dataset:传入自己定义好的数据集类Dataset
  • batch_size:默认值为1,它代表着每批次训练的样本的个数
  • shuffle:布尔类型,True为打乱数据集,False为不打乱数据集
  • sampler:决定以何种方式对数据进行采样,可以不用shuffle随机打乱样本,可以用自己编写的函数去决定如何取样本,比如:你想让你的样本以一种有序的方式来组织成mini-batch,比如把长度比较接近的样本放入到一个mini-batch中,这个时候就不能用shuffle,因为一打乱,这些样本的长度就是乱的。如果传入该参数,则shuffle就没有意义。
  • batch_sampler:可以用自己编写的函数成批次地取样本。如果传入该参数,则shuffle就没有意义。
  • num_workers:默认值为0,它是指数据加载的子进程数量,以加快数据加载的速度,提高训练效率。一般数值设定取决于CPU的核心数,通常数字大到一定程度,其加载速度也不会再提高了。
  • collate_fn:聚集函数,它是对一个批次batch进行后处理,比如:我们通过shuffle打乱后得到一个批次batch,然后对这个batch我们希望对它进行一个pad,但是这个pad的长度只能通过batch去算出来,而不是预先能计算出长度,这个时候我们就要用到collate_fn参数,对之前的shuffle后的mini-batch再处理一下,把这个批次batch给它pad成一样的长度,然后再返回一个新的批次batch。

【注】在深度学习和自然语言处理(NLP)等领域中,pad(填充)是一个常见的预处理步骤,特别是在处理变长序列(如文本、时间序列等)时。当使用DataLoader从数据集中批量提取数据时,如果每个数据项(例如,句子或时间序列)的长度不同,那么为了能够在同一批次中进行高效计算(例如,通过矩阵运算),我们通常需要将这些数据项填充(或截断)到相同的长度。
这就是collate_fn参数发挥作用的地方。默认情况下,DataLoader使用了一个内置的collate_fn来将一批数据项组合成一个张量(tensor),但这个默认函数并不进行填充。为了进行填充,你需要提供一个自定义的collate_fn。

  • pin_memory:布尔类型,默认值为False,用于指定是否将数据加载到固定的内存区域(pinned memory)中。固定内存区域是指一块被操作系统锁定的内存,这样可以防止它被移动,从而提高数据传输的效率。当pin_memory参数设置为True时,PyTorch会尝试将从数据集加载的数据存储在固定的内存中,这对于GPU加速的情况下可以提高数据传输效率,因为GPU可以直接从固定内存中访问数据,而不需要进行额外的内存拷贝操作。需要注意的是,只有当你使用GPU进行训练时,才会考虑使用pin_memory参数。对于CPU训练来说,pin_memory参数的影响通常不太明显。而且这个东西对训练速度的影响还有待考究。
  • drop_last:布尔类型,默认为False,如果你的总样本数目不是每个批次batch的整数倍的话,这时候我们可以将drop_last设置为True,让最后那个小批次(样本数没达到batch-size的批次)丢掉。

构造函数的具体代码和注释如下:

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,pin_memory: bool = False, drop_last: bool = False,timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,multiprocessing_context=None, generator=None,*, prefetch_factor: int = 2,persistent_workers: bool = False):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')if num_workers == 0 and prefetch_factor != 2:raise ValueError('prefetch_factor option could only be specified in multiprocessing.''let num_workers > 0 to enable multiprocessing.')assert prefetch_factor > 0if persistent_workers and num_workers == 0:raise ValueError('persistent_workers option needs num_workers > 0')# 设置成员函数self.dataset = datasetself.num_workers = num_workersself.prefetch_factor = prefetch_factorself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_context# 这里不用看,一般我们都是用Dataset类,而不是IterableDataset,所以直接看这个if条件后面对应的else条件if isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterableif isinstance(dataset, IterDataPipe):torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)elif shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))if sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))# 直接跳到else条件else:# 设置数据集的种类是DatasetKind.Map类型self._dataset_kind = _DatasetKind.Map# 如果你设置了sampler(默认为None),如果你传入了自定义的sampler且shuffle设置为True的话,这种情况是没有意义的,shuffle是官方提供的一种随机采用党的sampler,你都自定义sampler了,就不需要shuffle来随机打乱。所以shuffle和sampler是互斥的,不能同时去设置if sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')# batch_sampler是批次级别的采样,sampler是样本级的采样,if batch_sampler is not None:# 如果你设置了batch_size不是1,或者你设置了shuffle或者你设置了sampler,或者你设置了drop_last,这些都与batch_sampler是互斥的,总结一句话就是:你只要设置了batch_sampler就不需要设置batch_size了,因为你设置了batch_sampler就已经告诉PyTorch框架你的batch_size和以什么样的方式去构成mini-batchif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = False# 如果batch_size是None,同时如果有drop_last,这时候会报错elif batch_size is None:# no auto_collationif drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with drop_last')# 如果你没有设置sampler的话if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else:  # map-style(常用的),如果你设置了shuffle的话,它就会用内置的一个叫random sample的类来去对我们这个Dataset进行一个随机的打乱。具体实现在下面的章节if shuffle:sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]# 如果没有设置shuffle为True的话,它就用SequentialSampler即按原本的顺序来采样else:sampler = SequentialSampler(dataset)  # type: ignore[arg-type]# 如果你的batch_size不是None并且batch_sampler也不是None# 它就默认给你构造一个batch_sampler# BatchSampler源码实现见下面的章节if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerself.generator = generator# 如果collate_fn参数为None,则如果设置了auto_collatoion,就调用默认的default_collateif collate_fn is None:# _auto_collation是根据batch_sampler是否为None来去设置的,如果batch_sampler不是None,_auto_collation设置为True,如果batch_sampler是None的话,它就会调用_utils.collate.default_convert这个函数,否则调用_utils.collate.default_collate函数。# _utils.collate.default_collate函数是以batch作为输入,它相当于什么都没做,最后返回了个batch,如果自己要实现这个collate_fn,要以batch做输入,然后再做处理。if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.persistent_workers = persistent_workersself.__initialized = Trueself._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ]self._iterator = Noneself.check_worker_number_rationality()torch.set_vital('Dataloader', 'enabled', 'True')  # type: ignore[attr-defined]

1.1.2 _get_iterator函数

    def _get_iterator(self) -> '_BaseDataLoaderIter':# 如果设置num_workers为0的话,它就走单个样本处理过程if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:# 如果num_workers不为0,说明是多进程读取样本self.check_worker_number_rationality()return _MultiProcessingDataLoaderIter(self)

一般迭代用,是在__iter__方法中实现的,使得DataLoader能变成一个可迭代的对象。

1.2 RandomSampler 类的实现

重点看中文注释

class RandomSampler(Sampler[int]):r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.If with replacement, then user can specify :attr:`num_samples` to draw.Args:data_source (Dataset): dataset to sample fromreplacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``num_samples (int): number of samples to draw, default=`len(dataset)`.generator (Generator): Generator used in sampling."""data_source: Sizedreplacement: booldef __init__(self, data_source: Sized, replacement: bool = False,num_samples: Optional[int] = None, generator=None) -> None:self.data_source = data_sourceself.replacement = replacementself._num_samples = num_samplesself.generator = generatorif not isinstance(self.replacement, bool):raise TypeError("replacement should be a boolean value, but got ""replacement={}".format(self.replacement))if not isinstance(self.num_samples, int) or self.num_samples <= 0:raise ValueError("num_samples should be a positive integer ""value, but got num_samples={}".format(self.num_samples))@propertydef num_samples(self) -> int:# dataset size might change at runtimeif self._num_samples is None:return len(self.data_source)return self._num_samples# 首先看__iter__方法def __iter__(self) -> Iterator[int]:# 获取数据集的大小n = len(self.data_source)# 如果没有传入generator的话,他就会随机生成一个种子,去构建一个生成器generatorif self.generator is None:# 设置随机数的种子seed = int(torch.empty((), dtype=torch.int64).random_().item())generator = torch.Generator()generator.manual_seed(seed)else:generator = self.generatorif self.replacement:for _ in range(self.num_samples // 32):yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()# 返回0到n-1的列表的随机组合,n是数据集长度yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()else:for _ in range(self.num_samples // n):yield from torch.randperm(n, generator=generator).tolist()yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]def __len__(self) -> int:return self.num_samples

1.3 SequentialSampler类的实现

class SequentialSampler(Sampler[int]):r"""Samples elements sequentially, always in the same order.Args:data_source (Dataset): dataset to sample from"""data_source: Sizeddef __init__(self, data_source: Sized) -> None:self.data_source = data_source# 如果迭代它,返回的是有序的索引def __iter__(self) -> Iterator[int]:return iter(range(len(self.data_source)))def __len__(self) -> int:return len(self.data_source)

1.4 BatchSampler类的实现

也是直接看__iter__函数

class BatchSampler(Sampler[List[int]]):def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:# Since collections.abc.Iterable does not check for `__getitem__`, which# is one way for an object to be an iterable, we don't do an `isinstance`# check here.if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \batch_size <= 0:raise ValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))if not isinstance(drop_last, bool):raise ValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_last# 先看iter函数def __iter__(self) -> Iterator[List[int]]:# 先创建一个空列表batchbatch = []# 对sampler进行一个迭代,去元素的索引for idx in self.sampler:# 将其索引添加到列表中batch.append(idx)# 如果列表长度等于batch_size,这时候就返回列表,相当于返回一个批次batch,然后把batch置为空if len(batch) == self.batch_size:yield batchbatch = []# 如果drop_last(是否丢弃最后的不够一个批次数量的元素)设置为False,那我们就把最后这个不够数量的批次也返回if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self) -> int:# Can only be called if self.sampler has __len__ implemented# We cannot enforce this condition, so we turn off typechecking for the# implementation below.# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]if self.drop_last:return len(self.sampler) // self.batch_size  # type: ignore[arg-type]else:return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]

1.5 其他

这个UP讲的太详细了,没全记录,部分细节可以看看视频

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/7944.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3第二十六节(Hooks 封装注意事项)

1、什么是 Hooks Hooks 最先提出的是React&#xff0c;在 React 16 之后提出了所有以use 开头定义的函数&#xff0c;用于复杂功能编写、函数组件中状态管理共用、副作用处理而抽离的共用的单一功能可复用的函数&#xff1b; 2、Hooks 与 mixins Class 在应用中的差异 在vue…

网络安全的未来:挑战、策略与创新

引言&#xff1a; 在数字化时代&#xff0c;网络安全已成为个人和企业不可忽视的议题。随着网络攻击的日益频繁和复杂化&#xff0c;如何有效保护数据和隐私成为了一个全球性的挑战。 一、网络安全的现状与挑战 网络安全面临的挑战多种多样&#xff0c;包括但不限于恶意软件、…

学习通下载PDF资源

今天突然发现&#xff0c;学习通的pdf资源居然是没有下载入口的&#xff0c;这整的我想cv一下我的作业都搞不了&#xff0c;于是我一怒之下&#xff0c;怒了一下。 可以看到学习通的pdf资源是内嵌在网页的&#xff0c;阅读起来很不方便&#xff0c;虽然他内置了阅读器&#xf…

Liunx查找过滤

目录 一.查找 基本用法 常见条件 搭配逻辑逻辑运算符&#xff08;综合&#xff09; 1.根据文件名 2.根据文件类型 3.根据文件大小 4.根据时间 5.根据权限 基本权限格式 找到后处理的动作与处理 1. -exec 2.xargs -exec和xargs的区别 扩展查找 1.which 应用…

SpringBoot中这样用ObjectMapper

每次new一个单例化个性化配置小结 你要说他有问题吧&#xff0c;确实能正常执行&#xff1b;可你要说没问题吧&#xff0c;在追求性能的同学眼里&#xff0c;这属实算是十恶不赦的代码了。 首先&#xff0c;让我们用JMH对这段代码做一个基准测试&#xff0c;让大家对其性能有个…

蓝桥杯 试题 算法训练 找数2

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 找数2 【问题描述】在一个小到大的有序序列中&#xff08;不存在重复数字&#xff09;&#xff0c;查找某个数所在的位置。如果该数不…

使用 Kubeadm 搭建个公网 k8s 集群(单控制平面集群)

前言 YY&#xff1a;国庆的时候趁着阿里云和腾讯云的轻量级服务器做促销一不小心剁了个手&#x1f60e;&#x1f622;&#xff0c;2 Cores&#xff0c;4G RAM 还是阔以的&#xff0c;既然买了&#xff0c;那不能不用呀&#x1f6a9;&#xff0c;之前一直想着搭建个 k8s 集群玩…

【Git】Git学习-14:VSCode中使用git

学习视频链接&#xff1a;【GeekHour】一小时Git教程_哔哩哔哩_bilibili​编辑https://www.bilibili.com/video/BV1HM411377j/?vd_source95dda35ac10d1ae6785cc7006f365780 在vscode中打开文件 code . 自行修改内容&#xff0c;在源代码管理器中测试下

百度大模型文心一言api 请求错误码 一览表

错误码说明 千帆大模型平台API包含两类&#xff0c;分别为大模型能力API和大模型平台管控API&#xff0c;具体细分如下&#xff1a; 大模型能力API 对话Chat续写Completions向量Embeddings图像Images 大模型平台管控API 模型管理Prompt工程服务管理模型精调数据管理TPM&RP…

卷价格不如卷工艺降本增效狠抓模块规范化设计

俗话说&#xff0c;“卷价格不如卷工艺”&#xff0c;这意味着在追求成本控制和效率提升的过程中&#xff0c;蓝鹏的领导认为蓝鹏应该更注重工艺的优化和创新&#xff0c;而不仅仅是价格的竞争。而模块规范化设计正是实现这一目标的有效途径。 模块规范化设计可以提高生产效率…

小红书高级电商运营课,从0开始做小红书电商(18节课)

详情介绍 课程内容&#xff1a; 第1节课:学习流程以及后续实操流程注意事项,mp4 第2节课:小红书店铺类型解析以及开店细节.mp4 第3节课:小红书电商运营两种玩法之多品店铺解析,mp4 第4节课:小红书电商运营两种玩法之单品店铺解析,mp4 第5节课:选品课(多品类类目推荐).mp4 …

如何获得一个Oracle 23ai数据库(RPM安装)

准确的说&#xff0c;是Oracle 23ai Free Developer版&#xff0c;因为企业版目前只在云上&#xff08;OCI和Azure&#xff09;和ECC上提供。 方法包括3种&#xff0c;本文介绍第2种&#xff1a; Virtual ApplianceRPM安装Docker RPM安装支持Linux 8和Linux 9。由于官方的Vi…

Study--Oracle-02-单实例部署Oracle19C

一、CentOS 7 环境准备 1、软件准备 操作系统&#xff1a;CentOS 7 数据库版本: Oracle19C 2、操作系统环境配置 关闭selinux &#xff0c;编辑 /etc/selinux/config文件&#xff0c;设置SELINUX enforcing 为SELINUXdisabled [rootoracle ~]# grep SELINUX /etc/seli…

基于AbstractRoutingDataSource的mybatis动态多数据源切换

1.pom mybatis-starter版本只能选2开头的版本&#xff0c;选3开头的就报错 <!--druid连接池--> <dependency><groupId>com.alibaba</groupId><artifactId>druid-spring-boot-starter</artifactId><version>1.2.3</version> …

0基础学汽车 丝滑炫酷摄影教学,手机剪映特效剪辑(66节也就是)

详情介绍 0基础学汽车丝滑炫酷摄影教学,手机剪映特效剪辑 课程内容:0基础学汽车 丝滑炫酷摄影教学,手机剪映特效剪辑(66节也就是) - 百创网-源码交易平台_网站源码_商城源码_小程序源码 01 AE课前基础知识(必看).mp4 02 怎样制作德关一样丝滑的作品第二种方法(苹果mac专…

【busybox记录】【shell指令】paste

目录 内容来源&#xff1a; 【GUN】【paste】指令介绍 【busybox】【paste】指令介绍 【linux】【paste】指令介绍 使用示例&#xff1a; 合并文件的行 - 默认输出&#xff08;默认是行合并&#xff09; 合并文件的行 - 一个文件占一行 合并文件的行 - 使用指定的间隔符…

Python Web框架Django项目开发实战:多用户内容发布系统

注意:本文的下载教程,与以下文章的思路有相同点,也有不同点,最终目标只是让读者从多维度去熟练掌握本知识点。 下载教程:Python项目开发Django实战-多用户内容发布系统-编程案例解析实例详解课程教程.pdf 一、引言 在Web应用开发中,内容发布系统是一个常见的需求。这类系…

5000A信号发生器使用方法

背景 gnss工作需要使用的5000A&#xff0c;所以做成文档&#xff0c;用于其他员工学习。 下载星历数据 https://cddis.nasa.gov/archive/gnss/data/daily/2024/brdc/ 修改daily中的年份&#xff0c;就可以获取相关截至时间的星历数据 brcd数据格式 第一行记录了卫星的PRN号&a…

yolo-world:”目标检测届大模型“

AI应用开发相关目录 本专栏包括AI应用开发相关内容分享&#xff0c;包括不限于AI算法部署实施细节、AI应用后端分析服务相关概念及开发技巧、AI应用后端应用服务相关概念及开发技巧、AI应用前端实现路径及开发技巧 适用于具备一定算法及Python使用基础的人群 AI应用开发流程概…

科技早报 | 微软将推出自研AI大模型;苹果折叠屏iPhone新专利获批 | 最新快讯

微软将推出自研AI大模型 5月6日消息&#xff0c;据The Information报道&#xff0c;微软正在公司内部训练一个新的人工智能模型&#xff0c;规模足以与谷歌、Anthropic&#xff0c;乃至OpenAI 自身的先进大模型相抗衡。 报道称&#xff0c;这个新模型内部代号为“MAI-1”&…