从代码学习深度学习 - 使用块的网络(VGG)PyTorch版

文章目录

  • 前言
  • 一、VGG网络简介
    • 1.1 VGG的核心特点
    • 1.2 VGG的典型结构
    • 1.3 优点与局限性
    • 1.4 本文的实现目标
  • 二、搭建VGG网络
    • 2.1 数据准备
    • 2.2 定义VGG块
    • 2.3 构建VGG网络
    • 2.4 辅助工具
      • 2.4.1 计时器和累加器
      • 2.4.2 准确率计算
      • 2.4.3 可视化工具
    • 2.5 训练模型
    • 2.6 运行实验
  • 总结


前言

深度学习是近年来人工智能领域的重要突破,而卷积神经网络(CNN)作为其核心技术之一,在图像分类、目标检测等领域展现了强大的能力。VGG(Visual Geometry Group)网络是CNN中的经典模型之一,以其模块化的“块”设计和深层结构而闻名。本篇博客将通过PyTorch实现一个简化的VGG网络,并结合代码逐步解析其构建、训练和可视化过程,帮助读者从代码层面理解深度学习的基本原理和实践方法。我们将使用Fashion-MNIST数据集进行实验,展示如何从零开始搭建并训练一个VGG模型。

本文的目标读者是对深度学习有基本了解、希望通过代码实践加深理解的初学者或中级开发者。以下是博客的完整内容,包括代码实现和详细说明。


一、VGG网络简介

VGG网络(Visual Geometry Group Network)是由牛津大学视觉几何组在2014年提出的深度卷积神经网络(CNN)模型,因其在ImageNet图像分类竞赛中的优异表现而广为人知。VGG的设计理念是通过堆叠多个小卷积核(通常为3×3)和池化层,构建一个深层网络,从而提取图像中的复杂特征。与之前的模型(如AlexNet)相比,VGG显著增加了网络深度(常见版本包括VGG-16和VGG-19,分别有16层和19层),并采用统一的模块化结构,使其易于理解和实现。

1.1 VGG的核心特点

  1. 小卷积核:VGG使用3×3的小卷积核替代传统的大卷积核(如5×5或7×7)。两个3×3卷积核的堆叠可以达到5×5的感受野,而参数量更少,计算效率更高,同时增加了非线性(通过更多ReLU激活)。
  2. 模块化设计:网络由多个“块”(block)组成,每个块包含若干卷积层和一个最大池化层。这种设计使得网络结构清晰,便于扩展或调整。
  3. 深度增加:VGG通过加深网络层数(从11层到19层不等)提升性能,证明了深度对特征提取的重要性。
  4. 全连接层:在卷积层之后,VGG使用多个全连接层(通常为4096、4096和1000神经元)进行分类,输出对应ImageNet的1000个类别。

1.2 VGG的典型结构

以下是VGG-16的结构示意图,展示了其卷积块和全连接层的组织方式:

在这里插入图片描述

上图中:

  • 绿色方框表示卷积层(3×3卷积核,步幅1,padding=1),对应图中的“convolution+ReLU”部分(以立方体表示)。这些卷积层负责提取图像特征,padding=1确保特征图尺寸在卷积后保持不变。
  • 红色方框表示最大池化层(2×2,步幅2),对应图中的“max pooling”部分(以红色立方体表示)。池化层将特征图尺寸减半(例如从224×224到112×112),同时保留重要特征。
  • 蓝色部分为全连接层,最终输出分类结果,对应图中的“fully connected+ReLU”和“softmax”部分(以蓝色线条表示)。全连接层将卷积特征展平后进行分类,输出对应ImageNet的1000个类别。

VGG-16包含13个卷积层和3个全连接层,总计16层(池化层不计入层数)。每个卷积块的通道数逐渐增加(从64到512),而池化层将特征图尺寸逐步减半(从224×224到7×7)。

1.3 优点与局限性

优点

  • 结构简单,易于实现和理解。
  • 小卷积核和深层设计提高了特征提取能力。
  • 在多种视觉任务中表现出色,可作为预训练模型迁移学习。

局限性

  • 参数量巨大(VGG-16约有1.38亿个参数),训练和推理耗时。
  • 深层网络可能导致梯度消失问题(尽管ReLU和适当初始化缓解了部分问题)。
  • 对内存和计算资源要求较高,不适合资源受限的设备。

1.4 本文的实现目标

在本文中,我们将基于PyTorch实现一个简化的VGG网络,针对Fashion-MNIST数据集(28×28灰度图像,10个类别)进行调整。我们保留VGG的模块化思想,但适当减少层数和参数量,以适应较小规模的数据和计算资源。通过代码实践,读者可以深入理解VGG的设计原理及其在实际任务中的应用。

