分类神经网络3:DenseNet模型复现

目录

DenseNet网络架构

DenseNet部分实现代码


DenseNet网络架构

论文原址:https://arxiv.org/pdf/1608.06993.pdf

稠密连接神经网络(DenseNet)实质上是ResNet的进阶模型(了解ResNet模型请点击),二者均是通过建立前面层与后面层之间的“短路连接”,但不同的是,DenseNet建立的是前面所有层与后面层的密集连接,其一大特点是通过特征在通道上的连接来实现特征重用,这让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。DenseNet 网络的模型结构如下:

DenseNet 的网络结构主要由DenseBlockTransition Layer组成。

DenseBlock:密集连接机制。互相连接所有的层,即每一层的输入都来自于它前面所有层的特征图,每一层的输出均会直接连接到它后面所有层的输入,这可以实现特征重用(即对不同“级别”的特征——不同表征进行总体性地再探索),提升效率。具体的连接方式如下图示:

在同一个DenseBlock当中,特征层的高宽不会发生改变,但是通道数会发生改变可以看出DenseBlock中采用了BN+ReLU+Conv的结构,然而一般网络是用Conv+BN+ReLU的结构。这是由于卷积层的输入包含了它前面所有层的输出特征,它们来自不同层的输出,因此数值分布差异比较大,所以它们在输入到下一个卷积层时,必须先经过BN层将其数值进行标准化,然后再进行卷积操作。通常为了减少参数,一般还会先加一个1x1 卷积来减少参数量。所以DenseBlock中的每一层采用BN+ReLU+1x1Conv 、Conv+BN+ReLU+3x3 Conv的结构。

Transition Layer:用于将不同DenseBlock之间进行连接,整合上一个DenseBlock获得的特征,并且缩小上一个DenseBlock的宽高,达到下采样的效果,实质上起到压缩模型的作用。Transition Layer中一般包含一个1x1卷积(用于调整通道数)和2x2平均池化(用于降低特征图大小),结构为BN+ReLU+1x1 Conv+2x2 AvgPooling

DenseNet网络的具体配置信息如下:

可以看出,一个DenseNet中一般有3个或4个DenseBlock,最后的DenseBlock后连接了一个最大池化层,然后是一个包含1000个类别的全连接层,通过softmax激活函数得到类别属性。

DenseNet部分实现代码

直接上干货

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict__all__ = ["densenet121", "densenet161", "densenet169", "densenet201"]class DenseLayer(nn.Module):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient = False):super(DenseLayer,self).__init__()self.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)self.drop_rate = float(drop_rate)self.memory_efficient = memory_efficientdef bn_function(self, inputs):concated_features = torch.cat(inputs, 1)bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))return bottleneck_outputdef any_requires_grad(self, input):for tensor in input:if tensor.requires_grad:return Truereturn False@torch.jit.unuseddef call_checkpoint_bottleneck(self, input):def closure(*inputs):return self.bn_function(inputs)return cp.checkpoint(closure, *input)def forward(self, input):if isinstance(input, torch.Tensor):prev_features = [input]else:prev_features = inputif self.memory_efficient and self.any_requires_grad(prev_features):if torch.jit.is_scripting():raise Exception("Memory Efficient not supported in JIT")bottleneck_output = self.call_checkpoint_bottleneck(prev_features)else:bottleneck_output = self.bn_function(prev_features)new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return new_featuresclass DenseBlock(nn.ModuleDict):def __init__(self,num_layers,num_input_features,bn_size,growth_rate,drop_rate,memory_efficient = False,):super(DenseBlock,self).__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.add_module("denselayer%d" % (i + 1), layer)def forward(self, init_features):features = [init_features]for name, layer in self.items():new_features = layer(features)features.append(new_features)return torch.cat(features, 1)class Transition(nn.Sequential):"""Densenet Transition Layer:1 × 1 conv2 × 2 average pool, stride 2"""def __init__(self, num_input_features, num_output_features):super(Transition,self).__init__()self.norm = nn.BatchNorm2d(num_input_features)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)class DenseNet(nn.Module):def __init__(self,growth_rate = 32,num_init_features = 64,block_config = None,num_classes = 1000,bn_size = 4,drop_rate = 0.,memory_efficient = False,):super(DenseNet,self).__init__()# First convolutionself.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),]))# Each denseblocknum_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.features.add_module("denseblock%d" % (i + 1), block)num_features = num_features + num_layers * growth_rateif i != len(block_config) - 1:trans = Transition(num_input_features=num_features, num_output_features=num_features // 2)self.features.add_module("transition%d" % (i + 1), trans)num_features = num_features // 2# Final batch normself.features.add_module("norm5", nn.BatchNorm2d(num_features))# Linear layerself.classifier = nn.Linear(num_features, num_classes)# Official init from torch repo.for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.relu(features, inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))out = torch.flatten(out, 1)out = self.classifier(out)return outdef densenet121(num_classes):"""Densenet-121 model"""return DenseNet(32, 64, (6, 12, 24, 16),num_classes=num_classes)def densenet161(num_classes):"""Densenet-161 model"""return DenseNet(48, 96, (6, 12, 36, 24),  num_classes=num_classes)def densenet169(num_classes):"""Densenet-169 model"""return DenseNet(32, 64, (6, 12, 32, 32),   num_classes=num_classes)def densenet201(num_classes):"""Densenet-201 model"""return DenseNet(32, 64, (6, 12, 48, 32), num_classes=num_classes)if __name__=="__main__":# from torchsummaryX import summarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = densenet121(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)# summary(net, torch.ones((1, 3, 224, 224)).to(device))

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/829714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java面试八股文-2024

