使用make_grid多批次显示网格图像(使用CIFAR数据集介绍)

背景介绍

在机器学习的训练数据集中,我们经常使用多批次的训练来实现更好的训练效果,具体到cv领域,我们的训练数据集通常是[B,C,W,H]格式,其中,B是每个训练批次的大小,C是图片的通道数,如果是1则为灰度图像,如果是3则为彩色图像,W,H分别是图像的像素宽和像素高,在torchvision中,为我们提供了方便的方法显示多通道的图像显示成网格的格式

数据集介绍

这里使用机器学习中经典的CIFAR10数据集,具体可以参考博客CIFAR-10数据集详解与可视化_cifar10数据集可视化-CSDN博客

数据集读取

我们假设已经下载好CIFAR数据集保存在本地计算机的路径中,可以通过CIFAR函数进行读取

# 依赖的库环境
import torchvision
import torch
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor,Compose,Resize

读取CIFAR数据集中的训练数据集

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())

这里的转换方式是使用简单的ToTensor()将图片格式转换成经典的[C,W,H]格式,方便后续的可视化操作

此时我们可以简单地对数据集中的第一张图片进行可视化

img,label = train_dataset[0]
plt.imshow(img.permute(1,2,0))
plt.show()

构造批次数据集

如何构造批次的训练数据集呢?可以通过DataLoader的方式获得批次生成器,也可以通过torch.stack函数自定义地构成

cifar_img = torch.stack([train_dataset[i][0] for i in range(4)], dim=0)

这里使用列表推导式获得前4张图片组成的数据列表,通过torch.stack指定dim=0进行多个数据的堆加,这里需要注意的是,stack是在指定的维度新增一个维度进行多矩阵的合并,cat是在指定的维度上合并多个矩阵而不增加新的维度

cat与stack的区别

我们来具体看看两者的区别

cat_img = torch.cat([train_dataset[i][0] for i in range(4)],dim=0)
stack_img = torch.stack([train_dataset[i][0] for i in range(4)],dim=0)
print(f'cat_shape:{cat_img.shape}')
print(f'stack_shape:{stack_img.shape}')
cat_shape:torch.Size([12, 32, 32])
stack_shape:torch.Size([4, 3, 32, 32])

train_dataset[i][0]的形状为[3,32,32],当使用cat时,直接在第一维度上进行累加获得[12,32,32];使用stack时,在指定的第一维度上新增一个维度进行累加,有[4,3,32,32]

进行网格化显示

