【深度学习】图像分类数据集

图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。
我们将使用类似但更复杂的Fashion-MNIST数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()#设置图表大小,具体实现过程及其底层逻辑见微积分一节

读取数据集

我们可以[通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中]。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,# 并除以255使得所有像素的数值均在0~1之间trans = transforms.ToTensor()mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

这段代码的主要目的是从 torchvision 库中下载并加载 Fashion - MNIST 数据集,同时对数据进行预处理,将图像转换为 PyTorch 张量。
代码主要分为三个部分:定义图像预处理操作、加载训练集数据、加载测试集数据。下面逐行进行详细解释。

1. 定义图像预处理操作

trans = transforms.ToTensor()

  • 功能:创建一个图像预处理的转换对象 transtransforms.ToTensor()torchvision.transforms 模块里的一个类,专门用于将 PIL(Python Imaging Library)图像或者 NumPy 数组(一般是 uint8 类型)转换为 torch.FloatTensor 类型的张量。
  • 转换细节
    - 在转换过程中,会把图像的像素值归一化到 [0.0, 1.0] 范围。例如,原始图像像素值范围是 [0, 255],经过该转换后,像素值会除以 255,变成 [0.0, 1.0] 之间的浮点数。
    - 同时,转换后张量的维度也会发生变化。对于单通道的灰度图像,会从 (H, W)(高度和宽度)变为 (1, H, W);对于三通道的彩色图像,会从 (H, W, C) 变为 (C, H, W),这里 C 代表通道数。

2. 加载训练集数据

mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_train,用于加载 Fashion - MNIST 数据集的训练集部分。
  • 参数解释
    - root="../data":指定数据集的存储路径。若该路径下没有数据集,下载的数据会存于此;若已存在,则直接从该路径加载数据。
    - train=True:表明要加载的是训练集数据。Fashion - MNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,通过此参数区分加载的是训练集还是测试集。
    - transform=trans:指定对图像数据进行的预处理操作。这里使用之前创建的 trans 对象,即对每个图像应用 ToTensor() 变换,将其转换为张量
    - download=True:如果指定路径下未找到数据集,会自动从网络下载 Fashion - MNIST 数据集。

3. 加载测试集数据

mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_test,用于加载 Fashion - MNIST 数据集的测试集部分。
  • 参数解释:与加载训练集的代码基本相同,唯一区别在于 train=False,表示加载的是测试集数据。

Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像
测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

len(mnist_train), len(mnist_test)

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。
数据集由灰度图像组成,其通道数为1。
为了简洁起见,将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w或( h h h, w w w)。

mnist_train[0][0].shape

在这里插入图片描述
[两个可视化数据集的函数]

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

列表推导式
[expression for item in iterable]

  • expression:对每个 item 进行操作后得到的结果,它将成为新列表中的一个元素。
  • item:从 iterable 中取出的单个元素。
  • iterable:一个可迭代对象,如列表、元组、字符串等。

示例代码

text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
labels = [0, 2, 4]
result = [text_labels[int(i)] for i in labels]
print(result)  # 输出: ['t-shirt', 'pullover', 'coat']

我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

子图坐标轴对象
在 matplotlib 中,一个图形(Figure)可以包含多个子图(Axes),每个子图就是一个独立的绘图区域,子图坐标轴对象(Axes 对象)就代表了这些独立的绘图区域。它可以被看作是一个 “画布”,你可以在这个 “画布” 上进行各种绘图操作,比如绘制线条、散点、柱状图等,还可以设置坐标轴的范围、标签、标题等。