面试指南 TMD,一个后端为什么要了解那么多的知识,真是服了。啥啥都得了解 MySQL MySQL索引可能在以下几种情况下失效: 不遵循最左匹配原则:在联合索引中,如果没有使用索引的最左前缀,即查询条件中没有包含…

Altera FPGA 配置flash读写

目录 一、读写控制器的配置 二、生成flash的配置文件 三、关于三种配置文件的大小 四、其他 一、读写控制器的配置 Altera ASMI Parallel(下文简称ASMI)这个IP就仅仅是个Flash读写控制器,可以自由的设计数据来源。 关于这个IP的使用,可以…

MAC有没有免费NTFS tuxera激活码 tuxera破解 tuxera for mac2023序列号直装版 ntfs formac教程

Tuxera NTFS 2023破解版是一款非常好用的在线磁盘读写工具,该软件允许mac用户在Windows NTFS格式的硬盘上进行读写操作,Mac的文件系统是HFS,而Windows则使用NTFS格式,这导致在Mac系统上不能直接读写Windows格式的硬盘。然而&#…

程序员:写好代码就行了,为什么要学写作

🍁 展望:关注我, AI 学习之旅上,我与您一同成长! 一、引言 在当今这个信息爆炸的时代,程序员们往往沉浸在代码的世界里,用代码来解决问题。然而,你是否曾想过,除了代码,…

INSTEAD OF 触发器的创建

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 INSTEAD OF 触发器,也称替换触发器,是一种特殊的触发器,和其他建立在数据表上的触发器不同,INSTEAD OF 触发器建立在视图上。…

Podman入门全指南:安装、配置与运行容器

欢迎来到我的博客,代码的世界里,每一行都是一个故事 Podman入门全指南:安装、配置与运行容器 前言Podman简介什么是 Podman?Podman 与 Docker 的主要区别 安装Podman支持的操作系统和环境安装步骤详解LinuxUbuntuCentOS/RHEL MacO…

双系统下删除ubuntu

絮絮叨叨 由于我在安装Ubuntu的时候没有自定义安装位置,而是使用与window共存的方式让Ubuntu自己选择安装位置,导致卸载时我不知道去格式化哪个分区,查阅多方资料后无果,后在大佬帮助下找到解决方案 解决步骤 1、 插上Ubuntu安…

Axure如何调起浏览器的打印功能

Axure如何调起浏览器的打印功能 答:javascript:window.print(); 不明白的继续往下看 应用场景: 原型设计中,页面上的打印按钮,需要模拟操作演示,需要点击指定的按钮时,唤起浏览器的打印功能&#xff08…

使用Pandas从Excel文件中提取满足条件的数据并生成新的文件

