NiNNet

目录

一、网络介绍

1、全连接层存在的问题

2、NiN的解决方案(NiN块)

3、NiN架构

4、总结

二、代码实现

1、定义NiN卷积块

2、NiN模型

3、训练模型


一、网络介绍

       NiN(Network in Network)是一种用于图像识别任务的卷积神经网络模型。它由谷歌研究员Min Lin、Qiang Chen和Shouyuan Chen于2013年提出。NiN的设计理念是通过引入“网络中的网络”结构来增强模型的表示能力。

1、全连接层存在的问题

       在之前的网络(比如AlexNet和VGGNet)后面都用了几个比较大的全连接层,全连接层中的参数相比于卷积层多得多,一个网络的参数大多都在全连接层,并且可以认为主要分布在卷积层之后的第一个全连接层。因此全连接层最大的问题是可能造成过拟合。

2、NiN的解决方案(NiN块)

       NiN的核心思想是使用1x1卷积层替代传统的全连接层。传统的卷积神经网络通常使用卷积层提取特征,然后通过全连接层进行分类。而NiN则在卷积层中引入了一种称为“1x1卷积”的操作,这个操作可以看作是在每个像素点上进行的全连接操作。通过使用1x1卷积,NiN能够在卷积层中引入非线性,增加模型的表达能力,并且减少了参数的数量。

       和VGG一样,NiN也有自己的块(NiN块),每一个NiN块其实就相当于一个小的神经网络(因为它具有卷积层和类似于全连接层的 $1 \times 1$ 卷积层),因此叫网络中的网络。NiN块首先有一个卷积层,然后后跟两个 $1 \times 1$ 的卷积层($1 \times 1$ 的卷积层等价于全连接层)。

3、NiN架构

全局池化层:池化层的高和宽等于输入的高和宽,一个通道得出一个值,用这个值当作对类别的预测。

4、总结

二、代码实现

       NiN的想法是将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)。下图说明了VGG和NiN及它们的块之间主要架构差异。NiN块以一个普通卷积层开始,后面是两个 $1 \times 1$ 的卷积层。NiN块第一层的卷积窗口形状通常由用户设置。随后的卷积窗口形状固定为 $1 \times 1$

1、定义NiN卷积块

import torch
from torch import nn
from d2l import torch as d2ldef nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

2、NiN模型

       最初的NiN网络是在AlexNet后不久提出的,显然从中得到了一些启示。NiN使用窗口形状为$11\times 11$$5\times 5$ 和 $3\times 3$ 的卷积层,输出通道数量与AlexNet中的相同。每个NiN块后有一个最大池化层,池化窗口形状为 $3\times 3$,步幅为2。

       NiN和AlexNet之间的一个显著区别是NiN完全取消了全连接层。相反,NiN使用一个个NiN块,最后一个NiN块的输出通道数等于标签类别的数量。最后放一个全局平均池化层(global average pooling layer),生成一个对数几率(logits)。NiN设计的一个优点是,它显著减少了模型所需参数的数量。然而,在实践中,这种设计有时会增加训练模型的时间。

net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, strides=1, padding=1),    # 通道数先增加后减少:1->96->256->384->10nn.AdaptiveAvgPool2d((1, 1)),   # 注意这里的(1, 1)不是kernel_size,而是output_size# 将四维的输出转成二维的输出,其形状为(批量大小, 10)nn.Flatten())   # Flatten会把channel、height和width展平成一行

       我们创建一个数据样本来查看每个块的输出形状。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

3、训练模型

       我们使用Fashion-MNIST来训练模型。训练NiN与训练AlexNet、VGG时相似。

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224) # 调节图片尺寸为224
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.563, train acc 0.786, test acc 0.790
3087.6 examples/sec on cuda:0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240909.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【电路笔记】-串联电容器

