分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):super().__init__()self.ps = patch_sizeself.c1 = Conv2D(in_c, out_c)encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)self.r1 = residual_block(out_c, out_c)def forward(self, inputs):x = self.c1(inputs)b, c, h, w = x.shapenum_patches = (h*w)//(self.ps**2)x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))x = self.te(x)x = torch.reshape(x, (b, c, h, w))x = self.c2(x)s = self.c3(inputs)x = self.relu(x + s)x = self.r1(x)return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):def __init__(self, in_c, out_c):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c),nn.LeakyReLU(negative_slope=0.1, inplace=True),nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c))self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),nn.BatchNorm2d(out_c))self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)def forward(self, inputs):x = self.conv(inputs)s = self.shortcut(inputs)return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):def __init__(self):super().__init__()""" Encoder """backbone = resnet50()self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)self.layer2 = backbone.layer2self.layer3 = backbone.layer3self.layer4 = backbone.layer4self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)""" Decoder """self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)self.r1 = residual_transformer_block(64+64, 64, dim=64)self.r2 = residual_transformer_block(64+64, 64, dim=256)self.r3 = residual_block(64+64, 64)""" Classifier """self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)def forward(self, inputs):""" Encoder """x0 = inputsx1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]e1 = self.e1(x1)e2 = self.e2(x2)e3 = self.e3(x3)e4 = self.e4(x4)""" Decoder """x = self.up(e4)x = torch.cat([x, e3], axis=1)x = self.r1(x)x = self.up(x)x = torch.cat([x, e2], axis=1)x = self.r2(x)x = self.up(x)x = torch.cat([x, e1], axis=1)x = self.r3(x)x = self.up(x)""" Classifier """outputs = self.outputs(x)return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734825.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch搭建LeNet训练集详细实现

一、下载训练集 导包 import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as npToTensor()函数: 把图像…

git学习(创建项目提交代码)

操作步骤如下 git init //初始化git remote add origin https://gitee.com/aydvvs.git //建立连接git remote -v //查看git add . //添加到暂存区git push 返送到暂存区git status // 查看提交代码git commit -m初次提交git push -u origin "master"//提交远程分支 …

微信小程序(五十二)开屏页面效果

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.使用控件模拟开屏界面 2.倒计时逻辑 3.布局方法 4.TabBar隐藏复现 源码&#xff1a; components/openPage/openPage.wxml <view class"openPage-box"><image src"{{imagePath}}"…

20个Python中列表(list)最常用的方法和函数。

本篇文章中,我们分别介绍10个Python中列表(list)最常用的方法和函数。 列表方法示例 1. 创建空列表和使用 append() 方法 # 创建一个空列表 my_list = []# 使用 append() 方法在列表末尾添加元素 my_list.append(5) my_list.append(10) print("After append:",…

离线数仓建设

一.数据仓库分层 ODS(Operation Data Store)层&#xff1a;原始数据层&#xff0c;存放加载原始日志、数据&#xff0c;数据保持原貌不做处理。 DWD(Data warehouse detail)层&#xff1a;对ODS层数据进行清洗&#xff08;去除空值&#xff0c;超过极限范围的数据&#xff09;、…

三维不同坐标系下点位姿态旋转平移变换

文章目录 前言正文计算方法思路Python实现总结前言 本文主要说明以下几种场景3D变换的应用: 3D相机坐标系下长方体物体,有本身坐标系,沿该物体长边方向移动一段距离,并绕长边轴正旋转方向转90度,求解当前物体中心点在相机坐标系下的位置和姿态多关节机器人末端沿工具坐标…

介绍Android UI绘制过程以及注意事项

Android UI绘制是一个复杂的过程&#xff0c;它涉及到多个步骤&#xff0c;从测量&#xff08;measure&#xff09;到布局&#xff08;layout&#xff09;再到绘制&#xff08;draw&#xff09;。以下是这个过程的简要介绍以及一些注意事项&#xff1a; 1. **测量&#xff08;…

计算机网络-网络应用服务器(四)

1.Samba服务器&#xff1a; Samba是Linux上实现和Windows系统局域网上共享文件和打印机的一种通信协议&#xff0c;由服务器及客户端程序构成。支持SMB/CIFS协议&#xff0c;实现共享资源。最主要的一个配置文件smb.conf&#xff0c;可以使用vi编辑器修改。守护进程&#xff1a…

STM32 利用FlashDB库实现在线扇区数据管理不丢失

