GoogleNet网络训练集和测试集搭建

测试集和训练集都是在之前搭建好的基础上进行修改的,重点记录与之前不同的代码。

还是使用的花分类的数据集进行训练和测试的。

一、训练集

1、搭建网络

设置参数:使用辅助分类器,采用权重初始化

net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)

2、参数输出

之前的模型只有 1 个输出,但由于GoogleNet使用了两个辅助分类器,所以会有 3 个输出。

定义三个输出,分别计算主分类器、辅助分类器1、辅助分类器2的损失函数并相加,最后将损失函数反向传播,使用优化器更新参数模型。 

不单独放代码了,不知道哪里是改动的。图片中红色框中是改动的

整个训练集的代码

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import GoogleNet
import os
import json
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,shuffle=False, num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = next(test_data_iter)
#
# # 查看图片
# def imshow(img):
#     img = img / 2 + 0.5
#     nping = img.numpy()
#     plt.imshow(np.transpose(nping, (1, 2, 0)))
#     plt.show()
# # print labels
# print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)best_acc = 0.0
save_path = './GoogleNet.pth'
# best_acc = 0.0
for epoch in range(2):# trainnet.train()running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()rate = (step+1) / len(train_loader)a = "*" * int(rate*50)b = "." *int((1-rate)*50)print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)net.eval()acc = 0.0with torch.no_grad():for data_test in validate_loader:test_images, test_labels = data_testoutputs = net(test_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == test_labels.to(device)).sum().item()accurate_test = acc / val_numif accurate_test > best_acc:best_acc = accurate_testtorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

训练完成 

 中间有几次报错,不过在看懂报错后很快改过来了。

二、测试集

载入模型

在创建模型的时候,aux_logits不会构建辅助分类器,但是之前训练的参数会保存。

所以,在载入模型的时候,要设置参数strict=False, 它可以精准匹配当前模型与所需要载入的权重模型的结构。

辅助分类器中的参数全部存放在unexpecte_keys中。

测试集全部代码

 可以自己找图片进行预测看准确率。