目录 一、引言 二、环境准备 三、读取Excel文件 四、数据筛选 五、保存为新的Excel文件 六、案例与代码总结 七、进阶用法与注意事项 八、结语 在数据处理的日常工作中,我们经常需要从大量数据中筛选出满足特定条件的数据集。Pandas是一个强大的Python数据分…

比 PSD.js 更强的下一代 PSD 解析器,支持 WebAssembly

比 PSD.js 更强的下一代 PSD 解析器,支持 WebAssembly 1.什么是 webtoon/ps webtoon/ps 是 Typescript 中轻量级 Adobe Photoshop .psd/.psb 文件解析器,对 Web 浏览器和 NodeJS 环境提供支持,且做到零依赖。 Fast zero-dependency PSD par…

2024 年最好的免费数据恢复软件,您可以尝试的几个数据恢复软件

由于系统崩溃而丢失数据可能会给用户带来麻烦。我们将重要的宝贵数据和个人数据保存在我们的 PC、笔记本电脑和其他数字设备上。您可能会因分区丢失、意外删除文件和文件夹、格式化硬盘驱动器而丢失数据。数据丢失是不幸的,如果您不小心从系统中删除了文件或数据&am…

深入理解 Srping IOC

什么是 Spring IOC? IOC 全称:Inversion of Control,翻译为中文就是控制反转,IOC 是一种设计思想,IOC 容器是 Spring 框架的核心,它通过控制和管理对象之间的依赖关系来实现依赖注入(Dependenc…

正点原子[第二期]ARM(I.MX6U)裸机篇学习笔记-1.2

前言: 本文是来自哔哩哔哩网站上视频“正点原子[第二期]Linux之ARM(MX6U)裸机篇”的学习笔记,在这里会记录下正点原子Linux ARM MX6ULL 开发板根据配套的哔哩哔哩学习视频所作的实验和笔记内容。本文大量的引用了正点原子哔哔哩网…

结构体内存对齐(未完成版)

前言 我们已经掌握了结构体的基本使用了。 现在我们深入讨论一个问题:计算机构体的大小。 这也是一个特别热门的考点:结构体内存对齐 练习导入 对齐规则

vue项目npm run build 打包之后如何在本地访问

vue项目npm run build 打包之后如何在本地访问 如果直接访问时,则会报错如下的信息: 报错码: Access to script at file:///D:/assets/index-DDVBfHVo.js from origin null has been blocked by CORS policy: Cross origin requests are on…

【转载】如何在MacBookPro上把Ubuntu安装到移动硬盘里过程记录

以下主要目的是记录安装过程中的问题,安装步骤等信息怕忘记 环境信息: Mac :macOS High Sierra 10.13.6 内存8G(Swap时用到) Ubuntu: ubuntu-22.04.4-desktop-amd64.ios 金士顿U盘:Kingston-64G 烧录软件:balenaEtcher…

牛客NC371 验证回文字符串(二)【简单 双指针 C++/Java/Go/PHP】

题目 题目链接: https://www.nowcoder.com/practice/130e1a9eb88942239b66e53ec6e53f51 思路 直接看答案,不难参考答案C class Solution {public:/*** 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可…

Atlassian Jira 信息泄露漏洞(CVE-2019-3403) 排查思路

Atlassian Jira: 企业广泛使用的项目与事务跟踪工具,被广泛应用于缺陷跟踪、客户服务、需求收集、流程审批、任务跟踪、项目跟踪和敏捷管理等工作领域。 简述: 近日发现多个内网IP触发的Atlassian Jira 信息泄露漏洞的告警。 告警的检测规…

openvoice v2 声音克隆使用案例

参考: https://github.com/myshell-ai/OpenVoice/blob/main/docs/USAGE.md https://www.wehelpwin.com/article/4940 安装 1)下载OpenVoice项目安装 2)MeloTTS安装 参考:https://blog.csdn.net/weixin_42357472/article/details/136320097 pip install git+https://gith…

2398.预算内最多的机器人数目

我第一个手搓的hard的单调队列题目......灵神yyds 思路解析: 我做的时候感觉这个题目有点歧义,我以为他的连续运行是时间上连续,所以我开始写的代码是选择最多的子序列(可以不连续),使得不超过budget,这个求最多子序列的代码会在最后给出,不保证完全正确(因为没有太多测试点),…