Pytorch从零开始实战16

Pytorch从零开始实战——ResNeXt-50算法的思考

本系列来源于365天深度学习训练营

原作者K同学

对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。

引入头文件

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels, kernel_size=(3, 3),strides=strides, padding='same', use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merage = concatenate(group_list, axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merage)x = ReLU()(x)return x

残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)# epsilon为BN公式中防止分母为零的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:# identity_shortcutshortcut = x# 三层卷积层x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 进行分组卷积x = grouped_convolution_block(x, strides, groups, g_channels)x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x, shortcut])x = ReLU()(x)return x

堆叠残差单元

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

网络搭建

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224,224,3]->[230,230,3]x = ZeroPadding2D((3, 3))(inputs)x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1, 1))(x)x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)# 堆叠残差结构x = stack(x, filters=128, blocks=2, strides=1)x = stack(x, filters=256, blocks=3, strides=2)x = stack(x, filters=512, blocks=5, strides=2)x = stack(x, filters=1024, blocks=2, strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes, activation='softmax')(x)# 定义模型model = Model(inputs=inputs, outputs=x)return model

对于残差单元中的代码,提出一个问题:当conv_shortcut=False的时候,在执行Add操作时,理论上通道数不一致,为什么代码不报错?
在这里插入图片描述
答:这主要是跟下面堆叠残差单元的代码有关系,每个stack第一轮总会令conv_shortcut为True,使得x通道数进行扩展,而后面循环的时候传入的filters还是这个函数的实参,没有发生变化,但由于conv_shortcut为False,此时shortcut的通道数是与上面的x一致,所以在Add的时候,代码不会报错。

def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

本文只是对ResNeXt-50算法的部分代码进行思考,学习过程中需要积极思考与探索,以提高能力和解决问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/611420.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rust基础类型之布尔类型和字符

布尔类型 Rust 中的布尔类型为 bool,仅仅有两个值,true 和 false。比如下方代码: let flag1 true; let flag2: bool false;字符 Rust 中的字符类型是 char,值用单引号括起来。 fn main() {let char1 z;let char2: char ℤ…

TensorRt(5)动态尺寸输入的分割模型测试

文章目录 1、固定输入尺寸逻辑2、动态输入尺寸2.1、模型导出2.2、推理测试2.3、显存分配问题2.4、完整代码 这里主要说明使用TensorRT进行加载编译优化后的模型engine进行推理测试,与前面进行目标识别、目标分类的模型的网络输入是固定大小不同,导致输入…

Java单元测试 1.SpringBoot 2.普通java工程

1)SpringBoot工程 1.注意:新建SpringBoot工程时,一定要根据你当前的jdk版本选择好SpringBoot的版本,比如:jdk为11,则不能创建3.x的SpringBoot工程,有时候2021版本的idea强制我创建3.x版本的SpringBoot&…

【打卡】牛客网:BM79 打家劫舍(二)

资料: dp.clear()会把dp的size变为0。 assign和insert的对比: v1.assign(v2.begin(), v2.end()); v1.insert(pos,n,elem); //在pos位置插入n个elem数据,无返回值。 v1.insert(pos,beg,end); //在pos位置插入[beg,end)区间的数据&#xff0…

【现代密码学】笔记3.4-3.7--构造安全加密方案、CPA安全、CCA安全 《introduction to modern cryphtography》

【现代密码学】笔记3.4-3.7--构造安全加密方案、CPA安全、CCA安全 《introduction to modern cryphtography》 写在最前面私钥加密与伪随机性 第二部分流加密与CPA多重加密 CPA安全加密方案CPA安全实验、预言机访问(oracle access) 操作模式伪随机函数PR…