以下是对 show_images 函数的详细解释:

  • def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    • 定义了一个名为 show_images 的函数,用于将一组图像以网格形式展示出来。
    • imgs:是一个包含图像的列表,这些图像可以是 PyTorch 张量,也可以是 PIL(Python Imaging Library)图像对象。
    • num_rows:指定了要展示的图像网格的行数。
    • num_cols:指定了要展示的图像网格的列数。
    • titles:是一个可选参数,类型为列表,用于为每个图像设置对应的标题。如果不提供该参数,则默认不显示标题。
    • scale:同样是可选参数,是一个浮点数,用于调整图像显示的缩放比例,默认值为 1.5。
  • figsize = (num_cols * scale, num_rows * scale):
    • 这行代码根据 num_cols(列数)、num_rows(行数)和 scale(缩放比例)计算出整个图像展示窗口的大小。
    • figsize 是一个元组,第一个元素是窗口的宽度,由列数乘以缩放比例得到;第二个元素是窗口的高度,由行数乘以缩放比例得到。
  • _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    • num_rowsnum_cols 分别指定了子图的行数和列数,也就是图像网格的布局。
    • figsize=figsize 表示使用之前计算好的窗口大小。
    • subplots 函数返回两个值,第一个是 Figure 对象,这里用 _ 占位表示我们不关心这个返回值;第二个是一个包含所有子图坐标轴对象的数组,赋值给 axes
  • axes = axes.flatten()
    • axes 原本是一个二维数组,因为它对应着 num_rows 行和 num_cols 列的子图布局。
    • flatten 方法将这个二维数组转换为一维数组,这样在后续遍历图像和子图时会更加方便。
  • for i, (ax, img) in enumerate(zip(axes, imgs))
    • zip(axes, imgs)axes 数组(包含所有子图坐标轴对象)和 imgs 列表(包含所有要展示的图像)中的元素一一对应地组合起来。
    • enumerate 函数用于为组合后的元素添加索引,i 就是当前元素的索引。
    • 在每次循环中,ax 代表当前子图的坐标轴对象,img 代表当前要展示的图像。
        if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)
  • torch.is_tensor(img) 用于判断当前的 img 是否为 PyTorch 张量。
  • 如果是张量,使用 img.numpy() 将其转换为 NumPy 数组,因为 matplotlibimshow 函数更适合处理 NumPy 数组。然后使用 ax.imshow 函数在当前子图上显示图像。
  • 如果不是张量,说明 img 可能是 PIL 图像对象,直接使用 ax.imshow 函数显示该图像。
        ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)
  • ax.axes.get_xaxis() 获取当前子图的 x 轴对象,set_visible(False) 方法将 x 轴设置为不可见。
  • 同理,ax.axes.get_yaxis() 获取当前子图的 y 轴对象,set_visible(False) 方法将 y 轴设置为不可见。这样可以使图像显示更加简洁,只专注于图像内容。
        if titles:ax.set_title(titles[i])
  • if titles: 检查是否提供了 titles 列表。
  • 如果提供了,使用 ax.set_title 方法为当前子图设置对应的标题,标题从 titles 列表中根据当前索引 i 取出。
    return axes
  • 最后,函数返回 axes 数组,这个数组包含了所有子图的坐标轴对象。返回它的目的是方便在调用该函数后,对图形进行进一步的操作,例如修改坐标轴属性等。

以下是训练数据集中前[几个样本的图像及其相应的标签]。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

在这里插入图片描述

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。
回顾一下,在每次迭代中,数据加载器每次都会[读取一小批量数据,大小为batch_size]。
通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4
#shuffle表示在每个训练周期开始时,对数据集进行随机打乱
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

我们看一下读取训练数据所需的时间。

timer = d2l.Timer()
for X, y in train_iter:continue
f'{timer.stop():.2f} sec'

整合所有组件

现在我们[定义load_data_fashion_mnist函数],用于获取和读取Fashion-MNIST数据集。
这个函数返回训练集和验证集的数据迭代器。
此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]#trans初始化为一个包含transforms.ToTensor()的列表if resize:trans.insert(0, transforms.Resize(resize))#在 trans 列表的开头插入 transforms.Resize(resize) 操作trans = transforms.Compose(trans)#将 trans 列表中的所有变换操作组合成一个完整的变换序列 transmnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)#X.shape表示张量 X 的形状,X.dtype表示张量 X 中元素的数据类型break

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68580.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DeepSeek-R1:多阶段训练提升推理能力

