CNN 卷积神经网络处理图片任务 | PyTorch 深度学习实战

前一篇文章,学习率调整策略 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

CNN 卷积神经网络

  • CNN
    • 什么是卷积
    • 工作原理
      • 深度学习的卷积运算
      • 提取特征
      • 不同特征核的效果比较
      • 卷积核
      • 感受野
      • 共享权重
      • 池化
    • 示例源码
  • Links

CNN

什么是卷积

【通信原理 入坑之路】——深入、详细地理解通信里面“卷积”概念

卷积,首先是一种数学运算。两个多项式通过滑动,求解多项式参数。

在这里插入图片描述
深度学习的卷积概念,就是借鉴了通信领域使用了卷积。跨学科运用知识,一直是大牛们的惯用手段。掌握人类已经精通的领域的经验,然后推广到前沿领域。

工作原理

利用卷积操作实现平移、扭曲情况下,依然能识别特征

图片是一个二维数据,如果只是利用全连接网络,那么数据的二维特征就丢失了,原始的物理信息丢失了。比如,同一个人出现在不同的照片中,很可能是在不同的位置,作为同样的一张人脸,当其出现在图片中的不同位置1,都可以正确的识别和分类呢?

深度学习的卷积运算

深度学习领域的卷积,参考文章。

卷积核是一个小矩阵,在输入矩阵上,滑动。
在这里插入图片描述

最终得到一个新的 output 矩阵。
在这里插入图片描述

提取特征

因为这种运算,Output 实际上代表了卷积核 Kernel 作用于 Input 后过滤出来的特征。每一个卷积核,就是一个过滤器,从源图片中,提取特定的形状。为了理解这一点,看下面这张图。

在这里插入图片描述

以黑白两个颜色,实现卷积运算,最终输入图片里和特征核(Single filter)重叠的部分得到了加强,和特征核不一致的部分得到了抑制。

不同特征核的效果比较

当特征核变大,增加多个特征提取器,那么就可以识别一张图片上的特征组,从而判定图片中包含的物体的分类。

  • 左侧是运算符,中间是对应的特征核,右侧是输出的图片

在这里插入图片描述
在这里插入图片描述
当然,计算机不是【看图】,而是通过卷积后的矩阵,从数字上去检查分类。当输出的矩阵组成一个全连接,使用目标的标注数据,计算出损失,就可以学习分类的权重,实现分类的效果。

卷积核

卷积核,也称为特征提取器,后者的名字更加的形象,特征提取器类似于通信领域的滤波器。

感受野

感受野(Receptive Field)的定义是卷积神经网络每一层输出的特征图(feature map)上的像素点在输入图片上映射的区域大小。参考文章

在这里插入图片描述

共享权重

使用同一个特征核过滤图片,也就是一个特征核对于一个图片上的多个感受野,特征核的矩阵不变。

使用梯度下降原理更新参数时,参数包括了每个卷积核,虽然一个卷积核是滑动在多个感受野得到输出矩阵的,但是特征核更新时,不会针对单独的某个感受野。

对于一个卷积神经网络,都包括哪些参数,参考文章。

池化

经过多个卷积核以后,维度更多,虽然因为保留了重要的特征信息,但是会远远的大于分类信息,在加入最后的全连接层之前,还需要浓缩一下信息,类似于结晶。

这个操作就是池化,比如常用的最大池化,方法如下:

在这里插入图片描述

示例源码

下面以一段 PyTorch 代码为例,使用卷积神经网络完成图片分类任务。