Java微服务系列之 ShardingSphere - ShardingSphere-JDBC

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 系列专栏目录 [Java项…

复习python从入门到实践——类

复习python从入门到实践——类 学好继承,你就初步学习了程序面向对象啦~本章我依旧为大家整理好了运行的步骤,以及举了尽可能清晰简单的例子为大家展示。 目录 复习python从入门到实践——类1. Syntax注意: 2.修改属性值3.继承4…

报错解决:No module named ‘pytorch_lightning‘ 安装pytorch_lightning

报错记录 执行如下代码: import pytorch_lightning报错: No module named ‘pytorch_lightning’ 解决方式 安装pytorch_lightning包即可。 一般情况下,缺失的包通过pip安装,即: pip install pytorch_lightning然…

C++性能优化简单总结

什么样的代码是高度优化的? 我们先出去数据结构和算法本身的使用。C 的高效代码通常是利用了各种编译器优化和语言特性来最大程度地提高执行效率和资源利用率的代码。我们需要编写编译器友好的代码来让编译器优化或者说编写出不用编译器高度优化优化也能达到同样效果…

1 快速前端开发

1 前端开发 目的:开发一个平台(网站)- 前端开发:HTML、CSS、JavaScript- Web框架:接收请求并处理- MySQL数据库:存储数据地方快速上手:基于Flask Web框架让你快速搭建一个网站出来。1.快速开发…

HarmonyOS应用开发学习笔记 应用上下文Context 获取文件夹路径

1、 HarmoryOS Ability页面的生命周期 2、 Component自定义组件 3、HarmonyOS 应用开发学习笔记 ets组件生命周期 4、HarmonyOS 应用开发学习笔记 ets组件样式定义 Styles装饰器:定义组件重用样式 Extend装饰器:定义扩展组件样式 5、HarmonyOS 应用开发…

14-股票K线图功能-个股日K线SQL分析__ev

需求:统计个股日K线数据,也就是把某只股票每天的最高价,开盘价,收盘价,最低价形成K线图。

山西电力市场日前价格预测【2024-01-11】

日前价格预测 预测说明: 如上图所示,预测明日(2024-01-11)山西电力市场全天平均日前电价为231.43元/MWh。其中,最高日前电价为422.21元/MWh,预计出现在18:00。最低日前电价为0.00元/MWh,预计出…

现代软件测试中的自动化测试工具

自动化测试的重要性和优势 引言:随着软件开发的不断发展,自动化测试工具在现代软件测试中扮演着重要角色。提高效率:自动化测试可以加快测试流程,减少人工测试所需的时间和资源。提升准确性:自动化测试工具可以减少人…

二分图最大匹配算法:匈牙利、KM

文章目录 基础定义匹配二分图二分图的矩阵覆盖交错路与增广路匈牙利算法饱和X的匹配不管X、Y求最大匹配KM算法可行顶点标号、相等子图相等子图的若干性质KM算法的正确性基于以下定理:算法流程描述1描述2基础定义 匹配 匹配:给定一个无向图 G = < V , E > G=<V,E>…

基于云平台技术的车外视频隐私合规的浅谈

基于云平台技术的车外视频隐私合规创新旨在确保车外视频数据的合法、合规使用&#xff0c;同时保护个人隐私不受侵犯。以下是基于云平台技术的车外视频隐私合规的创新实践和考虑因素&#xff1a; 实践&#xff1a; 数据采集&#xff1a;对车外视频数据进行采集时&#xff0c;…

PACS医学影像报告管理系统源码带CT三维后处理技术

PACS从各种医学影像检查设备中获取、存储、处理影像数据&#xff0c;传输到体检信息系统中&#xff0c;生成图文并茂的体检报告&#xff0c;满足体检中心高水准、高效率影像处理的需要。 自主知识产权&#xff1a;拥有完整知识产权&#xff0c;能够同其他模块无缝对接 国际标准…

Linux CentOS 7.6安装JDK详细保姆级教程

一、检查系统是否自带jdk java --version 如果有的话&#xff0c;找到对应的文件删除 第一步&#xff1a;先查看Linux自带的JDK有几个&#xff0c;用命令&#xff1a; rpm -qa | grep -i java第二步:删除JDK&#xff0c;执行命令&#xff1a; rpm -qa | grep -i java | xarg…

ubuntu设置ssh登录,设置公钥无密登录

sudo apt-get install openssh-server vim sudo vim /etc/ssh/sshd_config 修改这几行 PubkeyAuthentication yes #指定公钥数据库文件 AuthorsizedKeysFile.ssh/authorized_keys #PasswordAuthentication yes 改为 PasswordAuthentication no 然后重启ssh服务 systemctl res…

企业的 Android 移动设备管理 (MDM) 解决方案

移动设备管理可帮助您在不影响最终用户体验的情况下&#xff0c;通过无线方式管理和保护组织的移动设备群&#xff0c;现代 MDM 解决方案还可以控制 App、内容和安全性&#xff0c;因此员工可以毫无顾虑地在托管设备上工作。移动设备管理软件可有效管理个人设备上的公司空间。M…