标题:DeepSeek-R1:多阶段训练提升推理能力 文章信息摘要: DeepSeek-R1通过结合监督学习与强化学习的多阶段训练方法,显著提升了大型语言模型的推理能力,尤其在处理复杂数学问题时表现优异。该方法克服了纯强化学习模型…

以创新芯片技术助力科技发展

在当今数字化与智能化浪潮中,芯片作为现代科技的核心,正悄然推动着各个行业的变革。厦门国科安芯科技有限公司专注于高性能芯片的研发与创新,致力于为工业、汽车和商业航天等领域提供高效、可靠的解决方案。以下是国科安芯推出的几款具有代表…

【MySQL — 数据库增删改查操作】深入解析MySQL的 Retrieve 检索操作

Retrieve 检索 示例 1. 构造数据 创建表结构 create table exam1(id bigint, name varchar(20) comment同学姓名, Chinesedecimal(3,1) comment 语文成绩, Math decimal(3,1) comment 数学成绩, English decimal(3,1) comment 英语成绩 ); 插入测试数据 insert into ex…

Ansible自动化运维实战--通过role远程部署nginx并配置(8/8)

文章目录 1、准备工作2、创建角色结构3、编写任务4、准备配置文件(金甲模板)5、编写变量6、编写处理程序7、编写剧本8、执行剧本Playbook9、验证-游览器访问每台主机的nginx页面 在 Ansible 中,使用角色(Role)来远程部…

RNN实现阿尔茨海默症的诊断识别

本文为为🔗365天深度学习训练营内部文章 原作者:K同学啊 一 导入数据 import torch.nn as nn import torch.nn.functional as F import torchvision,torch from sklearn.preprocessing import StandardScaler from torch.utils.data import TensorDatase…

【新春特辑】2025年春节技术展望:蛇年里的科技创新与趋势预测

🔥【新春特辑】2025年春节技术展望:蛇年里的科技创新与趋势预测 📅 发布日期:2025年01月29日(大年初一) 在这个辞旧迎新的美好时刻,我们迎来了充满希望的2025年,也是十二生肖中的蛇…

使用 Docker + Nginx + Certbot 实现自动化管理 SSL 证书

使用 Docker Nginx Certbot 实现自动化管理 SSL 证书 在互联网安全环境日益重要的今天,为站点或应用部署 HTTPS 已经成为一种常态。然而,手动申请并续期证书既繁琐又容易出错。本文将以 Nginx Certbot 为示例,基于 Docker 容器来搭建一个…

C++11新特性之使用using(代替typedef)定义别名

1.介绍 传统的C使用typedef重定义一个类型存在一些限制&#xff0c;例如无法直接重定义一个模版。如下所示。 template <typename Val> struct str_map {typedef std::map<std::string, Val> type; };str_map<int>::type map1; 需要添加额外的类来实现&…

编程题-最长的回文子串(中等)

题目&#xff1a; 给你一个字符串 s&#xff0c;找到 s 中最长的回文子串。 示例 1&#xff1a; 输入&#xff1a;s "babad" 输出&#xff1a;"bab" 解释&#xff1a;"aba" 同样是符合题意的答案。示例 2&#xff1a; 输入&#xff1a;s &…

maven、npm、pip、yum官方镜像修改文档

文章目录 Maven阿里云网易华为腾讯云 Npm淘宝腾讯云 pip清华源阿里中科大华科 Yum 由于各博客繁杂&#xff0c;本文旨在记录各常见镜像官网&#xff0c;及其配置文档。常用镜像及配置可评论后加入 Maven 阿里云 官方文档 setting.xml <mirror><id>aliyunmaven&l…

CNN-GRU卷积门控循环单元时间序列预测(Matlab完整源码和数据)

