使用GoogleNet实现对花数据集的分类预测

使用GoogleNet实现对花数据集的分类预测

  • 1.作者介绍
  • 2.关于理论方面的知识介绍
    • 2.1GooLeNet的知识介绍
    • 2.2CNN发展阶段
    • 2.2GooLeNet创新模块
  • 3.关于实验过程的介绍,完整实验代码,测试结果
    • 3.1数据集介绍
    • 3.2实验过程
    • 3.3实验结果

1.作者介绍

王海博, 男 , 西安工程大学电子信息学院, 2024级研究生, 张宏伟人工智能课题组
研究方向:模式识别与人工智能
电子邮件:1137460680@qq.com

2.关于理论方面的知识介绍

2.1GooLeNet的知识介绍

GooLeNet(Google LeNet)是 Google 在 2014 年 ILSVRC(ImageNet Large Scale Visual Recognition Challenge) 竞赛中提出的深度卷积神经网络(CNN),以其创新的 Inception 结构 取得了 分类任务第一名,显著降低了计算量,同时提高了准确率。它的核心创新在于:
(1)Inception 模块:利用多种卷积核并行处理,提高特征表达能力。
(2)参数量优化:通过 1×1 卷积降维,减少计算量。
(3)深度结构设计:使用 22 层深度,但参数量远小于 VGGNet 和 AlexNet。
(4)全局平均池化:减少全连接层的参数,提高泛化能力。

2.2CNN发展阶段

在 GooLeNet 之前,深度学习领域的 CNN 主要经历了以下几个阶段:
(1)LeNet-5(1998 年):由 Yann LeCun 提出,使用 5 层网络进行手写数字识别,包含卷积、池化和全连接层。层数较少,计算量小,适用于小型数据集。
(2)AlexNet(2012 年):由 Alex Krizhevsky 提出,8 层深度,首次使用 ReLU 激活函数 和 Dropout,避免梯度消失问题。在 ILSVRC 2012 竞赛中大幅提升分类精度(top-5 错误率 16.4%)。
但存在参数量过多(60M 参数),计算复杂度高的问题。
(3)VGGNet(2014 年):由 Oxford 大学提出,使用 多个 3×3 小卷积核 堆叠,提高模型深度(VGG-16 具有 138M 参数)。计算复杂度更高,训练和推理成本较大。
但也存在问题,过深的网络计算量巨大,容易导致梯度消失和过拟合;传统 CNN 结构使用单一卷积核,难以同时捕捉不同尺度的特征。

2.2GooLeNet创新模块

GooLeNet 的最大创新是 Inception 模块,它可以 并行提取不同尺度的特征,同时降低计算复杂度。
Inception 结构的核心思想是:在同一层使用多种不同大小的卷积核进行特征提取,然后拼接输出,增强模型的表达能力。
一个 Inception 模块通常包含 4 个不同的分支:1×1 卷积(降维);3×3 卷积(中等感受野);5×5 卷积(大感受野);3×3 最大池化(增强局部特征)。最终,所有分支的输出通道(feature maps)被拼接(concatenation),形成最终输出。
在这里插入图片描述
在这里插入图片描述

GooLeNet 由多个 Inception 模块组成,总共有 22 层深度(如果包括池化层则为 27 层)。
在这里插入图片描述
在这里插入图片描述
(1)前处理层(Feature Extractor):7×7 卷积 + 3×3 最大池化;1×1 降维 + 3×3 卷积。
(2)9 个 Inception 模块:逐层提取多尺度特征。
(3)全局平均池化(Global Average Pooling):取代全连接层,减少参数量,提高泛化能力。
(4)Softmax 分类:最终输出类别概率。

3.关于实验过程的介绍,完整实验代码,测试结果

3.1数据集介绍

Oxford 102 Flowers 数据集是由牛津大学视觉几何组(Visual Geometry Group, VGG)发布的一个细粒度花卉图像分类数据集,主要用于计算机视觉、图像分类、深度学习研究。该数据集包含102 种英国常见花卉,每个类别包含至少 40 张图像,总共收录了8189 张花卉图片。由于数据集中存在高度相似的花卉类别(如不同品种的百合或玫瑰),因此它被认为是一个细粒度分类问题(Fine-Grained Classification),比普通的图像分类任务更加复杂和具有挑战性。

在这里插入图片描述

