pytorch实现自己的深度神经网络(公共数据集)

一、训练文件——train.py

  注意:在运行此代码之前,需要配置好pytorch-GPU版本的环境,具体再次不谈。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)# 数据预处理的转换
transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8,shuffle=True, num_workers=0)# 定义神经网络模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 32 * 32, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 32 * 32)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型,并将其移动到可用设备上
model = CNN().to(device)# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)if __name__ == '__main__':# 训练神经网络for epoch in range(5):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 梯度清零optimizer.zero_grad()# 正向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播 + 优化loss.backward()optimizer.step()# 打印统计信息running_loss += loss.item()if i % 200 == 199:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')# 保存模型至文件torch.save(model.state_dict(), 'cifar10_cnn_model.pth')

二、测试文件——val.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)# 数据预处理的转换
transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10测试数据集
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)# 创建测试数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8,shuffle=False, num_workers=0)# 加载模型并将其移动到可用设备上
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 32 * 32, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 32 * 32)x = torch.relu(self.fc1(x))x = self.fc2(x)return x
# 显示函数
def imshow(img):img = img / 2 + 0.5npimg = img.numpy()# 坐标转换plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()model = CNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn_model.pth'))
model.eval()if __name__ == '__main__':# 在测试集上测试模型correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)# 预测值的最大值以及最大值的类别索引_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on the test images: %d %%' % (100 * correct / total))# 显示测试集中的一些图片及其预测结果# 生成一个迭代器,从数据加载器中取出数据dataiter = iter(test_loader)# 从迭代器中获取下一个批次的数据images, labels = dataiter.next()# 将获取到的批次数据移动到device上,在这里也就是GPU上images, labels = images.to(device), labels.to(device)dip_flag = Falseif dip_flag == True:# -------------------------------------------# 可以选择 使用opencv显示# -------------------------------------------np_images = images.cpu().numpy()# 循环遍历并显示所有测试集图片for i in range(len(np_images)):# 从归一化中还原图像数据np_image = np.transpose(np_images[i], (1, 2, 0))   # 从CHW转换为HWCnp_image = np_image * 0.5 + 0.5# 将图像数据从float类型转换为unit8类型np_image = (np_image * 255).astype(np.uint8)# 使用opencv显示图像cv2.imshow("Image {}".format(i+1), np_image)cv2.waitKey(0)# 等待用户按下任意键继续显示下一张图像cv2.destroyAllWindows()imshow(torchvision.utils.make_grid(images.cpu()))print('GroundTruth: ', ' '.join('%5s' % test_dataset.classes[labels[j]] for j in range(8)))outputs = model(images)_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % test_dataset.classes[predicted[j]]for j in range(8)))


直接运行即可,亲测可以运行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/822491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Midjourney 实现角色一致性的新方法

AI 绘画的奇妙之处,实乃令人叹为观止!就像大千世界中,寻不见两片完全相同的树叶一般,AI 绘画亦复如是。同一提示之词,竟能催生出千变万化的图像,使得AI所绘之作,宛如自然之物般独特,…

项目7-音乐播放器2(上传音乐+查询音乐+拦截器)