'''
CNN Model
'''
import torch
import torchvision.datasets as ds
import torchvision.transforms as ts
from torch.utils.data import DataLoader
from torch.autograd import Variable
import randomtorch.manual_seed(777)# reproducibility# parameters
batch_size=100
learning_rate=0.001
epochs=2# MNIST dataset
ds_train=ds.MNIST(root='../../../DATA/MNIST_data',train=True,transform=ts.ToTensor(),download=True)
ds_test=ds.MNIST(root='../../../DATA/MNIST_data',train=False,transform=ts.ToTensor(),download=True)
# dataset loader
dl=DataLoader(dataset=ds_train,batch_size=batch_size,shuffle=True)# CNN Model (2 conv layers)
class CNN(torch.nn.Module):def __init__(self):super(CNN,self).__init__()# L1 ImgIn shape=(?, 28, 28, 1)#    Conv     -> (?, 28, 28, 32)#    Pool     -> (?, 14, 14, 32)self.layer1=torch.nn.Sequential(torch.nn.Conv2d(1,32,kernel_size=3,stride=1,padding=1),#padding=1进行0填充torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2))# L2 ImgIn shape=(?, 14, 14, 32)#    Conv      ->(?, 14, 14, 64)#    Pool      ->(?, 7, 7, 64)self.layer2=torch.nn.Sequential(torch.nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2))# Final FC 7x7x64 inputs -> 10 outputsself.fc=torch.nn.Linear(7*7*64,10)torch.nn.init.xavier_uniform(self.fc.weight)def forward(self,x):out=self.layer1(x)out=self.layer2(out)out=out.view(out.size(0),-1)# Flatten them for FCout=self.fc(out)return out# instantiate CNN model
model=CNN()# define cost/loss & optimizer
criterion=torch.nn.CrossEntropyLoss()# Softmax is internally computed.
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)# train my model
print('Learning started. It takes sometime.')
for epoch in range(epochs):avg_cost=0total_batch=len(ds_train)//batch_sizefor step,(batch_xs,batch_ys) in enumerate(dl):x=Variable(batch_xs)#[100, 1, 28, 28] image is already size of (28x28), no reshapey=Variable(batch_ys)#[100] label is not one-hot encodedoptimizer.zero_grad()h=model(x)cost=criterion(h,y)cost.backward()optimizer.step()avg_cost+=cost/total_batchprint(epoch+1,avg_cost.item())
print('Learning Finished!')# Test model and check accuracy
model.eval()#!!将模型设置为评估/测试模式 set the model to evaluation mode (dropout=False)# x_test=ds_test.test_data.view(len(ds_test),1,28,28).float()
x_test=ds_test.test_data.view(-1,1,28,28).float()
y_test=ds_test.test_labelspre=model(x_test)print("pre.data=")
print(pre.data)
print("*"*3)pre=torch.max(pre.data,1)[1].float()
acc=(pre==y_test.data.float()).float().mean()
print("acc", acc)r=random.randint(0,len(x_test)-1)
x_r=x_test[r:r+1]
y_r=y_test[r:r+1]
pre_r=model(x_r)# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
# https://discuss.pytorch.org/t/indexerror-dimension-out-of-range-expected-to-be-in-range-of-1-0-but-got-1/54267/12
print("pre_r.data=")
print(pre_r.data)
print("*"*3)pre_r=torch.max(pre_r.data,-1)[1].float()
print('pre_r')
print(pre_r)acc_r=(pre_r==y_r.data).float().mean()
print(acc_r)

Links

  • 卷积神经网络中感受野的详细介绍
  • 感受野详解
  • 【通信原理 入坑之路】——深入、详细地理解通信里面“卷积”概念
  • How to calculate the number of parameters in CNN?
  • 【深度学习】人人都能看得懂的卷积神经网络——入门篇

  1. 图片相关任务,包括图片分类、物体检测、实例分割、目标跟踪等。这些任务有不同的功能,但是都依赖于图片中包含的特征,这些特征都是可能平移、变幻、扭曲的。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/69382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.1 学习UVM中的uvm_component类分为几步?

文章目录 前言一、定义1.1 角色和功能:1.2 与其他UVM类的区别:1.3 主要属性和方法: 二、使用方法2.1 定义和实例化:2.2 生命周期管理:2.3 组件间通信: 三、何时使用3.1 使用场景3.2 适用组件3.3 与uvm_obje…

谷云科技RestCloud全面接入DeepSeek 开启智能新时代

在数字化转型的浪潮中,谷云科技始终走在数据集成与智能应用领域的前沿。近期,随着 DeepSeek 的火爆出圈,谷云科技紧跟技术趋势,对旗下两大核心产品 —— 数据集成软件 ETLCloud 和 AI Agent 智能体构建平台进行了重大升级&#xf…

Kafka 入门与实战

一、Kafka 基础 1.1 创建topic kafka-topics.bat --bootstrap-server localhost:9092 --topic test --create 1.2 查看消费者偏移量位置 kafka-consumer-groups.bat --bootstrap-server localhost:9092 --describe --group test 1.3 消息的生产与发送 #生产者 kafka-cons…

FFmpeg 与 FFplay 参数详解:-f、-pix_fmt、-pixel_format 和 -video_size 的区别与用法

FFmpeg 与 FFplay 参数详解:-f、-pix_fmt、-pixel_format 和 -video_size 的区别与用法 在使用 FFmpeg 和 FFplay 进行视频处理和播放时,-f、-pix_fmt、-pixel_format 和 -video_size 是常用的参数。这些参数的作用和使用场景略有不同,理解它们的区别和用法对于正确处理和播…

即时通讯开源项目OpenIM配置离线推送全攻略

如何进行二次开发 如果您需要基于 OpenIM 开发新特性,首先要确定是针对业务侧还是即时通讯核心逻辑。 由于 OpenIM 系统本身已经做好了比较多的抽象,大部分聊天的功能已经具备了,不建议修改 IM 本身。 如果需要增加 IM 的能力,可以…

深度解析:网站快速收录与网站内容更新频率的关系

本文转自:百万收录网 原文链接:https://www.baiwanshoulu.com/97.html 网站快速收录与网站内容更新频率之间存在着密切的关系。以下是对这一关系的深度解析: 一、内容更新频率对网站快速收录的影响 提高收录速度 定期发布新内容会促使搜索…

【个人开发】macbook m1 Lora微调qwen大模型

本项目参考网上各类教程整理而成,为个人学习记录。 项目github源码地址:Lora微调大模型 项目中微调模型为:qwen/Qwen1.5-4B-Chat。 去年新发布的Qwen/Qwen2.5-3B-Instruct同样也适用。 微调步骤 step0: 环境准备 conda create --name fin…

