厦门网站建设开发如何调整
厦门网站建设开发,如何调整,桥东区网站建设,嵌入式软件开发是什么专业背景介绍
在机器学习的训练数据集中#xff0c;我们经常使用多批次的训练来实现更好的训练效果#xff0c;具体到cv领域#xff0c;我们的训练数据集通常是[B,C,W,H]格式#xff0c;其中#xff0c;B是每个训练批次的大小#xff0c;C是图片的通道数#xff0c;如果是1…背景介绍
在机器学习的训练数据集中我们经常使用多批次的训练来实现更好的训练效果具体到cv领域我们的训练数据集通常是[B,C,W,H]格式其中B是每个训练批次的大小C是图片的通道数如果是1则为灰度图像如果是3则为彩色图像W,H分别是图像的像素宽和像素高在torchvision中为我们提供了方便的方法显示多通道的图像显示成网格的格式
数据集介绍
这里使用机器学习中经典的CIFAR10数据集具体可以参考博客CIFAR-10数据集详解与可视化_cifar10数据集可视化-CSDN博客
数据集读取
我们假设已经下载好CIFAR数据集保存在本地计算机的路径中可以通过CIFAR函数进行读取
# 依赖的库环境
import torchvision
import torch
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor,Compose,Resize
读取CIFAR数据集中的训练数据集
train_dataset CIFAR10(rD:\deep_learning\12_16\data, trainTrue, downloadFalse,transformToTensor())
这里的转换方式是使用简单的ToTensor()将图片格式转换成经典的[C,W,H]格式方便后续的可视化操作
此时我们可以简单地对数据集中的第一张图片进行可视化
img,label train_dataset[0]
plt.imshow(img.permute(1,2,0))
plt.show()构造批次数据集
如何构造批次的训练数据集呢可以通过DataLoader的方式获得批次生成器也可以通过torch.stack函数自定义地构成
cifar_img torch.stack([train_dataset[i][0] for i in range(4)], dim0)
这里使用列表推导式获得前4张图片组成的数据列表通过torch.stack指定dim0进行多个数据的堆加这里需要注意的是stack是在指定的维度新增一个维度进行多矩阵的合并cat是在指定的维度上合并多个矩阵而不增加新的维度
cat与stack的区别
我们来具体看看两者的区别
cat_img torch.cat([train_dataset[i][0] for i in range(4)],dim0)
stack_img torch.stack([train_dataset[i][0] for i in range(4)],dim0)
print(fcat_shape:{cat_img.shape})
print(fstack_shape:{stack_img.shape})
cat_shape:torch.Size([12, 32, 32])
stack_shape:torch.Size([4, 3, 32, 32])
train_dataset[i][0]的形状为[3,32,32]当使用cat时直接在第一维度上进行累加获得[12,32,32]使用stack时在指定的第一维度上新增一个维度进行累加有[4,3,32,32]
进行网格化显示
使用torchvision.utils.make_grid函数进行网格格式转换
train_dataset CIFAR10(rD:\deep_learning\12_16\data, trainTrue, downloadFalse,transformToTensor())
cifar_img torch.stack([train_dataset[i][0] for i in range(4)], dim0)
img_grid torchvision.utils.make_grid(cifar_img,nrow4,normalizeTrue,pad_value0.9,padding1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()
nrow是指定每一行的图片的数量这里只有四张图片所以是4默认nrow8
normalize是对图片数据进行标准化
pad_value是对图片间隔之间的像素进行填充的像素值
padding是指定图片之间的像素间隔数量 同时显示100张图片
train_dataset CIFAR10(rD:\deep_learning\12_16\data, trainTrue, downloadFalse,transformToTensor())
cifar_img torch.stack([train_dataset[i][0] for i in range(100)], dim0)
img_grid torchvision.utils.make_grid(cifar_img,nrow10,normalizeTrue,pad_value0.9,padding1)
plt.imshow(img_grid.permute(1,2,0))
plt.show() 批次图片可视化
我们对使用DataLoader生成的批次数据进行可视化
if __name____main__:train_dataset CIFAR10(rD:\deep_learning\12_16\data, trainTrue, downloadFalse,transformToTensor())trainloader DataLoader(train_dataset,shuffleTrue,batch_size128,num_workers8)trainloader iter(trainloader)trainloader_first_batch next(trainloader)imgs,labels trainloader_first_batchbatch_grid torchvision.utils.make_grid(imgs)plt.imshow(batch_grid.permute(1,2,0))plt.show() 对训练数据集更好的了解是为了在训练的时候获得更好的模型性能欢迎大家讨论交流~
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/90694.shtml
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!