0.加入拦截器 之后就不用对用户是否登录进行判断了 0.1 定义拦截器 0.2 注册拦截器 生效 1.上传音乐的接口设计 请求: { post, /music/upload {singer,MultipartFile file}, } 响应: { "status": 0, "message&…

单链表经典算法题分析

目录 一、链表的中间节点 1.1 题目 1.2 题解 1.3 收获 二、移除链表元素 2.1 题目 2.2 题解 2.3 收获 2.4递归详解 三、反转链表 3.1 题目 3.2 题解 3.3 解释 四、合并两个有序列表 4.1 题目 4.2 题解 4.3 递归详解 声明:本文所有题目均摘自leetco…

《手把手教你》系列基础篇(九十二)-java+ selenium自动化测试-框架设计基础-POM设计模式简介(详解教程)

1.简介 页面对象模型(Page Object Model)在Selenium Webdriver自动化测试中使用非常流行和受欢迎,作为自动化测试工程师应该至少听说过POM这个概念。本篇介绍POM的简介,接下来宏哥一步一步告诉你如何在你JavaSelenium3自动化测试…

13个Java基础面试题

Hi,大家好,我是王二蛋。 金三银四求职季,特地为大家整理出13个 Java 基础面试题,希望能为正在准备或即将参与面试的小伙伴们提供些许帮助。 后续还会整理关于线程、IO、JUC等Java相关面试题,敬请各位持续关注。 这1…

【ROS2】搭建ROS2-Humble + Vscode开发流程

【ROS2】搭建ROS2-Humble Vscode开发流程 文章目录 【ROS2】搭建ROS2-Humble Vscode开发流程1.基本环境配置2.搭建Vscode开发环境 1.基本环境配置 基本的环境配置包括以下步骤: 安装ROS2-Humble,可以参考这里安装一些基本的工具,可以参考…

nuxt3项目使用swiper11插件实现点击‘’返回顶部按钮‘’返回到第一屏

该案例主要实现点击返回顶部按钮返回至swiper第一个slide。 版本: "nuxt": "^3.10.3", "pinia": "^2.1.7", "swiper": "^11.0.7", 官方说明 swiper.slideTo(index, speed, runCallbacks) Run transit…

浅析MySQL 8忘记密码处理方式

对MySQL有研究的读者,可能会发现MySQL更新很快,在安装方式上,MySQL提供了两种经典安装方式:解压式和一键式,虽然是两种安装方式,但我更提倡选择解压式安装,不仅快,还干净。在操作系统…

【数据结构1-基本概念和术语】

这里写自定义目录标题 0.数据,数据元素,数据项,数据对项,数据结构,逻辑结构,存储结构1.结构1.1逻辑结构1.2存储结构1.2.1 顺序结构1.2.2链式结构 1.3数据结构1.3.1基本数据类型1.3.2抽象数据类型1.3.2.1一个…

Java SpringBoot基于微信小程序的高速公路服务区充电桩在线预定系统,附源码

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

05节-51单片机-模块化编程

1.两种编程方式的对比 传统方式编程: 所有的函数均放在main.c里,若使用的模块比较多,则一个文件内会有很多的代码,不利于代码的组织和管理,而且很影响编程者的思路 模块化编程: 把各个模块的代码放在不同的…

STM32外设配置以及一些小bug总结

USART RX的DMA配置 这里以UART串口1为例,首先点ADD添加RX和TX配置DMA,然后模式一般会选择是normal,这个模式是当DMA的计数器减到0的时候就不做任何动作了,还有一种循环模式,是计数器减到0之后,计数器自动重…

Echats 引入地图(二) 之中国地图省份高亮

效果图: 代码: series: [{type: map,map: china,zoom: 1.2, // 地图放大aspectScale: 0.8, //地图宽高比例roam: true, //地图缩放、平移// 滚轮缩放的极限控制scaleLimit: {min: 0.5, //缩放最小大小max: 6, //缩放最大大小},itemStyle…

使用Android studio,安卓手机编译安装yolov8部署ncnn,频繁出现编译错误

从编译开始就开始出现错误,解决步骤: 1.降低graddle版本,7.2-bin --->>> 降低为 6.1.1-all #distributionUrlhttps\://services.gradle.org/distributions/gradle-7.2-bin.zip distributionUrlhttps\://services.gradle.org/di…

5.HC-05蓝牙模块

配置蓝牙模块 注意需要将蓝牙模块接5v,实测接3.3v好像不太好使的样子 首先需要把蓝牙模块通过TTL串口模块接到我们的电脑,然后打开我们的串口助手 注意,我们现在是配置蓝牙模块,所以需要进入AT模式,需要按着蓝牙模块上的黑色小按钮再上电,这时候模块上的LED灯以一秒慢闪一次…

性能工具之emqtt-bench BenchMark 测试示例

文章目录 一、前言二、典型压测场景三、机器准备四、典型压测场景1、并发连接2、消息吞吐量测试2.1 1 对 1(示例)2.2 多对1(示例)2.3 1对多(示例) 五、遇到的问题client(): EXIT for {shutdown,eaddrnotava…

IDM2024破解版 IDM软件破解注册序列号 idm教程 idm序列激活永久授权 Internet Download Manager网络下载加速神器

你是不是感觉下载东西资源的时候,下载的非常慢,即便是五十兆的光纤依旧慢、是不是想下载网页上的视频但不知如何进行下载……这些问题是否一直在困扰着您,今日小编特意我大家带来了这款IDM 2024破解版。 众所周知,IDM是一款功能强…

ChatGPT实用指南2024

随着ChatGPT技术的演进,越来越多的人开始在工作中利用此工具。以下是关于ChatGPT的实用指南,适合不太熟悉此技术的朋友参考。 一、ChatGPT概述 1. ChatGPT是什么? ChatGPT是基于OpenAI开发的GPT大型语言模型的智能对话工具。它能够通过自然语…

1.8.5 卷积神经网络近年来在结构设计上的主要发展和变迁——Inception-v4 和 Inception-ResNet

1.8.5 卷积神经网络近年来在结构设计上的主要发展和变迁——Inception-v4 和 Inception-ResNet 前情回顾: 1.8.1 卷积神经网络近年来在结构设计上的主要发展和变迁——AlexNet 1.8.2 卷积神经网络近年来在结构设计上的主要发展和变迁——VGGNet 1.8.3 卷积神经网络近…

oracle安装后报错ORA-01031: insufficient privileges

1.管理员身份打开CMD,输入net user查看计算机用户 2.键入"net localgroup ora_dba"查看ora_dba下的具体用户 3.键入"net localgroup ora_dba administrator /add"把本计算机用户都添加进ora_dba组下