下一节将进入具体的代码实现部分,逐步搭建VGG网络并完成训练。

二、搭建VGG网络

2.1 数据准备

在开始构建VGG网络之前,我们需要准备训练和测试数据。这里使用Fashion-MNIST数据集,这是一个包含10类服装图像的灰度图像数据集,每个图像大小为28×28像素。以下是数据加载的代码:

import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessingdef get_dataloader_workers():"""使用电脑支持的最大进程数来读取数据"""return multiprocessing.cpu_count()def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中。参数:batch_size (int): 每个数据批次的大小。resize (int, 可选): 图像的目标尺寸。如果为 None,则不调整大小。返回:tuple: 包含训练 DataLoader 和测试 DataLoader 的元组。"""# 定义变换管道trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)# 加载 Fashion-MNIST 训练和测试数据集mnist_train = torchvision.datasets.FashionMNIST(root="./data",train=True,transform=trans,download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data",train=False,transform=trans,download=True)# 返回 DataLoader 对象return (data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))

这段代码定义了load_data_fashion_mnist函数,用于加载Fashion-MNIST数据集并将其封装成PyTorch的DataLoader对象。transforms.ToTensor()将图像转换为张量格式,batch_size控制每个批次的数据量,shuffle=True确保训练数据随机打乱以提高模型泛化能力。num_workers通过多进程加速数据加载。

2.2 定义VGG块

VGG网络的核心思想是将网络分解为多个“块”(block),每个块包含若干卷积层和一个池化层。以下是VGG块的实现:

import torch
from torch import nndef vgg_block(num_convs, in_channels, out_channels):layers = []                          # 初始化一个空列表,用于存储网络层for _ in range(num_convs):           # 循环 num_convs 次,构建卷积层layers.append(nn.Conv2d(         # 添加一个二维卷积层in_channels,                 # 输入通道数out_channels,                # 输出通道数kernel_size=3,               # 卷积核大小为 3x3padding=1))                  # 填充大小为 1,保持特征图尺寸layers.append(nn.ReLU())         # 添加 ReLU 激活函数in_channels = out_channels       # 更新输入通道数为输出通道数,用于下一次卷积layers.append(nn.MaxPool2d(          # 添加一个最大池化层kernel_size=2,                   # 池化核大小为 2x2stride=2))                       # 步幅为 2,缩小特征图尺寸return nn.Sequential(*layers)        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/75064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Baklib激活企业知识管理新动能

Baklib核心技术架构解析 Baklib的底层架构以模块化设计为核心,融合知识中台的核心理念,通过分布式存储引擎与智能语义分析系统构建三层技术体系。数据层采用多源异构数据接入协议,支持文档、音视频、代码片段等非结构化数据的实时解析与分类…

小智机器人中的部分关键函数,FreeRTOS中`xEventGroupWaitBits`函数的详细解析

以下是对FreeRTOS中xEventGroupWaitBits函数的详细解析: 函数功能 xEventGroupWaitBits用于在事件组中等待指定的位被设置。它可以配置为等待任意一个位或所有位,并支持超时机制。 注意:该函数不能在中断中调用。 函数原型 EventBits_t xEv…

关注分离(Separation of Concerns)在前端开发中的实践演进:从 XMLHttpRequest 到 Fetch API

关注分离(Separation of Concerns)在前端开发中的实践演进:从 XMLHttpRequest 到 Fetch API 一、关注分离的核心价值 关注分离(SoC)是软件工程领域的重要设计原则,强调将系统分解为不同维度的功能模块&am…

C之(16)scan-build与clang-tidy使用

C之(16)scan-build与clang-tidy使用 Author: Once Day Date: 2025年3月29日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章可参考专栏: Linux实践记录_Once_da…

在 Vue 项目中快速集成 Vant 组件库

目录 引言一、找到 src 下的App.js 写入代码。二、安装Vant三、解决 polyfill 问题四、查看依赖五、配置webpack六、引入 Vant七、在组件中使用 Vant八、在浏览器中查看样式总结 引言 在开发移动端 Vue 项目时,选择一个高效、轻量且功能丰富的组件库是提升开发效率…

“GPU 挤不动了?”——聊聊基于 GPU 的计算资源管理

“GPU 挤不动了?”——聊聊基于 GPU 的计算资源管理 作者:Echo_Wish “老板:为什么 GPU 服务器卡得跟 PPT 一样?” “运维:我们任务队列爆炸了,得优化资源管理!” 在 AI 训练、深度学习、科学计算的场景下,GPU 计算资源已经成为香饽饽。但 GPU 服务器贵得离谱,一台 A…

AI渗透测试:网络安全的“黑魔法”还是“白魔法”?

引言:AI渗透测试,安全圈的“新魔法师” 想象一下,你是个网络安全新手,手里攥着一堆工具,正准备硬着头皮上阵。这时,AI蹦出来,拍着胸脯说:“别慌,我3秒扫完漏洞&#xff0…

(二)GEE基础学习初探及案例详解【20250330】

Google Earth Engine(GEE)是由谷歌公司开发的众多应用之一。借助谷歌公司超强的服务器运算能力以及与NASA的合作关系,GEE平台将Landsat、MODIS、Sentinel等可以公开获取的遥感图像数据存储在谷歌的磁盘阵列中,使得GEE用户可以方便的提取、调用和分析海量…

redhat认证是永久的吗

​认证有效期 ​红帽认证一般有效期为3年​(如RHCSA、RHCE、RHCA等),从通过考试之日起计算。 ​例外:部分基础或工程师认证(如Red Hat Certified Engineer)有效期为三年时间,以官方最新政策为准…

git --- cherry pick

git --- cherry pick cherry pick cherry pick Cherry Pick 是 Git 中的一个操作,它允许你选择某个分支的某次(或多次)提交,并将其应用到当前分支,而不会合并整个分支的所有更改。 cherry pick 的作用 只提取某个特定的…

妙用《甄嬛传》中的选妃来记忆概率论中的乘法公式

强烈推荐最近在看的不错的B站概率论课程 《概率统计》正课,零废话,超精讲!【孔祥仁】 《概率统计》正课,零废话,超精讲!【孔祥仁】_哔哩哔哩_bilibili 其中概率论中的乘法公式,老师用了《甄嬛传…

AI 的出现是否能替代 IT 从业者?

AI 的出现是否能替代 IT 从业者? AI 的快速发展正在深刻改变各行各业,IT 行业也不例外。然而,AI 并非完全替代 IT 从业者,而是与其形成互补关系。本文将从 AI 的优势、IT 从业者的不可替代性、未来趋势等方面,探讨 AI…

【leetcode100】有效的括号

1、题目描述 给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效。 有效字符串需满足: 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有一个对应的…

为什么使用Flask + uWSGI + Nginx 部署服务?

概述 在Python开发的web应用中,我们通常能够看到flask、uWSGI、Nginx出现在一起,他们之间的关系是什么?为什么总是被应用在一起?  三者共同使用为了实现一个目的:客户端向服务端发送数据请求,服…

接口等幂处理

介绍 ✅ 什么是等幂(Idempotency)? 等幂 无论这个操作被执行多少次,结果都是一样的,不会因为多次执行而产生副作用。 通俗一点说:“点一次和点一百次,效果是一样的。” ✅ 在接口中&#xff0…

P1090合并果子(优先队列)

洛谷题目 这里使用的是优先队列&#xff0c;非常简单 首先让我们一起来学习一下优先队列&#xff08;默认是从大到小来排列&#xff09; 首先要使用头文件 #include<queue> using namespace std; 然后声明有限队列 priority_queue<int> a; priority_queue&…

蓝桥杯备考---->并查集之 Lake Counting

这道题就统计有多少个连通块就行了 这时候我们又需要把二维转成一维了&#xff0c;也就是把每一个格子都给一个编号 当我们合并连通块的时候&#xff0c;其实是只需要四个方向的 因为我们是从上往下遍历的&#xff0c;我们遍历到某个位置的时候&#xff0c;它已经和上面部分…

React受控表单绑定

受控表单绑定 在 React 中&#xff0c;受控组件&#xff08;Controlled Component&#xff09;是指表单元素的值由 React 组件的 state 管理&#xff0c;React 通过 onChange 事件监听输入变化&#xff0c;并实时更新 state&#xff0c;从而控制表单输入值。 为什么要使用受控…

8、linux c 信号机制

一、信号概述 1. 信号概念 信号是一种在软件层次上对中断机制的模拟&#xff0c;是一种异步通信方式。信号的产生和处理都由操作系统内核完成&#xff0c;用于在进程之间传递信息或通知某些事件的发生。 2. 信号的产生 信号可以通过以下方式产生&#xff1a; 按键产生&…

CSP-J 2019 入门级 第一轮(初赛) 完善程序(2)

【题目】 CSP-J 2019 入门级 第一轮&#xff08;初赛&#xff09; 完善程序&#xff08;2&#xff09; &#xff08;计数排序&#xff09;计数排序是一个广泛使用的排序方法。下面的程序使用双关键字计数排序&#xff0c;将n对10000 以内的整数&#xff0c;从小到大排序。 例如…