使用torchvision.utils.make_grid函数进行网格格式转换

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
cifar_img = torch.stack([train_dataset[i][0] for i in range(4)], dim=0)
img_grid = torchvision.utils.make_grid(cifar_img,nrow=4,normalize=True,pad_value=0.9,padding=1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()

nrow是指定每一行的图片的数量,这里只有四张图片,所以是4,默认nrow=8

normalize是对图片数据进行标准化

pad_value是对图片间隔之间的像素进行填充的像素值

padding是指定图片之间的像素间隔数量

同时显示100张图片

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
cifar_img = torch.stack([train_dataset[i][0] for i in range(100)], dim=0)
img_grid = torchvision.utils.make_grid(cifar_img,nrow=10,normalize=True,pad_value=0.9,padding=1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()

批次图片可视化

我们对使用DataLoader生成的批次数据进行可视化

if __name__=='__main__':train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())trainloader = DataLoader(train_dataset,shuffle=True,batch_size=128,num_workers=8)trainloader = iter(trainloader)trainloader_first_batch = next(trainloader)imgs,labels = trainloader_first_batchbatch_grid = torchvision.utils.make_grid(imgs)plt.imshow(batch_grid.permute(1,2,0))plt.show()

对训练数据集更好的了解是为了在训练的时候获得更好的模型性能,欢迎大家讨论交流~


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/656880.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接口请求,上传文件报500异常

异常响应 {"timestamp": "2024-01-29T06:39:28.82000:00","status": 500,"error": "Internal Server Error","path": "/test/upload" }服务端日志 服务端无日志打印 分析方向 nginx配置 nginx配置…

如何多个excel中的数据分发到多个excel中去

这个问题之前有一个文章我写了这个方法,但是后来发现效率太低了,于是再次更新一下对应的技术方案,提速5000倍。 一下代码主要实现的功能: 我有5000多个excel文件,每个文件是一只股票从上市至今的日K交易数据&#xff0…

Python网络拓扑库之mininet使用详解

概要 网络工程师、研究人员和开发人员需要进行各种网络实验和测试,以评估网络应用和协议的性能,以及解决网络问题。Python Mininet是一个功能强大的工具,它允许用户创建、配置和仿真复杂的网络拓扑,以满足各种实际应用场景。本文…

计算机二级Python选择题考点——Python语言程序设计Ⅰ

在Python中,变量名的命名规则:以字母或下划线开头,后面跟字母、下划线和数字,不能以数字开头。在Python语言中,可以作为源文件后缀名的是py。chr(x)和ord(x)函数用于在单字符和Unicode编码值之间进行转换。Python语言中用来表示代…

运行yolo v8 YOLOv8-CPP-Inference C++部署遇到的问题

环境: openCv:4.8.0 torch: 2.0.0 cuda:cuda_11.7.r11.7 遇到问题1: (tools) rogi7:~/my_file/obj/ultralytics/examples/YOLOv8-CPP-Inference/build$ ./Yolov8CPPInference Running on CUDA [ WARN:00.039] global net_impl.cpp:178 setUpNet DNN mo…

Java 面向对象进阶 01(黑马)

static案例代码: 代码: public class Student {private String gender;private String name;private int age;public static String teacherName ;public Student() {}public Student(String gender, String name, int age) {this.gender gender;this.…

[晓理紫]每日论文分享(有中文摘要,源码或项目地址)--大模型、扩散模型、视觉语言导航

专属领域论文订阅 VX 关注{晓理紫},每日更新论文,如感兴趣,请转发给有需要的同学,谢谢支持 如果你感觉对你有所帮助,请关注我,每日准时为你推送最新论文。 为了答谢各位网友的支持,从今日起免费…

费一凡:土木博士的自我救赎之道 | 提升之路系列(五)

导读 为了发挥清华大学多学科优势,搭建跨学科交叉融合平台,创新跨学科交叉培养模式,培养具有大数据思维和应用创新的“π”型人才,由清华大学研究生院、清华大学大数据研究中心及相关院系共同设计组织的“清华大学大数据能力提升项…

Linux true/false区分

bash的数值代表和其它代表相反:0表示true;非0代表false。 #!/bin/sh PIDFILE"pid"# truenginx进程运行 falsenginx进程未运行 checkRunning(){# -f true表示普通文件if [ -f "$PIDFILE" ]; then# -z 字符串长度为0trueif [ -z &qu…

时序数据库 Tdengine 执行命令能够查看执行的sql语句

curl是 访问6041端口,在windows系统里没有linux里的curl命令,需要用别的工具实现。我在cmd里是访问6030端口 第一步 在安装是时序数据库的服务器上也就是数据库服务端 进入命令窗口 执行 taos 第二步 执行 show queries\G;

jsjiami.v6加解密教学

1. 优点 a. 安全性提升 JavaScript 加密可以有效保护源代码,减少恶意用户的攻击风险。 b. 代码混淆 通过混淆技术,可以使代码变得难以阅读和理解,增加破解的难度。 c. 知识产权保护 对于商业项目,JavaScript 加密有助于保护…

Abp 创建一个WPF的项目

开发环境:VS2022、.NET6 1、创建项目:MyWpfApp,这里不再废话了。 2、NuGet添加: 2.1、Volo.Abp.Autofac 2.2、Serilog.Sinks.File 2.3、Serilog.Sinks.Async 2.4、Serilog.Extensions.Logging 2.5、Serilog.Extensions.Hos…

java spring boot 导入bean 的四种方式

1 Import导入bean的四种方式 2 代码 2.1 要导入的bean package com.example.demo;public class MyUser { }package com.example.demo;public class MyRow { }2.2 各种方式的代码 2.2.1 Import(MyUser.class) package com.example.demo;import org.springframework.boot.Sp…

低功耗蓝牙(BLE) 和 经典蓝牙(SPP) 的区别

低功耗蓝牙(BLE) vs 经典蓝牙(SPP) 区别项低功耗蓝牙(BLE)经典蓝牙(SPP 串行端口协议)蓝牙版本蓝牙版本 > 4.0,又称蓝牙低功耗、蓝牙智能经典蓝牙2.0 或更早版本,经典配对模式在两台蓝牙设备之间建立虚拟串口数据连接,提供一种简单而直接…

DML的原理:一篇文章让你豁然开朗

推荐阅读 给软件行业带来了春天——揭秘Spring究竟是何方神圣(一) 给软件行业带来了春天——揭秘Spring究竟是何方神圣(二) 文章目录 推荐阅读DML 数据操纵语言INSERT语句UPDATE语句DELETE语句SELECT语句 DML 数据操纵语言 DML是…

【前端】防抖和节流

防抖 防抖用于限制连续触发的事件的执行频率。当一个事件被触发时,防抖会延迟一定的时间执行对应的处理函数。如果在延迟时间内再次触发了同样的事件,那么之前的延迟执行将被取消,重新开始计时。 总结:在单位时间内频繁触发事件,只有最后一次生效 场景 :用户在输入框输…

消息中间件RabbitMQ介绍

一、基础知识 1. 什么是RabbitMQ RabbitMQ是2007年发布,是一个在AMQP(高级消息队列协议)基础上完成的,简称MQ全称为Message Queue, 消息队列(MQ)是一种应用程序对应用程序的通信方法,由Erlang(专门针对于大…

sqli-labs部署及sqli-labs靶场第一关

部署 一、环境安装 1.下载phpstudy,下载链接:小皮面板(phpstudy) - 让天下没有难配的服务器环境! ,傻瓜式的安装过后打开软件进入如下界面,我们开启nginx和mysql !!!&#xff0…

金蝶云星空AppDesigner.AppDesignerService.RecordCurDevCodeInfo RCE漏洞

免责声明:文章来源互联网收集整理,请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该…

第38期 | GPTSecurity周报

GPTSecurity是一个涵盖了前沿学术研究和实践经验分享的社区,集成了生成预训练Transformer(GPT)、人工智能生成内容(AIGC)以及大型语言模型(LLM)等安全领域应用的知识。在这里,您可以…