UNet 改进(5):结合SE模块提升图像分割性能

U-Net是医学图像分割领域最成功的架构之一,其对称的编码器-解码器结构和跳跃连接使其能够有效捕捉多尺度特征。本文将解析一个改进版的U-Net实现,该版本通过引入Squeeze-and-Excitation(SE)模块进一步提升了模型性能。

一、架构概览

这个改进的U-Net保持了经典U-Net的核心结构,但在每个卷积块后添加了SE模块,主要包含以下几个关键组件:

  1. SE注意力模块:增强重要通道的特征响应

  2. 双卷积块:基础特征提取单元

  3. 编码器-解码器结构:逐步下采样和上采样

  4. 跳跃连接:结合低层和高层特征

二、核心组件详解

1. SE注意力模块 (SELayer)

class SELayer(nn.Module):def __init__(self, in_channels, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels, bias=False),nn.Sigmoid())

SE模块通过以下步骤工作:

  1. 使用全局平均池化将空间信息压缩为一个通道描述符

  2. 通过两个全连接层学习通道间的依赖关系

  3. 使用Sigmoid激活生成通道权重

  4. 将权重应用于原始特征图

这种机制让模型能够自适应地强调重要特征通道,抑制不重要的通道。

2. 改进的双卷积块 (DoubleConv)

class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),SELayer(out_channels)  # 添加 SE 模块)

每个双卷积块包含:

  • 两个3×3卷积层,保持空间分辨率(padding=1)

  • 每个卷积后接批量归一化和ReLU激活

  • 最后添加SE模块进行通道注意力加权

3. 完整的改进U-Net (ImprovedUNet)

编码器部分通过最大池化逐步下采样,解码器部分通过转置卷积上采样,并结合跳跃连接:

class ImprovedUNet(nn.Module):def __init__(self, n_channels, n_classes):# 初始化各层...def forward(self, x):# 编码过程x1 = self.inc(x)       # 初始卷积x2 = self.down1(x1)    # 下采样1x3 = self.down2(x2)    # 下采样2x4 = self.down3(x3)    # 下采样3x5 = self.down4(x4)    # 下采样4# 解码过程x = self.up1(x5)x = self.double_conv_up1(torch.cat([x, F.interpolate(x4, size=x.shape[2:])], dim=1))# ...类似处理其他上采样层return self.outc(x)

三、创新点与优势

  1. SE模块集成:在每个双卷积块后添加SE模块,使模型能够自适应地重新校准通道特征响应

  2. 改进的特征融合:使用双线性插值调整跳跃连接特征图尺寸,确保精确对齐

  3. 参数效率:通过factor参数控制解码器通道数,平衡模型容量和计算成本

四、性能分析

这个改进版U-Net相比原始U-Net有以下潜在优势:

  • 更好的特征选择能力,通过SE模块突出重要特征

  • 更稳定的训练,得益于批量归一化的广泛使用

  • 更精确的边界预测,得益于改进的特征融合方式

五、使用示例

# 创建模型实例
model = ImprovedUNet(n_channels=3, n_classes=1)# 随机输入测试
input_tensor = torch.randn(2, 3, 256, 256)  # 2张256x256的RGB图像
output = model(input_tensor)  # 输出形状为[2, 1, 256, 256]