STM32 利用FlashDB库实现在线扇区数据管理不丢失 &#x1f4cd;FalshDB地址:https://gitee.com/Armink/FlashDB ✨STM32没有片内EEPROM这样的存储区&#xff0c;虽然有备份寄存器&#xff0c;仅可以实现对少量数据的频繁存储&#xff0c;但是依赖备份电源&#xff08;BAT引脚&a…

美国签证|附面签相关事项√

小伙伴最近都忙着办签证吧&#xff01;但是需要注意的是&#xff0c;美国的签证跟其他任何国家的签证不同&#xff0c;并不是办理了就一定拿得到&#xff0c;据说概率是50%左右。所以办理美国签证&#xff0c;不要太着急啦&#xff01;先来了解一下美国签证的相片该怎么拍叭 ✅…

题目 2073: [STL训练]亲和串

题目描述: 人随着岁数的增长是越大越聪明还是越大越笨&#xff0c;这是一个值得全世界科学家思考的问题,同样的问题Eddy也一直在思考&#xff0c;因为他在很小的时候就知道亲和串如何判断了&#xff0c;但是发现&#xff0c;现在长大了却不知道怎么去判断亲和串了&#xff0c;…

RocketMQ的事务消息流程

什么是事务消息&#xff1f; 事务消息是一种在发送方和接收方之间保证消息传递的一致性和可靠性的消息传递机制。在消息发送过程中&#xff0c;生产者可以将消息发送到消息队列&#xff0c;但不会立即被消费者接收和处理。相反&#xff0c;消息会先进入一种“准备”状态&#x…

用chatgpt写insar地质灾害的论文,重复率只有1.8%,chatgpt4.0写论文不是梦

突发奇想&#xff0c;想用chatgpt写一篇论文&#xff0c;并看看查重率&#xff0c;结果很惊艳&#xff0c;说明是确实可行的&#xff0c;请看下图。 下面是完整的文字内容。 InSAR (Interferometric Synthetic Aperture Radar) 地质灾害监测技术是一种基于合成孔径雷达…

【JavaScript】JavaScript 变量 ① ( JavaScript 变量概念 | 变量声明 | 变量类型 | 变量初始化 | ES6 简介 )

文章目录 一、JavaScript 变量1、变量概念2、变量声明3、ES6 简介4、变量类型5、变量初始化 二、JavaScript 变量示例1、代码示例2、展示效果 一、JavaScript 变量 1、变量概念 JavaScript 变量 是用于 存储数据 的 容器 , 通过 变量名称 , 可以 获取 / 修改 变量 中的数据 ; …

第十五届蓝桥杯模拟赛(第三期)

大家好&#xff0c;我是晴天学长&#xff0c;本次分享&#xff0c;制作不易&#xff0c;本次题解只用于学习用途&#xff0c;如果有考试需要的小伙伴请考完试再来看题解进行学习&#xff0c;需要的小伙伴可以点赞关注评论一波哦&#xff01;蓝桥杯省赛就要开始了&#xff0c;祝…

【DimPlot】【FeaturePlot】使用小tips

目录 DimPlot函数参数解析 栅格化点图 放大 ggplot2 图例的点&#xff0c;修改图例的标题 FeaturePlot函数参数解析 调整FeaturePlot颜色 分组绘制featureplot 随手笔记&#xff0c;持续更新中。。。 Reference DimPlot函数参数解析 object: 一个Seurat对象&#xff0c;…

工作纪实46-关于微服务的上线发布姿势

蓝绿部署 在部署时&#xff0c;不需要将旧版本的服务停掉&#xff0c;而是将新版本与旧版本同时运行&#xff0c;新版本测试无误之后再将旧版本停掉。这样可以避免再升级的过程中如果失败服务不可用的问题&#xff0c;因为同时部署了两个版本的程序&#xff0c;使得硬件资源是…

【项目笔记】java微服务:黑马头条(day01)

文章目录 环境搭建、SpringCloud微服务(注册发现、服务调用、网关)1)课程对比2)项目概述2.1)能让你收获什么2.2)项目课程大纲2.3)项目概述2.4)项目术语2.5)业务说明 3)技术栈4)nacos环境搭建4.1)虚拟机镜像准备4.2)nacos安装 5)初始工程搭建5.1)环境准备5.2)主体结构 6)登录6.1…

Ubuntu用扩展分区加载home目录步骤

如果你想要将新的磁盘挂载到默认的 /home 目录下&#xff0c;可以按照以下步骤进行操作&#xff1a; 创建挂载点&#xff1a; 首先&#xff0c;确保新磁盘已连接并识别。然后&#xff0c;创建一个临时挂载点&#xff0c;以便将新磁盘挂载到该点。sudo mkdir /mnt/new_home挂载磁…

JavaScript中的Set和Map:理解与使用

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…