使用 PyTorch 实现标准卷积神经网络(CNN)

卷积神经网络(CNN)是深度学习中的重要组成部分,广泛应用于图像处理、语音识别、视频分析等任务。在这篇博客中,我们将使用 PyTorch 实现一个标准的卷积神经网络(CNN),并介绍各个部分的作用。

什么是卷积神经网络(CNN)?

卷积神经网络(CNN)是一种专门用于处理图像数据的深度学习模型,它通过卷积层提取图像的特征。CNN 由多个层次组成,其中包括卷积层(Conv2d)、池化层(MaxPool2d)、全连接层(Linear)、激活函数(ReLU)等。这些层级合作,使得模型能够从原始图像中自动学习到重要特征。

CNN 的核心组成部分

  1. 卷积层(Conv2d):用于提取输入图像的局部特征,通过多个卷积核对图像进行卷积运算。
  2. 激活函数(ReLU):增加非线性,使得模型能够学习更复杂的特征。
  3. 池化层(MaxPool2d):通过对特征图进行下采样来减少空间尺寸,降低计算复杂度,同时保留重要的特征。
  4. 全连接层(Linear):将卷积和池化后得到的特征图展平,送入全连接层进行分类或回归预测。

PyTorch 实现 CNN

下面是我们实现的标准卷积神经网络模型。它包含三个卷积层和两个全连接层,适用于图像分类任务,如 MNIST 数据集。

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1: 输入1个通道(灰度图像),输出32个通道self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)# 卷积层2: 输入32个通道,输出64个通道self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)# 卷积层3: 输入64个通道,输出128个通道self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)# 全连接层1: 输入128*7*7,输出1024个节点self.fc1 = nn.Linear(128 * 7 * 7, 1024)# 全连接层2: 输入1024个节点,输出10个节点(假设是10分类问题)self.fc2 = nn.Linear(1024, 10)# Dropout层: 避免过拟合self.dropout = nn.Dropout(0.5)def forward(self, x):# 第一层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)  # 使用2x2的最大池化# 第二层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)# 第三层卷积 + ReLU 激活 + 最大池化x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)# 展平层(将卷积后的特征图展平成1D向量)x = x.view(-1, 128 * 7 * 7)  # -1代表自动推算batch size# 第一个全连接层 + ReLU 激活 + Dropoutx = F.relu(self.fc1(x))x = self.dropout(x)# 第二个全连接层(输出最终分类结果)x = self.fc2(x)return x# 创建CNN模型
model = CNN()# 打印模型架构
print(model)

代码解析

  1. 卷积层(Conv2d)

    • self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1):该层的输入为 1 个通道(灰度图像),输出 32 个通道,卷积核大小为 3x3,步幅为 1,填充为 1,保持输出特征图的大小与输入相同。
    • 后续的卷积层类似,只是输出通道数量逐渐增多。
  2. 激活函数(ReLU)

    • F.relu(self.conv1(x)):ReLU 激活函数将输入的负值转为 0,并保留正值,增加了模型的非线性。
  3. 池化层(MaxPool2d)

    • F.max_pool2d(x, 2, 2):使用 2x2 的池化窗口和步幅为 2 进行池化,将特征图尺寸缩小一半,减少计算复杂度。
  4. 展平(Flatten)

    • x = x.view(-1, 128 * 7 * 7):在经过卷积和池化操作后,我们将多维的特征图展平成一维向量,供全连接层输入。
  5. Dropout

    • self.dropout = nn.Dropout(0.5):Dropout 正则化技术在训练时随机丢弃一些神经元,防止过拟合。
  6. 全连接层(Linear)

    • self.fc1 = nn.Linear(128 * 7 * 7, 1024):第一个全连接层的输入是卷积后得到的特征,输出 1024 个节点。
    • self.fc2 = nn.Linear(1024, 10):最后的全连接层将 1024 个节点压缩为 10 个输出,代表分类结果。

训练 CNN 模型