串联电容器 文章目录 串联电容器1、概述2、示例13、示例34、总结 当电容器以菊花链方式连接在一条线上时,它们就串联在一起。 1、概述 对于串联电容器,流过电容器的充电电流 ( i C i_C iC​ ) 对于所有电容器来说都是相同的,因为它只有一条…

matlab实践(十一):导弹追踪

1.题目 a9.94,x062.06 2.方程 我们有: ( d x d t ) 2 ( d y d t ) 2 w 2 (\frac{\mathrm d\mathrm x}{\mathrm d\mathrm t})^2(\frac{\mathrm d\mathrm y}{\mathrm d\mathrm t})^2\mathrm w^2 (dtdx​)2(dtdy​)2w2 还有导弹始终指向船 ( d x d t d y d t ) …

【华为OD题库-106】全排列-java

题目 给定一个只包含大写英文字母的字符串S,要求你给出对S重新排列的所有不相同的排列数。如:S为ABA,则不同的排列有ABA、AAB、BAA三种。 解答要求 时间限制:5000ms,内存限制:100MB 输入描述 输入一个长度不超过10的字符串S,确保都是大写的。…

【快速开发】使用SvelteKit

自我介绍 做一个简单介绍,酒架年近48 ,有20多年IT工作经历,目前在一家500强做企业架构.因为工作需要,另外也因为兴趣涉猎比较广,为了自己学习建立了三个博客,分别是【全球IT瞭望】,【…

【Mode Management】CanSM详细介绍

目录 1. Introduction and functional overview 2.Dependencies to other modules 3.Functional specification 3.1General requirements 3.2State machine for each CAN network 3.2.1Trigger: PowerOn 3.2.2Trigger: CanSM_Init 3.2.3 Trigger: CanSM_DeInit 3.2.4 …

机器学习 | 概率图模型

见微知著,睹始知终。 见到细微的苗头就能预知事物的发展方向,能透过微小的现象看到事物的本质,推断结论或者结果。 概率模型为机器学习打开了一扇新的大门,将学习的任务转变为计算变量的概率分布。 实际情况中,各个变量…

单词接龙[中等]

一、题目 字典wordList中从单词beginWord和endWord的 转换序列 是一个按下述规格形成的序列beginWord -> s1 -> s2 -> ... -> sk&#xff1a; 1、每一对相邻的单词只差一个字母。 2、对于1 < i < k时&#xff0c;每个si都在wordList中。注意&#xff0c;beg…

Midjourney V6版本的5大新特性,掌握了,想法和实现信手拈来

Midjourney v6已推出&#xff1a;更简单的提示、增强的文本集成和更高水平的照片真实感&#xff01;以下是每个创意人员都需要了解的 5 个重要见解。 一、产品文字整合 使用简单风格提示向您的产品添加文本提示&#xff1a;带有文字“SALMA”的白色健身瓶 Midjourney v5.2&am…

Git安装和使用教程,并以gitee为例实现远程连接远程仓库

文章目录 1、Git简介及安装2、使用方法2.1、Git的启动与配置2.2、基本操作2.2.1、搭建自己的workspace2.2.2、git add2.2.3、git commit2.2.4、忽略某些文件不予提交2.2.5、以gitee为例实现git连接gitee远程仓库来托管代码 1、Git简介及安装 版本控制&#xff08;Revision cont…

只用10分钟,ChatGPT就帮我写了一篇2000字文章

有了ChatGPT之后&#xff0c;于我来说&#xff0c;有两个十分明显的变化&#xff1a; 1. 人变的更懒 因为生活、工作中遇到大大小小的事情&#xff0c;都可以直接找ChatGPT来寻求答案。 2. 工作产出量更大 之前花一天&#xff0c;甚至更久才能写一篇原创内容&#xff0c;现…

qt简单连接摄像头

要使用摄像头&#xff0c;就需要链接多媒体模块以及多媒体工具模块 需要在.pro文件中添加QT multimedia multimediawidgets 是用的库文件 QCamera 类用于打开系统的摄像头设备&#xff0c; QCameraViewfinder 用于显示捕获的视频&#xff0c; QCameraImageCapt…