六、完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F# SE 模块
class SELayer(nn.Module):def __init__(self, in_channels, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)# 改进的卷积块
class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),SELayer(out_channels)  # 添加 SE 模块)def forward(self, x):return self.double_conv(x)# 改进的 U-Net 模型
class ImprovedUNet(nn.Module):def __init__(self, n_channels, n_classes):super().__init__()self.n_channels = n_channelsself.n_classes = n_classesself.inc = DoubleConv(n_channels, 64)self.down1 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(64, 128))self.down2 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(128, 256))self.down3 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(256, 512))factor = 2self.down4 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(512, 1024 // factor))self.up1 = nn.Sequential(nn.ConvTranspose2d(1024 // factor, 512 // factor, kernel_size=2, stride=2))self.double_conv_up1 = DoubleConv(512 // factor + 512, 512 // factor)self.up2 = nn.Sequential(nn.ConvTranspose2d(512 // factor, 256 // factor, kernel_size=2, stride=2))self.double_conv_up2 = DoubleConv(256 // factor + 256, 256 // factor)self.up3 = nn.Sequential(nn.ConvTranspose2d(256 // factor, 128 // factor, kernel_size=2, stride=2))self.double_conv_up3 = DoubleConv(128 // factor + 128, 128 // factor)self.up4 = nn.Sequential(nn.ConvTranspose2d(128 // factor, 64, kernel_size=2, stride=2))self.double_conv_up4 = DoubleConv(64 + 64, 64)self.outc = nn.Conv2d(64, n_classes, kernel_size=1)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5)x = self.double_conv_up1(torch.cat([x, F.interpolate(x4, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up2(x)x = self.double_conv_up2(torch.cat([x, F.interpolate(x3, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up3(x)x = self.double_conv_up3(torch.cat([x, F.interpolate(x2, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up4(x)x = self.double_conv_up4(torch.cat([x, F.interpolate(x1, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))logits = self.outc(x)return logits# 创建改进的 U-Net 模型实例
model = ImprovedUNet(n_channels=3, n_classes=1)
print(model)# 生成一个随机输入
input_tensor = torch.randn(2, 3, 256, 256)# 前向传播
output = model(input_tensor)
print(output.shape)

七、适用场景

这种改进的U-Net特别适合以下任务:

  • 医学图像分割(CT/MRI)

  • 遥感图像解析

  • 任何需要精确边界预测的密集预测任务

八、总结

通过在U-Net中集成SE模块,我们获得了能够自适应关注重要特征的改进架构。这种设计在不显著增加计算成本的情况下,提高了模型的特征选择能力,使其在各种图像分割任务中表现更加出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/76068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器人拧螺丝紧固装配(Robot screw fastening assembly)

机器人拧螺丝紧固装配技术正以其高精度、高效率和高灵活性,重塑着传统制造业的生产范式。这项融合了机械臂定位、扭矩控制、视觉引导与数据分析的自动化解决方案,不仅将工人从重复性高强度劳动中解放出来,更通过实时数据反馈与精准执行&#…

图像处理中的 Gaussina Blur 和 SIFT 算法

Gaussina Blur 高斯模糊 高斯模糊的数学定义 高斯模糊是通过 高斯核(Gaussian Kernel) 对图像进行卷积操作实现的. 二维高斯函数定义为 G ( x , y , σ ) 1 2 π σ 2 e − x 2 y 2 2 σ 2 G(x, y, \sigma) \frac{1}{2\pi \sigma^2} e^{-\frac{x^2 y^2}{2\sigma^2}} G(x…

在Unity中实现《幽灵行者》风格的跑酷动作

基础设置 角色控制器选择: 使用Character Controller组件或Rigidbody Capsule Collider 推荐使用Character Controller以获得更精确的运动控制 输入系统: 使用Unity的新输入系统(Input System Package)处理玩家输入 滑铲实现 public class Slide…

青蛙吃虫--dp

1.dp数组有关元素--路长和次数 2.递推公式 3.遍历顺序--最终影响的是路长&#xff0c;在外面 其次次数遍历&#xff0c;即这次路长所有情况都更新 最后&#xff0c;遍历次数自然就要遍历跳长 4.max时时更新 dp版本 #include<bits/stdc.h> using namespace std; #def…

Tiktok 关键字 视频及评论信息爬虫(2) [2025.04.07]

&#x1f64b;‍♀️Tiktok APP的基于关键字检索的视频及评论信息爬虫共分为两期&#xff0c;希望对大家有所帮助。 第一期&#xff1a;基于关键字检索的视频信息爬取 第二期见下文。 1.Node.js环境配置 首先配置 JavaScript 运行环境&#xff08;如 Node.js&#xff09;&…

Matlab绘图—‘‘错误使用 plot输入参数的数目不足‘‘

原因1&#xff1a; ❤️ 文件列名不是合法变量名 在excel中数据列名称为Sample:float,将:删除就解决了

Kotlin问题汇总

Kotlin问题汇总 真机安装调试 查看真机的Android版本&#xff0c;将build.gradle文件中的minSdk改为手机的Android版本&#xff0c;点Sync Now更新设置 apk安装失败 在gradle.properties全局配置中设置android.injected.testOnlyfalse Unresolved reference: 在activity_…

基于VMware的Cent OS Stream 8安装与配置及远程连接软件的介绍

1.VMware Workstation 简介&#xff1a; VMware Workstation&#xff08;中文名“威睿工作站”&#xff09;是一款功能强大的桌面虚拟计算机软件&#xff0c;提供用户可在单一的桌面上同时运行不同的操作系统&#xff0c;和进行开发、测试 、部署新的应用程序的最佳解决方案。…

Go语言从零构建SQL数据库(4)-解析器

SQL解析器&#xff1a;数据库的"翻译官"图解与代码详解 图解SQL解析过程 SQL解析器就像是人类语言与计算机之间的翻译官&#xff0c;将我们书写的SQL语句转换成数据库能够理解和执行的结构。 #mermaid-svg-f9gAqHutDLL4McGy {font-family:"trebuchet ms"…

十道海量数据处理面试题与十个方法总结

一、十道海量数据处理面试题 ♟️1、海量日志数据&#xff0c;提取出某日访问百度次数最多的那个IP。(分治思想 哈希表) 首先&#xff0c;从日志中提取出所有访问百度的IP地址&#xff0c;将它们逐个写入一个大文件中&#xff0c;便于后续处理。 考虑到IP地址是32位的&#…

SolidWorks2025三维计算机辅助设计(3D CAD)软件超详细图文安装教程(2025最新版保姆级教程)

目录 前言 一、SolidWorks下载 二、SolidWorks安装 三、启动SolidWorks 前言 SolidWorks 是一款由法国达索系统&#xff08;Dassault Systmes&#xff09;公司开发的三维计算机辅助设计&#xff08;3D CAD&#xff09;软件&#xff0c;广泛用于机械设计、工程仿真和产品开…

IntelliJ IDEA 2020~2024 创建SpringBoot项目编辑报错: 程序包org.springframework.boot不存在

目录 前奏解决结尾 前奏 哈&#xff01;今天在处理我的SpringBoot项目时&#xff0c;突然遇到了一些让人摸不着头脑的错误提示&#xff1a; java: 程序包org.junit不存在 java: 程序包org.junit.runner不存在 java: 程序包org.springframework.boot.test.context不存在 java:…

CPU 压力测试命令大全

CPU 压力测试命令大全 以下是 Linux/Unix 系统下常用的 CPU 压力测试命令和工具&#xff0c;可用于测试 CPU 性能、稳定性和散热能力。 1. 基本压力测试命令 1.1 使用 yes 命令 yes > /dev/null & # 启动一个无限循环进程 yes > /dev/null & # 启动第二个进…

#SVA语法滴水穿石# (003)关于 sequence 和 property 的区别和联系

在 SystemVerilog Assertions (SVA) 中,sequence 和 property 是两个核心概念,它们既有区别又紧密相关。对于初学者,可能不需要过多理解;但是要想写出复杂精美的断言,深刻理解两者十分重要。今天,我们汇总和学习一下该知识点。 1. 区别 特性sequenceproperty定义描述一系…

WordPress浮动广告插件+飘动效果客服插件

源码介绍 WordPress浮动广告插件飘动效果客服插件 将源码上传到wordpress的插件根目录下&#xff0c;解压&#xff0c;然后后台启用即可 截图 源码免费获取 WordPress浮动广告插件飘动效果客服插件

虚幻基础:蓝图基础知识

文章目录 组件蓝图创建时&#xff0c;优先创建组件&#xff0c;如c一样。 UI控件控件不会自动创建&#xff0c;而是在蓝图创建函数中手动创建。 函数内使用S序列接退出&#xff0c;并不会等所有执行完再退出&#xff0c;而是一个执行完后直接退出 组件 蓝图创建时&#xff0c;…

《AI大模型应知应会100篇》加餐篇:LlamaIndex 与 LangChain 的无缝集成

加餐篇&#xff1a;LlamaIndex 与 LangChain 的无缝集成 问题背景&#xff1a;在实际应用中&#xff0c;开发者常常需要结合多个框架的优势。例如&#xff0c;使用 LangChain 管理复杂的业务逻辑链&#xff0c;同时利用 LlamaIndex 的高效索引和检索能力构建知识库。本文在基于…

深度学习项目--分组卷积与ResNext网络实验探究(pytorch复现)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 前言 ResNext是分组卷积的开始之作&#xff0c;这里本文将学习ResNext网络&#xff1b;本文复现了ResNext50神经网络&#xff0c;并用其进行了猴痘病分类实验…

从代码学习深度学习 - RNN PyTorch版

文章目录 前言一、数据预处理二、辅助训练工具函数三、绘图工具函数四、模型定义五、模型训练与预测六、实例化模型并训练训练结果可视化总结前言 循环神经网络(RNN)是深度学习中处理序列数据的重要模型,尤其在自然语言处理和时间序列分析中有着广泛应用。本篇博客将通过一…

JS DOM节点增删改查

增加节点 通过document.createNode()函数创建对象 // 创建节点 const div document.createElement(div) // 追加节点 document.body.appendChild(div) 克隆节点 删除节点