要训练该模型,我们需要加载一个数据集、定义损失函数和优化器,然后进行训练。以下是如何使用 MNIST 数据集进行训练的示例。

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 100 == 99:  # 每100个batch输出一次损失print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {running_loss / 100:.4f}')running_loss = 0.0print("Finished Training")

训练过程说明

  • 数据加载器(DataLoader):用于批量加载训练数据,支持数据的随机打乱(shuffle)。
  • 损失函数(CrossEntropyLoss):用于多分类问题,计算预测和真实标签之间的交叉熵损失。
  • 优化器(Adam):Adam 优化器自适应调整学习率,通常在深度学习中表现良好。
  • 训练循环:每个 epoch 处理整个数据集,通过前向传播、计算损失、反向传播和优化步骤,更新网络参数。

总结

在这篇文章中,我们实现了一个标准的卷积神经网络(CNN),并使用 PyTorch 对其进行了定义和训练。通过使用卷积层、池化层和全连接层,模型能够自动学习图像的特征并进行分类。我们还介绍了如何训练模型、加载数据集以及使用常见的优化器和损失函数。希望这篇文章能帮助你理解 CNN 的基本架构及其实现方式!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/70209.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot2.0整合Redis(Lettuce版本)

前言: 目前java操作redis的客户端有jedis跟Lettuce。在springboot1.x系列中,其中使用的是jedis, 但是到了springboot2.x其中使用的是Lettuce。 因为我们的版本是springboot2.x系列,所以今天使用的是Lettuce。关于jedis跟lettuce的区别&#…

qt + opengl 给立方体增加阴影

