Pytorch cifar10离线加载二进制文件

 

说明直接离线加载cifar10到Pytorch

'''
直接加载6个文件到pytorchdata_batch_1data_batch_2data_batch_3data_batch_4data_batch_5test_batch'''import os
import cv2
import pickle
import numpy as np
import matplotlib.pyplot as pltimport torchvision
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import transforms#加载cifar10的数据
def load_CIFAR_batch(filename):""" load single batch of cifar """with open(filename, 'rb') as f:datadict = pickle.load(f,encoding='latin1')X = datadict['data']Y = datadict['labels']# X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1).astype("float")X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1)Y = np.array(Y)return X, Ydef load_CIFAR10(ROOT):""" load all of cifar """xs = []ys = []for b in range(1,6):filename = os.path.join(ROOT, 'data_batch_%d' % (b))X, Y = load_CIFAR_batch(filename)xs.append(X)ys.append(Y)Xtrain = np.concatenate(xs)#使变成行向量Ytrain = np.concatenate(ys)del X, YXtest, Ytest = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))return Xtrain, Ytrain, Xtest, Ytestclass DealDataset(Data.Dataset):"""读取数据、初始化数据"""def __init__(self, root, train=True, transform=None):if train:# 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式(train_set, train_labels, _, _) = load_CIFAR10(root)self.train_set = train_setself.train_labels = train_labelselse:(_, _, test_set, test_labels) = load_CIFAR10(root)self.test_set = test_setself.test_labels = test_labelsself.transform = transformself.train = traindef __getitem__(self, index):if self.train:img, target = self.train_set[index], int(self.train_labels[index])else:img, target = self.test_set[index], int(self.test_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):if  self.train:return len(self.train_set)else:return len(self.test_set)root = r'E:\cifar-10-python\cifar-10-batches-py'
batch_size = 8# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(root, train=True, transform=transforms.ToTensor())
testDataset = DealDataset(root, train=False, transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = Data.DataLoader(dataset=trainDataset,batch_size=batch_size, # 一个批次可以认为是一个包,每个包中含有batch_size张图片shuffle=False,
)test_loader = Data.DataLoader(dataset=testDataset,batch_size=batch_size,shuffle=False,
)if __name__ == '__main__':# 这里trainDataset包含:train_labels, train_set等属性;  数据类型均为ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset# dataset中又包含train_labels, train_set等属性;  数据类型均为ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')# # 可视化1,使用OpenCV# images, lables = next(iter(train_loader))# img = torchvision.utils.make_grid(images, nrow = 10)# img = img.numpy().transpose(1, 2, 0)# # OpenCV默认为BGR,这里img为RGB,因此需要对调img[:,:,::-1]# cv2.imshow('img', img[:,:,::-1])# cv2.waitKey(0)# 可视化2,使用pltdataiter = iter(train_loader)images, labels = dataiter.next()images = images.numpy()classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']fig = plt.figure(figsize=(4, 4))for idx in np.arange(batch_size):ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])# ax.imshow(np.squeeze(images[idx]), cmap='gray')# a = images[idx]# b = images[idx].transpose(1, 2, 0)# ax.imshow(images[idx].transpose(1, 2, 0), cmap='RGB')ax.imshow(images[idx].transpose(1, 2, 0))ax.set_title(classes[labels[idx]])plt.show()

 

运行结果

显示图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/350206.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring cloud gateway 深入了解 - Predicate

文章来源 spring cloud gateway 通过谓词(Predicate)来匹配来自用户的请求 为了方便,使用postman测试不同的谓词的效果 路径谓词(Predicate)—— 最简单的谓词 配置如下spring:cloud:gateway:routes:# 匹配指定路径的路…

python漏洞检测脚本_URL重定向漏洞,python打造URL重定向漏洞检测脚本

前言:今天学习了重定向漏洞,这个漏洞比较好理解漏洞名:URL重定向漏洞威胁:低漏洞的来源:开发者对head头做好对应的过滤和限制例子:有漏洞的网站:http://a.com/x.php?urlhttp://a.com/login.php…

Pytorch cifar100离线加载二进制文件

说明:直接加载cifar100二进制文件到Pytorch 直接加载文件到pytorchmetatesttrain import os import cv2 import pickle import time import numpy as np import matplotlib.pyplot as pltimport torchvision from torch.autograd import Variable import torch.uti…

为单个Web应用程序配置多个上下文根– JBoss

有时&#xff0c;我们通过在jboss-web.xm l中定义一个来对应用程序进行更改&#xff0c;以支持利用JBoss功能的多个上下文根&#xff0c;如下所示&#xff1a; webapp / WEB-INF / jboss-web.xml&#xff1a; <?xml version"1.0" encoding"UTF-8"?&…

手动升级ubuntu 18.04内核从4.15.0-45-generic到4.15.0-52-generic

1 从下面官网下载相应的包&#xff0c;共3个。 https://kernel.ubuntu.com/~kernel-ppa/mainline/v4.15-rc9/ linux-headers-4.15.0-041500rc9-generic_4.15.0-041500rc9.201801212130_amd64.deb linux-headers-4.15.0-041500rc9_4.15.0-041500rc9.201801212130_all.deb linux-…

cmd52命令发送 mmc_乾坤合一~Linux SD/MMC/SDIO驱动分析(上)

一、SD/MMC/SDIO概念区分SD(SecureDigital)与 MMC(MultimediaCard)SD 是一种 flash memory card 的标准&#xff0c;也就是一般常见的 SD 记忆卡&#xff0c;而 MMC 则是较早的一种记忆卡标准&#xff0c;目前已经被 SD 标准所取代。在维基百科上有相当详细的 SD/MMC 规格说明&…

Pytorch MNIST直接离线加载二进制文件到pytorch

说明&#xff1a;MNIST直接离线加载二进制文件到pytorch 直接以下4个文件读入数据到pytorch中t10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gztrain-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gz import os import numpy as np import gzipimport torch.utils.data a…

爱与愁的心痛

爱与愁的心痛 题目链接 题意 这道题的题意是&#xff0c;给定一个整数数组&#xff0c;数组中的每个元素代表一个不爽的事情的刺痛值。现在需要找出连续m个刺痛值的和的最小值。 思路 读取输入和初始化遍历数组并计算窗口和输出最小和 坑点 数组越界重复计算窗口和 算法一&am…

begintrans返回值_SQL事务回滚 ADO BeginTrans, CommitTran 以及 RollbackTrans 方法

定义和用法这三个方法与 Connection 对象使用&#xff0c;来保存或取消对数据源所做的更改。注释&#xff1a;并非所有提供者都支持事务。注释&#xff1a;BeginTrans、CommitTrans 和 RollbackTrans 方法在客户端 Connection 对象上无效。那客户端不能支持事务? 这是什么意思…

Pytorch Fashion_MNIST直接离线加载二进制文件到pytorch

说明&#xff1a;Fashion_MNIST直接离线加载二进制文件到pytorch 将4个gz直接加载到pytoch用来训练t10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gztrain-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gz import os import numpy as np import gzip import matplotlib.p…

jQuery选择器种类整理

选择器概念 jQuery选择器是通过标签、属性或者内容对HTML内容进行选择&#xff0c;选择器运行对HTML元素组或者单个元素进行操作。 jQuery选择器使用$符号&#xff0c;等同于jquery&#xff0c;例如&#xff1a; $(“li”) jquery(“li”) 同样等同于javascript中的&#xff1…

jee过滤器应用场景_将涡轮增压器添加到JEE Apps

jee过滤器应用场景我扮演的关键角色之一是在本地社区中传播Akka。 作为讨论的一部分&#xff0c;人们通常会想到的问题/疑问是Akka如何针对编写良好的Java / JEE应用程序提供更好的可伸缩性和并发性。 由于底层硬件/ JVM保持不变&#xff0c;因此参与者模型如何比传统的JEE应用…

mysql 列 随机数_mysql mmp 某字段插入随机数!(说不定那天就忘记了,存下来再说)...

UPDATE 表名 SET 字段名ceiling(rand()*500000500000) WHERE (条件);原文链接&#xff1a;http://blog.csdn.net/bobay/article/details/24797525MMP 上面的只能更新一条UPDATE 表名 SET 字段名cast(rand(checksum(newid()))*(24)1 as int) WHERE (条件);上面的就是每条都更新的…

适用于Java开发人员的Elasticsearch:Elasticsearch生态系统

本文是我们学院课程的一部分&#xff0c;该课程的标题为Java开发人员的Elasticsearch教程 。 在本课程中&#xff0c;我们提供了一系列教程&#xff0c;以便您可以开发自己的基于Elasticsearch的应用程序。 我们涵盖了从安装和操作到Java API集成和报告的广泛主题。 通过我们简…

matplotlib markers的类型

https://matplotlib.org/api/markers_api.html matplotlib markers 所有可能的markers定义如下: marker symbol description "." point "," pixel "o" circle "v" triangle_down "^" triangle_up &…

android实时声音信号波形_Android输出正弦波音频信号(左右声道对称)-阿里云开发者社区...

转载请说明出处&#xff01;作者&#xff1a;kqw攻城狮出处&#xff1a;个人站 | CSDN需求&#xff1a;左右声道分别输出不同的音频数据&#xff0c;波形要是一个正弦波&#xff0c;左右声道还要对称&#xff01;对硬件不是很了解&#xff0c;说是要通过音波避障。效果图之前已…

matplotlib color可选

matplotlib color matplotlib中color可用的颜色&#xff1a; cnames { aliceblue: #F0F8FF, antiquewhite: #FAEBD7, aqua: #00FFFF, aquamarine: #7FFFD4, azure: #F0FFFF, beige: #F5F5…

python之scrapy爬取jd和qq招聘信息

1、settings.py文件 # -*- coding: utf-8 -*-# Scrapy settings for jd project # # For simplicity, this file contains only settings considered important or # commonly used. You can find more settings consulting the documentation: # # https://doc.scrapy.org…

opencl 加速 c语言程序_Win10应用获得面向OpenCL和OpenGL的兼容层

今年早些时候&#xff0c;微软宣布正在努力在Windows 10的Direct3D 12(D3D12)中启用对OpenCL和OpenGL映射层的支持。为了启用映射层&#xff0c;解决设备上没有OpenCL和OpenGL硬件驱动时的兼容性问题&#xff0c;公司目前已经在微软商店中发布了兼容性包。该兼容性包的标题为 &…

matplotlib plt.subplot

matplotlib plt.subplot 用于在一个Figure对象里画多个子图(Axes)。 其调用格式&#xff1a;subplot(numRows, numCols, plotNum)&#xff0c;即&#xff08;行、列、序号&#xff09;。 图表的整个绘图区域被分成numRows行和numCols列&#xff0c;plotNum参数指定创建的Axes…