(1)数据集划分:
训练集(Training Set):1020 张(每类 10 张),每个类别有 10 张图片,适用于训练模型。
验证集(Validation Set):1020 张(每类 10 张),每个类别 10 张图片,用于模型超参数调优和验证性能。
测试集(Test Set):6149 张(剩余所有图片),用于最终评估模型的分类效果。

(2)数据集特点:
Oxford 102 Flowers 数据集之所以被广泛用于计算机视觉和深度学习研究,主要有以下几个特点:
a.细粒度分类任务:
该数据集的 102 类花卉中,许多花卉类别的外观相似,区分难度较大。
这使得 Oxford 102 Flowers 成为**细粒度分类(Fine-Grained Classification)**的典型案例。
细粒度分类相比一般的图像分类任务更具有挑战性,因为它要求模型能够捕捉到微小的视觉差异,比如花瓣形状、颜色模式、叶子分布等。
b.真实世界的复杂性:
数据集中的图片是从真实世界环境中收集的,背景复杂,不是理想的实验室环境。
花卉在不同的光照、角度、遮挡情况下都会表现出不同的视觉特征,使得模型必须具备较强的泛化能力。
c.高质量图像:
与某些低分辨率或噪声较多的数据集不同,Oxford 102 Flowers 提供了高分辨率的高清图片,适用于深度学习任务。
这使得研究人员可以使用CNN(卷积神经网络)、**ViT(视觉变换器)等先进模型进行训练。
d.适用于迁移学习:
由于该数据集类别较多、样本数量相对有限,因此可以作为
迁移学习(Transfer Learning)**的目标数据集。许多研究使用预训练的 ResNet、VGG、EfficientNet、ViT 等模型进行迁移学习,以提高分类效果。

(3)数据集获取:
Oxford 102 Flowers 数据集可以从以下渠道下载:
官方链接: http://www.robots.ox.ac.uk/~vgg/data/flowers/102/
Kaggle 数据集链接: https://www.kaggle.com/datasets/

3.2实验过程

对比对比两种深度学习模型(EfficientNetV2-M 和 GoogLeNet)在 Flowers102 数据集上的图像分类性能。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 增强的数据预处理
train_transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.RandomRotation(30),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])val_test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])# 加载数据集
train_dataset = datasets.Flowers102(root='./data', split='train', transform=train_transform, download=True)
val_dataset = datasets.Flowers102(root='./data', split='val', transform=val_test_transform, download=True)
test_dataset = datasets.Flowers102(root='./data', split='test', transform=val_test_transform, download=True)# DataLoader配置
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 模型定义函数
def create_model(model_name):if model_name == 'efficientnet':model = models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.IMAGENET1K_V1)model.classifier[1] = nn.Sequential(nn.Dropout(0.3),nn.Linear(model.classifier[1].in_features, 102))elif model_name == 'googlenet':model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)model.fc = nn.Linear(model.fc.in_features, 102)else:raise ValueError("不支持的模型类型")return model.to(device)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练函数(支持多模型对比)
def train_compare_models(model_dict, train_loader, val_loader, epochs=10):history = {name: {'train_loss': [], 'val_acc': []} for name in model_dict}for name, model in model_dict.items():print(f"\n{'='*30} 训练 {name} 模型 {'='*30}")# 优化器配置optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)  # 统一学习率scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)best_acc = 0.0for epoch in range(epochs):model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * images.size(0)scheduler.step()# 验证阶段val_loss, val_acc = evaluate(model, val_loader, criterion)train_loss = running_loss / len(train_loader.dataset)history[name]['train_loss'].append(train_loss)history[name]['val_acc'].append(val_acc)print(f"[{name}] Epoch {epoch+1}/{epochs}")print(f"Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.2f}%")# 保存最佳模型if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), f'best_{name}.pth')return history# 验证函数
def evaluate(model, loader, criterion):model.eval()correct = 0total = 0running_loss = 0.0with torch.no_grad():for images, labels in loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()running_loss += loss.item() * images.size(0)return running_loss / len(loader.dataset), 100 * correct / total# 测试函数
def test_model(model, test_loader):model.eval()correct, total = 0, 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"测试集准确率: {accuracy:.2f}%")# 图片显示函数
def imshow(img, title):img = img.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])img = std * img + meanimg = np.clip(img, 0, 1)plt.imshow(img)plt.title(title)plt.axis('off')# 对比可视化函数
def plot_comparison(history):plt.figure(figsize=(12, 5))# 训练损失对比plt.subplot(1, 2, 1)for name, data in history.items():plt.plot(data['train_loss'], label=f'{name} 训练损失')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('训练损失对比')plt.legend()# 验证准确率对比plt.subplot(1, 2, 2)for name, data in history.items():plt.plot(data['val_acc'], label=f'{name} 验证准确率')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('验证准确率对比')plt.legend()plt.tight_layout()plt.show()if __name__ == "__main__":# 定义模型models_dict = {'EfficientNetV2-M': create_model('efficientnet'),'GoogleNet': create_model('googlenet')}# 训练并对比模型history = train_compare_models(models_dict, train_loader, val_loader, epochs=10)# 绘制对比图plot_comparison(history)# 测试集性能对比print("\n测试集性能对比:")for name, model in models_dict.items():model.load_state_dict(torch.load(f'best_{name}.pth'))print(f"\n{name} 模型测试结果:")test_model(model, test_loader)# 随机抽取5张图片进行预测对比class_names = [f"Class {i}" for i in range(102)]random_indices = random.sample(range(len(test_dataset)), 5)images_list = []labels_list = []for idx in random_indices:img, lbl = test_dataset[idx]images_list.append(img)labels_list.append(lbl)images_tensor = torch.stack(images_list).to(device)labels_tensor = torch.tensor(labels_list).to(device)# 显示预测结果plt.figure(figsize=(15, 5))for i, (name, model) in enumerate(models_dict.items()):model.eval()outputs = model(images_tensor)_, predicted = torch.max(outputs, 1)for j in range(5):plt.subplot(2, 5, i * 5 + j + 1)imshow(images_tensor[j].cpu(), f"{name}\n预测: {class_names[predicted[j].item()]}")plt.tight_layout()plt.show()