在前几篇文章里面学会了通过opengl实现一个立方体,那么这篇我们来学习光照。 风氏光照模型的主要结构由3个分量组成:环境(Ambient)、漫反射(Diffuse)和镜面(Specular)光照。下面这张图展示了这些光照分量看起来的样子: 1 环境光照(Ambient …

大模型工具大比拼:SGLang、Ollama、VLLM、LLaMA.cpp 如何选择?

简介:在人工智能飞速发展的今天,大模型已经成为推动技术革新的核心力量。无论是智能客服、内容创作,还是科研辅助、代码生成,大模型的身影无处不在。然而,面对市场上琳琅满目的工具,如何挑选最适合自己的那…

stream流常用方法

1.reduce 在Java中,可以使用Stream API的reduce方法来计算一个整数列表的乘积。reduce方法是一种累积操作,它可以将流中的元素组合起来,返回单个结果。对于计算乘积,你需要提供一个初始值(通常是1,因为乘法…

pgAdmin4在mac m1上面简单使用(Docker)

问题 想要在本地简单了解一下pgAdmin4一些简单功能。故需要在本机先安装看一看。 安装步骤 拉取docker镜像 docker pull dpage/pgadmin4直接简单运行pgAdmin4 docker run --name pgAdmin4 -p 5050:80 \-e "PGADMIN_DEFAULT_EMAILuserdomain.com" \-e "PGAD…

ubuntu下安装TFTP服务器

在 Ubuntu 系统下安装和配置 TFTP(Trivial File Transfer Protocol)服务器可以按照以下步骤进行: 1. 安装 TFTP 服务器软件包 TFTP 服务器通常使用 tftpd-hpa 软件包,你可以使用以下命令进行安装: sudo apt update …

Softing线上研讨会 | 自研还是购买——用于自动化产品的工业以太网

| 线上研讨会时间:2025年1月27日 16:00~16:30 / 23:00~23:30 基于以太网的通信在工业自动化网络中的重要性日益增加。设备制造商正面临着一大挑战——如何快速、有效且经济地将工业以太网协议集成到其产品中。其中的关键问题包括:是否只需集成单一的工…

vscode创建java web项目

一.项目部署 1.shiftctrlp,选择java项目 2.选择maven create from arcetype 3.选择webapp 4.目录结构如下,其中index.jsp是首页 5.找到左下角的servers,添加tomcat服务器 选择 再选择: 找到你下载的tomcat 的bin目录的上一级目录&#x…

C语言指针学习笔记

1. 指针的定义 指针(Pointer)是存储变量地址的变量。在C语言中,指针是一种非常重要的数据类型,通过指针可以直接访问和操作内存。 2. 指针的声明与初始化 2.1 指针声明 指针变量的声明格式为:数据类型 *指针变量名…

DeepSeek R1生成图片总结2(虽然本身是不能直接生成图片,但是可以想办法利用别的工具一起实现)

DeepSeek官网 目前阶段,DeepSeek R1是不能直接生成图片的,但可以通过优化文本后转换为SVG或HTML代码,再保存为图片。另外,Janus-Pro是DeepSeek的多模态模型,支持文生图,但需要本地部署或者使用第三方工具。…

什么是Dubbo?Dubbo框架知识点,面试题总结

本篇包含什么是Dubbo,Dubbo的实现原理,节点角色说明,调用关系说明,在实际开发的场景中应该如何选择RPC框架,Dubbo的核心架构,Dubbo的整体架构设计及分层。 主页还有其他的面试资料,有需要的可以…

kafka消费能力压测:使用官方工具

背景 在之前的业务场景中,我们发现Kafka的实际消费能力远低于预期。尽管我们使用了kafka-go组件并进行了相关测试,测试情况见《kafka-go:性能测试》这篇文章。但并未能准确找出消费能力低下的原因。 我们曾怀疑这可能是由我的电脑网络带宽问题或Kafka部…

【大学生职业规划大赛备赛PPT资料PDF | 免费共享】

自取链接: 链接:https://pan.quark.cn/s/4fa45515325e 📢 同学,你是不是正在为职业规划大赛发愁? 想展示独特思路却不知如何下手? 想用专业模板却找不到资源? 别担心!我整理了全网…

ubuntu20动态修改ip,springboot中yaml的内容的读取,修改,写入

文章目录 前言引入包yaml原始内容操作目标具体代码执行查看结果总结: 前言 之前有个需求,动态修改ubuntu20的ip,看了下: 本质上是修改01-netcfg.yaml文件,然后执行netplan apply就可以了。 所以,需求就变成了 如何对ya…

【算法】双指针(下)

目录 查找总价格为目标值的两个商品 暴力解题 双指针解题 三数之和 双指针解题(左右指针) 四数之和 双指针解题 双指针关键点 注意事项 查找总价格为目标值的两个商品 题目链接:LCR 179. 查找总价格为目标值的两个商品 - 力扣(LeetCode&#x…

Windows 图形显示驱动开发-IoMmu 模型

输入输出内存管理单元 (IOMMU) 是一个硬件组件,它将支持具有 DMA 功能的 I/O 总线连接到系统内存。 它将设备可见的虚拟地址映射到物理地址,使其在虚拟化中很有用。 在 WDDM 2.0 IoMmu 模型中,每个进程都有一个虚拟地址空间,即&a…

软件测评报告包括哪些内容?第三方软件测评机构推荐

在当今信息技术飞速发展的时代,软件的品质与性能直接影响到企业的运营效率和市场竞争力。为了确保软件的可用性和可靠性,软件测评成为一个不可或缺的环节,软件测评报告也是对软件产品进行全面评估后形成的一份文档,旨在系统地纪录…

深浅拷贝区别,怎么区别使用

在 JavaScript 中,深拷贝(Deep Copy) 和 浅拷贝(Shallow Copy) 是两种不同的对象复制方式,它们的区别主要体现在对嵌套对象的处理上。以下是它们的详细对比及使用场景: 1. 浅拷贝(Sh…

tailscale + derp中继 + 阿里云服务器 (无域名版)

使用tailscale默认的中转节点延迟很高,因为服务器都在国外。 感谢大佬提供的方案:Tailscale 搭建derp中继节点,不需要域名,不需要备案,不需要申请证书(最新) - yafeng - 博客园 基于这个方案&…

【异常错误】pycharm debug view变量的时候显示不全,中间会以...显示

异常问题: 这个是在新版的pycharm中出现的,出现的问题,点击view后不全部显示,而是以...折叠显示 在setting中这么设置一下就好了: 解决办法: https://youtrack.jetbrains.com/issue/PY-75568/Large-stri…