Java并发工具类---ForkJoin、countDownlatch、CyclicBarrier、Semaphore

一、Fork Join fork join是JDK7引入的一种并发框架&#xff0c;采用分而治之的思想来处理并发任务 ForkJoin框架底层实现了工作窃取&#xff0c;当一个线程完成任务处于空闲状态时&#xff0c;会窃取其他工作线程的任务来做&#xff0c;这样可以充分利用线程来进行并行计算&a…

2023.11.24 信息学日志

2023.11.24 信息学日志 1. CF1700D River Locks题目描述题目概况思路点拨 1. CF1700D River Locks 题目描述 https://www.luogu.com.cn/problem/CF1700D 题目概况 来源&#xff1a;Codeforces 洛谷难度&#xff1a; 绿题 \color{green}绿题 绿题 CF难度&#xff1a; 1900…

系列十四、SpringBoot + JVM参数配置实战调优

一、SpringBoot JVM参数配置实战调优 1.1、概述 前面的系列文章大篇幅的讲述了JVM的内存结构以及各种参数&#xff0c;今天就使用SpringBoot项目实战演示一下&#xff0c;如何进行JVM参数调优&#xff0c;如果没有阅读过前面系列文章的朋友&#xff0c;建议先阅读后再看本篇文…

Java整合APNS推送消息-IOS-APP(基于.p12推送证书)

推送整体流程 1.在开发者中心申请对应的证书&#xff08;我用的是.p12文件&#xff09; 2.苹果手机用户注册到APNS&#xff0c;APNS将注册的token返回给APP&#xff08;服务端接收使用&#xff09;。 3.后台服务连接APNS&#xff0c;获取连接对象 4.后台服务构建消息载体 5.后台…

C++面向对象(OOP)编程-位运算详解

本文主要介绍原码、位运算的种类&#xff0c;以及常用的位运算的使用场景。 目录 1 原码、反码、补码 2 有符号和无符号数 3 位运算 4 位运算符使用规则 4.1 逻辑移位和算术移位 4.1.1 逻辑左移和算法左移 4.1.2 逻辑右移和算术右移 4.1.3 总结 4.2 位运算的应用场景 …

Searching for MobileNetV3(2019)

文章目录 Abstract主要内容实验结果 IntroductionRelated WorkEfficient Mobile Building BlocksNetwork SearchPlatform-Aware NAS for Block-wise SearchNetAdapt for Layer-wise Search Network ImprovementsRedesigning Expensive LayersNonlinearitiesLarge squeeze-and-e…

JAVA设计模式(三)-原型

JAVA设计模式(三)-原型 本篇文章主要讲下java 创建型设计模式中的原型模式. 何谓原型模式: 简单来说就是 将一个对象作为原型&#xff0c;通过对其进行复制而克隆出多个和原型类似的新实例。 使用原型模式,就可以简化实例化的过程, 不必依赖于构造函数或者new关键字. 由于j…

蓝桥杯-每日刷题-027

出租汽车计费器 一、题目要求 题目描述 有一个城市出租汽车的计费规则是3公里内&#xff08;含3公里&#xff09;基本费6元&#xff0c;超过3公里&#xff0c;每一公里1.4元。 现在对于输入具体的公里数x&#xff08;0<x<1000&#xff09;&#xff0c;编程计算x公里所需…

PHP-Xlswriter高性能导出Excel

使用背景 使用传统的PHPExcel导出效率太慢&#xff0c;并且资源占用高&#xff0c;数据量大的情况&#xff0c;会导致服务占用大量的资源&#xff0c;从而导致生产意味&#xff0c;再三思索后&#xff0c;决定使用其他高效率的导出方式 PHP-Xlswriter PHPExcel 因为内存消耗过…