c++计算机教程

目的 做出-*/%计算机 要求 做出可以计算-*/%的计算机 实现 完整代码 #include<bits/stdc.h> int main() {std::cout<<"加 减- 乘* 除/ 取余% \没有了|(因为可以算三位)"<<"\n"<<"提示:每打完一个符号或打完一个数,\…

了解Linux 中 make 与 Makefile

目录 一、为什么开发者需要构建工具&#xff1f; 二、make/Makefile 1. Makefile基本规则 2.清理项目 三、make的工作原理 一、为什么开发者需要构建工具&#xff1f; 在软件开发中&#xff0c;我们经常面临这样的场景&#xff1a;一个项目包含数十个源代码文件&#xff…

RK3568中,使用cmake搭建C++工程进行RGA开发

在 RK3568 平台上使用 C 配合 RGA (Raster Graphics Acceleration) 进行图像加速开发&#xff0c;以下是详细的配置步骤和示例&#xff1a; 1. 环境准备 安装 RK3568 SDK 确保已安装 Rockchip 官方提供的 SDK&#xff08;如 Linux SDK&#xff09;&#xff0c;RGA 头文件和库通…

win11右击显示全部

正常&#xff1a; 输入&#xff1a; reg.exe add "HKCU\Software\Classes\CLSID\{86ca1aa0-34aa-4e8b-a509-50c905bae2a2}\InprocServer32" /f /ve 重启或刷新进程 刷新&#xff1a; taskkill /f /im explorer.exe & start explorer.exe 成功&#xff1a;

Redis基础--常用数据结构的命令及底层编码

零.前置知识 关于时间复杂度,按照以下视角看待. redis整体key的个数 -- O(N)当前key对应的value中的元素个数 -- O(N)当前命令行中key的个数 -- O(1) 一.string 1.1string类型常用命令 1.2string类型内部编码 二.Hash 哈希 2.1hash类型常用命令 2.2hash类型内部编码 2.3ha…

React 设计模式:实用指南

React 提供了众多出色的特性以及丰富的设计模式&#xff0c;用于简化开发流程。开发者能够借助 React 组件设计模式&#xff0c;降低开发时间以及编码的工作量。此外&#xff0c;这些模式让 React 开发者能够构建出成果更显著、性能更优越的各类应用程序。 本文将会为您介绍五…

SpringBoo项目标准测试样例

文章目录 概要Controller Api 测试源码单元测试集成测试 概要 Spring Boot项目测试用例 测试方式是否调用数据库使用的注解特点单元测试&#xff08;Mock Service&#xff09;❌ 不调用数据库WebMvcTest MockBean只测试 Controller 逻辑&#xff0c;速度快集成测试&#xff0…

Unity扩展编辑器使用整理(一)

准备工作 在Unity工程中新建Editor文件夹存放编辑器脚本&#xff0c; Unity中其他的特殊文件夹可以参考官方文档链接&#xff0c;如下&#xff1a; Unity - 手册&#xff1a;保留文件夹名称参考 (unity3d.com) 一、菜单栏扩展 1.增加顶部菜单栏选项 使用MenuItem&#xff…

Vue3+codemirror6实现公式(规则)编辑器

实现截图 实现/带实现功能 插入标签 插入公式 提示补全 公式验证 公式计算 需要的依赖 "codemirror/autocomplete": "^6.18.4","codemirror/lang-javascript": "^6.2.2","codemirror/state": "^6.5.2","cod…

K8S QoS等级

在 Kubernetes (K8S) 中&#xff0c;QoS&#xff08;Quality of Service&#xff0c;服务质量&#xff09;等级用于定义 Pod 在资源调度和管理过程中的优先级&#xff0c;确保在资源紧张时能够更好地管理和分配资源。Kubernetes 根据 Pod 的资源请求和限制将 Pod 分为三种 QoS …

4.PPT:日月潭景点介绍【18】

目录 NO1、2、3、4​ NO5、6、7、8 ​ ​NO9、10、11、12 ​ 表居中或者水平/垂直居中单元格内容居中或者水平/垂直居中 NO1、2、3、4 新建一个空白演示文稿&#xff0c;命名为“PPT.pptx”&#xff08;“.pptx”为扩展名&#xff09;新建幻灯片 开始→版式“PPT_素材.doc…

如何在macOS上安装Ollama

安装Ollama 安装Ollama的步骤相对简单&#xff0c;以下是基本的安装指南&#xff1a; 访问官方网站&#xff1a;打开浏览器&#xff0c;访问Ollama的官方网站。 下载安装包&#xff1a;根据你的操作系统&#xff0c;选择相应的安装包进行下载。 运行安装程序&#xff1a;下载完…

开源项目介绍-词云生成

开源词云项目是一个利用开源技术生成和展示词云的工具或框架&#xff0c;广泛应用于文本分析、数据可视化等领域。以下是几个与开源词云相关的项目及其特点&#xff1a; Stylecloud Stylecloud 是一个由 Maximilianinir 创建和维护的开源项目&#xff0c;旨在通过扩展 wordclou…