CNN-GRU卷积门控循环单元时间序列预测&#xff08;Matlab完整源码和数据&#xff09; 目录 CNN-GRU卷积门控循环单元时间序列预测&#xff08;Matlab完整源码和数据&#xff09;预测效果基本介绍CNN-GRU卷积门控循环单元时间序列预测一、引言1.1、研究背景与意义1.2、研究现状1…

HTML-新浪新闻-实现标题-样式1

用css进行样式控制 css引入方式&#xff1a; --行内样式&#xff1a;写在标签的style属性中&#xff08;不推荐&#xff09; --内嵌样式&#xff1a;写在style标签中&#xff08;可以写在页面任何位置&#xff0c;但通常约定写在head标签中&#xff09; --外联样式&#xf…

搜索引擎友好:设计快速收录的网站架构

本文来自&#xff1a;百万收录网 原文链接&#xff1a;https://www.baiwanshoulu.com/14.html 为了设计一个搜索引擎友好的网站架构&#xff0c;以实现快速收录&#xff0c;可以从以下几个方面入手&#xff1a; 一、清晰的目录结构与层级 合理划分内容&#xff1a;目录结构应…

CF1098F Ж-function

【题意】 给你一个字符串 s s s&#xff0c;每次询问给你 l , r l, r l,r&#xff0c;让你输出 s s s l , r sss_{l,r} sssl,r​中 ∑ i 1 r − l 1 L C P ( s s i , s s 1 ) \sum_{i1}^{r-l1}LCP(ss_i,ss_1) ∑i1r−l1​LCP(ssi​,ss1​)。 【思路】 和前一道题一样&#…

C++ 拷贝构造

拷贝构造函数会在以下几种场景中被调用: 1. 用一个对象显式初始化另一个对象。 2. 对象按值传递给函数。 3. 函数按值返回对象。 4. 将对象插入到容器中。 5. 明确调用拷贝构造函数。 1. 当用一个对象显式初始化另一个对象时 MyClass obj1("Hello"); MyClass obj2…

2024年终总结

回顾 今年过年没回老家&#xff0c;趁着有时间&#xff0c;总结一下24年吧。 我把23年看做是打基础的一年&#xff0c;而24年主要是忙于项目的一年&#xff0c;基本上大部分时间都是忙着交付软件&#xff0c;写的一些文章也大部分都是项目中遇到的问题和解决方案&#xff0c;虽…

《哈佛家训》

《哈佛家训》是一本以教育为主题的书籍&#xff0c;旨在通过一系列富有哲理的故事和案例&#xff0c;传递积极的人生观、价值观和教育理念。虽然它并非直接由哈佛大学官方出版&#xff0c;但其内容深受读者喜爱&#xff0c;尤其是在家庭教育和个人成长领域。 以下是《哈佛家训…

[c语言日寄]越界访问:意外的死循环

【作者主页】siy2333 【专栏介绍】⌈c语言日寄⌋&#xff1a;这是一个专注于C语言刷题的专栏&#xff0c;精选题目&#xff0c;搭配详细题解、拓展算法。从基础语法到复杂算法&#xff0c;题目涉及的知识点全面覆盖&#xff0c;助力你系统提升。无论你是初学者&#xff0c;还是…

使用 KNN 搜索和 CLIP 嵌入构建多模态图像检索系统

作者&#xff1a;来自 Elastic James Gallagher 了解如何使用 Roboflow Inference 和 Elasticsearch 构建强大的语义图像搜索引擎。 在本指南中&#xff0c;我们将介绍如何使用 Elasticsearch 中的 KNN 聚类和使用计算机视觉推理服务器 Roboflow Inference 计算的 CLIP 嵌入构建…

深入理解三高架构:高可用性、高性能、高扩展性的最佳实践

引言 在现代互联网环境下&#xff0c;随着用户规模和业务需求的快速增长&#xff0c;系统架构的设计变得尤为重要。为了确保系统能够在高负载和复杂场景下稳定运行&#xff0c;"三高架构"&#xff08;高可用性、高性能、高扩展性&#xff09;成为技术架构设计中的核…