import torch
import matplotlib.pyplot as plt
import json
from model import GoogleNet
from PIL import Image
from torchvision import transformsdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("8.jpeg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)# read class_indent
try:json_file = open('./class_indices,json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = GoogleNet(num_classes=5, aux_logits=False)
model_weight_path = "./GoogleNet.pth"
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
model.eval()
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

准确率好低,可能是模型训练的还不够吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/823779.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Midjourney常见玩法及prompt关键词技巧

今天系统给大家讲讲Midjourney的常见玩法和prompt关键词的一些注意事项,带大家入门~(多图预警,建议收藏~) 一、入门及常见玩法 1、注册并添加服务器(会的童鞋可跳过~) …

实在智能携AI Agent智能体亮相2024年度QCon全球软件开发大会

4月11日-13日,以“全面进化”作为2024年度主题的「QCon全球软件开发大会暨智能软件开发生态展」在北京隆重举行。作为AI准独角兽和超自动化头部企业,实在智能应邀出席发表《面向办公自动化领域的AI Agent建设思考与分享》演讲及圆桌交流,展示…

反转二叉树(力扣226)

解题思路:用队列进行前序遍历的同时把节点的左节点和右节点交换 具体代码如下: class Solution { public:TreeNode* invertTree(TreeNode* root) {if (root NULL) return root;swap(root->left, root->right); // 中invertTree(root->left)…

OpenHarmony南向开发案例【智慧中控面板(基于 Bearpi-Micro)】

1 开发环境搭建 【从0开始搭建开发环境】【快速搭建开发环境】 参考鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或复制转到。 【注意】:快速上手教程第六步出拉取代码时需要修改代码仓库地址 在MobaXterm中输入…

【模拟】Leetcode 提莫攻击

题目讲解 495. 提莫攻击 算法讲解 前后的两个数字之间的关系要么是相减之差 > 中毒时间 &#xff0c;要么反之 那即可通过示例&#xff0c;进行算法的模拟&#xff0c;得出上图的计算公式 class Solution { public:int findPoisonedDuration(vector<int>& time…

至少需要[XXXXMB]内存才能安装(宝塔导入数据库提示)

①我的2g内存腾讯云服务器想安装mysql8.0 ②宝塔提示“至少需要[3700MB]内存才能安装” 将数据库部署到宝塔上的时候提示-----》至少需要[XXXXMB]内存才能安装&#xff0c;解决的方法其实也很简单。 首先&#xff0c;进入文件夹/www/server/panel/class&#xff0c;找到找到…

PostgreSQL的学习心得和知识总结(一百三十八)|深入理解PostgreSQL数据库之Protocol message构造和解析逻辑

目录结构 注&#xff1a;提前言明 本文借鉴了以下博主、书籍或网站的内容&#xff0c;其列表如下&#xff1a; 1、参考书籍&#xff1a;《PostgreSQL数据库内核分析》 2、参考书籍&#xff1a;《数据库事务处理的艺术&#xff1a;事务管理与并发控制》 3、PostgreSQL数据库仓库…

【Java开发指南 | 第十一篇】Java运算符

读者可订阅专栏&#xff1a;Java开发指南 |【CSDN秋说】 文章目录 算术运算符关系运算符位运算符逻辑运算符赋值运算符条件运算符&#xff08;?:&#xff09;instanceof 运算符Java运算符优先级 Java运算符包括&#xff1a;算术运算符、关系运算符、位运算符、逻辑运算符、赋值…

Postman之版本信息查看

Postman之版本信息查看 一、为何需要查看版本信息&#xff1f;二、查看Postman的版本信息的步骤 一、为何需要查看版本信息&#xff1f; 不同的版本之间可能存在功能和界面的差异。 二、查看Postman的版本信息的步骤 1、打开 Postman 2、打开设置项 点击页面右上角的 “Set…

Java中的容器

Java中的容器主要包括以下几类&#xff1a; Collection接口及其子接口/实现类&#xff1a; List 接口及其实现类&#xff1a; ArrayList&#xff1a;基于动态数组实现的列表&#xff0c;支持随机访问&#xff0c;插入和删除元素可能导致大量元素移动。LinkedList&#xff1a;基…

只用键盘的技巧

技巧一&#xff1a;将常用软件固定在任务栏使用winnum/winT(shift)打开 技巧二&#xff1a;winX快捷键&#xff08;显示快捷键的快捷键&#xff09; ALT F4    关闭当前应用程序 技巧三&#xff1a;使用好Chrome快键键 ctrl h&#xff1b;历史纪录。 ctrl shift esc&am…

充电桩--OCPP 充电通讯协议介绍

一、OCPP协议介绍 OCPP的全称是 Open Charge Point Protocol 即开放充电点协议&#xff0c; 它是免费开放的协议&#xff0c;该协议由位于荷兰的组织 OCA&#xff08;开放充电联盟&#xff09;进行制定。Open Charge Point Protocol (OCPP) 开放充电点协议用于充电站(CS)和任何…

大模型开发轻松入门——(1)从搭建自己的环境开始

pip install openai import openai import osfrom dotenv import load_dotenv, find_dotenv _ load_dotenv(find_dotenv())openai.api_key os.getenv(OPENAI_API_KEY)

mac IDEA激活 亲测有效

1、官网下载mac版本IDEA并安装 2、打开激活页面 3、下载脚本文件 链接: https://pan.baidu.com/s/1I2BqdfxSJv1A96422rflnA?pwdm494 提取码: m494 4、命令行到该界面&#xff0c;执行 sudo bash idea.sh 可能出现的问题&#xff1a; 查看sh文件&#xff0c;targetFilePath…

vue快速入门(十六)事件修饰符

注释很详细&#xff0c;直接上代码 上一篇 新增内容 事件修饰符之阻止冒泡事件修饰符之阻止默认行为 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdev…

重定向原理和缓冲区

文章目录 重定向缓冲区 正文开始前给大家推荐个网站&#xff0c;前些天发现了一个巨牛的 人工智能学习网站&#xff0c; 通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。 点击跳转到网站。 重定向 内核中为了管理被打开的文件&#xff0c;一定会存在描述一…

ESP8266闪存文件系统(SPIFFS)

开发环境&#xff1a; 1、安装ESP8266的开发环境&#xff0c;如Arduino IDE。 2、下载并安装ESP8266的相关开发库和工具。 我们使用的是Arduino IDE。 基本介绍&#xff1a; 每一个ESP8266都配有一个闪存&#xff0c;这个闪存很像是一个小硬盘&#xff0c;我们上传的文件就被…

MCU最小系统晶振模块设计

单片机的心脏&#xff1a;晶振 晶振模块 单片机有两个心脏&#xff0c;一个是8M的心脏&#xff0c;一个是32.768的心脏 8M的精度较低&#xff0c;所以需要外接一个32.768khz 为什么是8MHZ呢&#xff0c;因为内部自带的 频率越高&#xff0c;精度越高&#xff0c;功耗越大&am…

[Java EE] 多线程(二): 线程的创建与常用方法(下)

2.3 启动一个线程–>start() 之前我们已经看到了如何通过重写run()方法来创建一个线程对象,但是线程对象被创建出来并不意味着线程就开始运行了. 覆写run方法是给线程提供了所要做的事情的指令清单创建线程对象就是把干活的人叫了过来.而调用start方法,就是喊一声"行…

贪心法确定补水地点

贪心算法是一个简单有趣的算法&#xff0c;它总是做出当前看来最好的选择&#xff0c;每次的局部最优选择最终可以产生整体最优解或整体最优解的近似。本文将介绍如何用贪心法解决补水问题。 1. 补水问题 2升水可以走 k k k英里&#xff0c;水站可以把水补满为2升&#xff0c…