3.3实验结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
对5张随机抽取的图片进行预测,两种模型输出的结果:
在这里插入图片描述
EfficientNetv2和Googlenet模型进行对比结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/73747.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

流量密码破解:eBay店铺首页改版后的黄金展示位

流量密码破解:eBay店铺首页改版后的黄金展示位 近年来,跨境电商行业竞争日趋激烈,流量分配机制的调整成为平台卖家最为关注的核心议题之一。作为全球领先的在线电商平台,eBay的每一次页面优化都可能对卖家的经营策略产生深远影响…

从0到1彻底掌握Trae:手把手带你实战开发AI Chatbot,提升开发效率的必备指南!

我正在参加Trae「超级体验官」创意实践征文, 本文所使用的 Trae 免费下载链接: www.trae.ai/?utm_source… 前言 大家好,我是小Q,字节跳动近期推出了一款 AI IDE—— Trae,由国人团队开发,并且限时免费体…

Netty:java高性能网络编程的基石(下)

一、Netty是什么?为什么需要它? Netty是一个异步事件驱动的网络应用框架,基于Java NIO技术封装,用于快速开发高性能、高可靠性的网络服务器和客户端程序。作为当前最流行的NIO框架之一,支撑着Dubbo、RocketMQ、Elasti…

leetcode-47.全排列II

如何在有重复值的时候节省时间是优化重点。 基础写法肯定是按无重复值时的全排列写,在其中要加上防止走重复路径的分支。 能防止的也只有同层,如果同层走一个值,但是该值重复,且走过了,则放弃走该分支。所以设layer_u…

函数(函数的概念、库函数、自定义函数、形参和实参、return语句、数组做函数参数、嵌套调用和链式访问、函数的声明和定义、static和extern)

一、函数的概念 •C语⾔中的函数:⼀个完成某项特定的任务的⼀⼩段代码 •函数又被翻译为子函数(更准确) •在C语⾔中我们⼀般会⻅到两类函数:库函数 ⾃定义函数 二、库函数 1 .标准库和头文件 •C语⾔的国际标准ANSIC规定了⼀…

孜然SEO静态页面生成系统V1.0

孜然SEO静态页面生成系统,1秒生成上万个不同的静态单页系统,支持URL裂变采集,采集的内容不会重复,因为程序系统自带AI重写算法,AI扩写算法,可视化的蜘蛛池系统让您更清楚的获取到信息! 可插入二…

Secs/Gem第一讲 · 总结精华版(基于secs4net项目的ChatGpt介绍)

好的!这就是《第一讲 总结精华版》——为背诵准备的口述速成稿,适合面试前复习答题用。我们会分为两个部分: 第一部分:一整段口述稿,可以当成面试时开口自我介绍用;第二部分:要点清单关键词串…

预处理指令中#if 和 #endif的用法

在 C 语言中,#if 和 #endif 是预处理指令,用于条件编译。它们的核心作用是:根据预处理器能够识别的条件(通常是宏定义或常量表达式),决定某段代码是否参与编译。 — 基本功能 #if 用于开启一个条件编译块…

【数据库】掌握MySQL事务与锁机制-数据一致性的关键

在数据库的世界里,数据就是一切。而确保数据的准确性和一致性,则是数据库系统的核心任务之一。想象一下,如果没有合适的机制,当多个用户同时试图修改同一条数据时,会发生什么? chaos(混乱&#…

linux 基础网络配置文件

使用“ifconfig”命令查看网络接口地址 直接执行“iconfg”命令后可以看到ens33、10、virbr0这3个网络接口的信息,具体命令如下 ifconfig ##查看网络接口地址 ens33:第一块以太网卡的名称 lo:“回环”网络接口 virbr0:虚拟网桥的连接接口 查看指…

OpenCV特征提取与深度学习CNN特征提取差异

一、特征生成方式 ‌OpenCV传统方法‌ ‌手工设计特征‌:依赖人工设计的算法(如SIFT、FAST、BRIEF)提取图像中的角点、边缘等低层次特征,需手动调整参数以适应不同场景‌。‌数学驱动‌:基于梯度变化、几何变换等数学规…

五种方案实现双链路可靠数据传输

本文介绍五种双链路数据传输方案,目标是利用设备的多个传输通道,(如双有线网口,网口+wifi, 网口+5G等场景 , 网口+ 自组网, 自组网 + 5G等),将数据复制后分流、分路同时传输,以期提高数据传输可靠性,满足高可靠性传输的应用场景需求。部分方案给出了实际验证结果 。 …

【备赛】遇到的小问题-1

问题描述-1 想实现的功能是,通过ADC实时测量某引脚的电压及其占空比。 可以通过旋转电位器,更改其电压。 首先我定义了这几个变量 uint32_t adc_value;//HAL库函数里面得出的采样值(实时更新) uint32_t percentage6;//占空比,随着adc_val…

最大公约数

4.最大公约数 - 蓝桥云课 最大公约数 题目描述 给定两个正整数 A,B,求它们的最大公约数。 输入描述 第1行为一个整数 T,表示测试数据数量。 接下来的 T 行每行包含两个正整数 A,B。 1≤T≤105,1≤A,B≤109。 输出描述 输出共 T 行&…

TMHMM2.0-蛋白跨膜螺旋预测工具-centos-安装+配置+排错

参考: A. Krogh, B. Larsson, G. von Heijne, and E. L. L. Sonnhammer. Predicting transmembrane protein topology with a hidden Markov model: Application to complete genomes. Journal of Molecular Biology, 305(3):567-580, January 2001. centos&#x…

docker run 命令常用参数

docker run 命令 用于从镜像创建并启动一个新的容器。 基本语法: docker run [OPTIONS] IMAGE [COMMAND] [ARG...]常用选项分类说明 容器配置 --name 为容器指定名称(默认随机生成)。 示例: docker run --name my_container …

Zbrush插件安装

安装目录在: ...\Zbrush2022\ZStartup\ZPlugs64

pandas中excel自定义单元格颜色

writerpd.ExcelWriter(filepathf05教师固定学生占比1月{today}.xlsx,engineopenpyxl) df.to_excel(writer,sheet_name明细) piv1.to_excel(writer,sheet_name1月分布) wswriter.book.create_sheet(口径) ws.cell(1,1).value综合占比: ws.cell(1,2).value固定学生占比…

整体二分算法讲解及例题

算法思想 整体二分,带有二分二字那么就一定和二分脱不了干系。 整体二分算法常用来解决询问区间的第 k k k小值的问题,思路如下: 我们二分的对象是这道题目给定的值域,及最小值与最大值之间的区间,在题目给定的数组中…

python+flask实现360全景图和stl等多种格式模型浏览

1. 安装依赖 pip install flask 2. 创建Flask应用 创建一个基本的Flask应用,并设置路由来处理不同的文件类型。 from flask import Flask, render_template, send_from_directory app Flask(__name__) # 设置静态文件路径 app.static_folder static app.r…