详细介绍:图像分割:PyTorch从零开始实现SegFormer语义分割

news/2025/10/4 9:23:16/文章来源:https://www.cnblogs.com/ljbguanli/p/19125311

详细介绍:图像分割:PyTorch从零开始实现SegFormer语义分割

图像分割:PyTorch从零开始实现SegFormer语义分割

  • 前言
  • 环境要求
  • 相关介绍
  • SegFormer核心模块:编码器(MiT)和解码器(All-MLP解码器)。
    • 编码器(MiT):
      • 分层结构,产生多尺度特征(通常有4个阶段,每个阶段特征图尺寸递减)。
      • 每个阶段由多个Transformer块组成,每个块包含:
        • 重叠块嵌入(Overlapped Patch Embedding)
        • 高效自注意力(Efficient Self-Attention)
        • 混合前馈网络(Mix FeedForward Network)
    • 解码器(All-MLP):
      • 将多尺度特征上采样到相同尺寸并拼接。
      • 通过多层感知机(MLP)得到分割结果。
  • 具体实现
    • 导入相关库
    • 准备数据集
    • 定义网络模型
    • 训练验证
    • 推理预测
    • 主函数
    • 输出结果
    • 完整代码
  • 参考

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

前言

环境要求

Package                Version      Editable project location
---------------------- ------------ ----------------------------------------------
addict                 2.4.0
aliyun-python-sdk-core 2.16.0
aliyun-python-sdk-kms  2.16.5
certifi                2025.8.3
cffi                   2.0.0
charset-normalizer     3.4.3
click                  8.3.0
colorama               0.4.6
contourpy              1.3.2
crcmod                 1.7
cryptography           46.0.1
cycler                 0.12.1
einops                 0.8.1
filelock               3.14.0
fonttools              4.60.0
fsspec                 2025.9.0
ftfy                   6.3.1
huggingface-hub        0.35.1
idna                   3.10
jmespath               0.10.0
kiwisolver             1.4.9
Markdown               3.9
markdown-it-py         4.0.0
matplotlib             3.10.6
mdurl                  0.1.2
mmcv                   2.1.0
mmcv-full              1.2.7
mmengine               0.10.7
mmsegmentation         0.11.0
model-index            0.1.11
numpy                  1.26.3
opencv-python          4.6.0.66
opendatalab            0.0.10
openmim                0.3.9
openxlab               0.1.2
ordered-set            4.1.0
oss2                   2.17.0
packaging              24.2
pandas                 2.3.2
pillow                 11.3.0
pip                    23.0.1
platformdirs           4.4.0
polars                 1.33.1
prettytable            3.16.0
psutil                 7.1.0
pycparser              2.23
pycryptodome           3.23.0
Pygments               2.19.2
pyparsing              3.2.5
python-dateutil        2.9.0.post0
pytz                   2023.4
pywin32                311
PyYAML                 6.0.3
regex                  2025.9.18
requests               2.28.2
rich                   13.4.2
safetensors            0.6.2
scipy                  1.15.3
setuptools             60.2.0
six                    1.17.0
tabulate               0.9.0
termcolor              3.1.0
terminaltables         3.1.10
timm                   1.0.20
tomli                  2.2.1
torch                  1.13.1+cu116
torchaudio             0.13.1+cu116
torchvision            0.14.1+cu116
tqdm                   4.65.2
typing_extensions      4.15.0
tzdata                 2025.2
ultralytics            8.3.203
ultralytics-thop       2.0.17
urllib3                1.26.20
wcwidth                0.2.14
yapf                   0.43.0

相关介绍

  • Python是一种跨平台的计算机程序设计语言。是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。
  • PyTorch 是一个深度学习框架,封装好了很多网络和深度学习相关的工具方便我们调用,而不用我们一个个去单独写了。它分为 CPU 和 GPU 版本,其他框架还有 TensorFlow、Caffe 等。PyTorch 是由 Facebook 人工智能研究院(FAIR)基于 Torch 推出的,它是一个基于 Python 的可续计算包,提供两个高级功能:1、具有强大的 GPU 加速的张量计算(如 NumPy);2、构建深度神经网络时的自动微分机制。
  • SegFormer 是一个简单、高效但功能强大的语义分割框架,它将 Transformers 与轻量级多层感知器 (MLP) 解码器结合在一起。
  • SegFormer 有两个吸引人的特点:
    1. SegFormer 包含一个新颖的分层结构变换器编码器,可输出多尺度特征。它不需要位置编码,从而避免了位置编码的插值,当测试分辨率与训练分辨率不同时,插值会导致性能下降。
    2. SegFormer 避免了复杂的解码器。所提出的 MLP 解码器汇聚了来自不同层的信息,从而将局部注意力和全局注意力结合起来,呈现出强大的表征。
  • 这种简单轻便的设计是在 Transformers 上实现高效分割的关键。通过扩展,获得了从 SegFormer-B0 到 SegFormer-B5 的一系列模型,其性能和效率明显优于之前的同类产品。
  • 例如,SegFormer-B4 在 64M 参数的 ADE20K 上实现了 50.3% 的 mIoU,比之前的最佳方法小 5 倍,好 2.2%。最佳模型 SegFormer-B5 在 Cityscapes 验证集上实现了 84.0% 的 mIoU,并在 Cityscapes-C 上显示了出色的零点稳健性。
  • 官方源代码: https://github.com/NVlabs/SegFormer.git
  • Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers. 2021
    在这里插入图片描述

SegFormer核心模块:编码器(MiT)和解码器(All-MLP解码器)。

在这里插入图片描述

class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=4
):
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1),
nn.Upsample(scale_factor=2 ** i)
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1),
nn.Conv2d(decoder_dim, num_classes, 1),
)
def forward(self, x):
H, W = x.shape[-2:]  # 原始输入高宽
layer_outputs = self.mit(x, return_layer_outputs=True)
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
out = self.to_segmentation(fused)
# 关键修复:上采样到原始输入尺寸
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out

编码器(MiT):

  • 论文中的MiT:
    • 分层设计的Transformer编码器
    • 4个阶段,每个阶段下采样2倍
    • 使用重叠块嵌入(Overlapped Patch Embedding)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
super().__init__()
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out, heads=heads, reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(
self,
x,
return_layer_outputs=False
):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret

分层结构,产生多尺度特征(通常有4个阶段,每个阶段特征图尺寸递减)。

class MiT(nn.Module):
def __init__(self, *, channels, dims, heads, ff_expansion, reduction_ratio, num_layers):
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4

每个阶段由多个Transformer块组成,每个块包含:

重叠块嵌入(Overlapped Patch Embedding)
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
高效自注意力(Efficient Self-Attention)
  • 论文创新点:
    • 序列缩减机制,降低计算复杂度
    • 使用reduction_ratio对K,V进行下采样
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False) # 关键:序列缩减
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
混合前馈网络(Mix FeedForward Network)
  • 论文创新点:
    • 使用3×3深度可分离卷积增强局部特征提取
    • 替换标准MLP
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1), # 升维
DsConv2d(hidden_dim, hidden_dim, 3, padding=1), # 深度可分离卷积
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1) # 降维
)
def forward(self, x):
return self.net(x)

解码器(All-MLP):

  • 论文创新点:
    • 简单的MLP结构,无需复杂设计
    • 多尺度特征融合

将多尺度特征上采样到相同尺寸并拼接。

# 多尺度特征融合
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1), # 统一通道数
nn.Upsample(scale_factor=2 ** i) # 上采样到1/4尺度
) for i, dim in enumerate(dims)])

通过多层感知机(MLP)得到分割结果。

self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1), # 特征融合
nn.Conv2d(decoder_dim, num_classes, 1), # 分类头
)

具体实现

导入相关库

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import torch.nn.functional as F
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth

准备数据集

# ============== MockSegmentationDataset ==============
class MockSegmentationDataset(Dataset):
def __init__(self, size=256, num_samples=1000, num_classes=4):
self.size = size
self.num_samples = num_samples
self.num_classes = num_classes
# 图像变换
self.image_transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 使用固定模式而不是完全随机,让模型容易学习
rng = np.random.RandomState(idx)  # 固定随机种子,让数据可重复
# 生成更结构化的背景
img = np.full((self.size, self.size, 3), 128, dtype=np.uint8)  # 固定灰色背景
seg_map = np.zeros((self.size, self.size), dtype=np.uint8)
# 固定位置和尺寸的形状,减少随机性
positions = [
(self.size//4, self.size//4),      # 左上
(3*self.size//4, self.size//4),    # 右上  
(self.size//4, 3*self.size//4),    # 左下
(3*self.size//4, 3*self.size//4),  # 右下
]
# 为每个样本固定选择2个形状,确保类别平衡
shape_indices = [idx % 3 + 1, (idx + 1) % 3 + 1]  # 循环使用类别1,2,3
for i, cls in enumerate(shape_indices[:2]):  # 只画2个形状
pos = positions[i]
if cls == 1:  # 圆形
cv2.circle(seg_map, pos, 25, int(cls), -1)
cv2.circle(img, pos, 25, (255, 0, 0), -1)  # 红色
elif cls == 2:  # 矩形
pt1 = (pos[0]-25, pos[1]-20)
pt2 = (pos[0]+25, pos[1]+20)
cv2.rectangle(seg_map, pt1, pt2, int(cls), -1)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
elif cls == 3:  # 椭圆
cv2.ellipse(seg_map, pos, (30, 15), 45, 0, 360, int(cls), -1)
cv2.ellipse(img, pos, (30, 15), 45, 0, 360, (0, 0, 255), -1)  # 蓝色
# 应用图像变换
img = Image.fromarray(img)
img = self.image_transform(img)
# 直接转换为tensor,不应用与图像相同的变换
seg_map = torch.from_numpy(seg_map).long()
return img, seg_map

定义网络模型

# ============== SegFormer模型定义 ==============
class DsConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
)
def forward(self, x):
return self.net(x)
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False) # 关键:序列缩减
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1), # 升维
DsConv2d(hidden_dim, hidden_dim, 3, padding=1), # 深度可分离卷积
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1) # 降维
)
def forward(self, x):
return self.net(x)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
super().__init__()
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out, heads=heads, reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(
self,
x,
return_layer_outputs=False
):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret
class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=4
):
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
# 多尺度特征融合
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1), # 统一通道数
nn.Upsample(scale_factor=2 ** i) # 上采样到1/4尺度
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1), # 特征融合
nn.Conv2d(decoder_dim, num_classes, 1), # 分类头
)
def forward(self, x):
H, W = x.shape[-2:]  # 原始输入高宽
layer_outputs = self.mit(x, return_layer_outputs=True)
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
out = self.to_segmentation(fused)
# 关键修复:上采样到原始输入尺寸
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out

训练验证

# ============== 训练函数 ==============
def get_segformer(model_name='b0', num_classes=4, decoder_dim=256):
config = {
'b0': dict(dims=(32, 64, 160, 256), num_layers=(2, 2, 2, 2)),
'b1': dict(dims=(64, 128, 320, 512), num_layers=(2, 2, 2, 2)),
'b2': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 6, 3)),
'b3': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 18, 3)),
'b4': dict(dims=(64, 128, 320, 512), num_layers=(3, 8, 27, 3)),
'b5': dict(dims=(64, 128, 320, 512), num_layers=(3, 6, 40, 3)),
}
if model_name not in config:
raise ValueError(f"Unsupported model: {model_name}")
cfg = config[model_name]
ff_expansion = (4, 4, 4, 4) if model_name == 'b5' else (8, 8, 4, 4)
return Segformer(
dims=cfg['dims'],
heads=(1, 2, 5, 8),
ff_expansion=ff_expansion,
reduction_ratio=(8, 4, 2, 1),
num_layers=cfg['num_layers'],
decoder_dim=decoder_dim,
num_classes=num_classes
)
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=1e-4, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", leave=False):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
# 确保输出和标签维度匹配
# outputs: [batch, num_classes, H, W]
# labels: [batch, H, W]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss /= len(train_loader.dataset)
train_losses.append(train_loss)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loss /= len(val_loader.dataset)
val_losses.append(val_loss)
# 每2个epoch可视化一次训练样本的预测
if epoch % 2 == 0 or epoch == num_epochs - 1:
model.eval()
with torch.no_grad():
# 取一个训练样本
sample_img, sample_label = next(iter(train_loader))
sample_img, sample_label = sample_img[:1].to(device), sample_label[:1].to(device)
output = model(sample_img)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(sample_img[0].cpu().permute(1, 2, 0))
plt.title('Input')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(sample_label[0].cpu(), cmap='jet', vmin=0, vmax=3)
plt.title('Ground Truth')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred, cmap='jet', vmin=0, vmax=3)
plt.title(f'Prediction Epoch {epoch}')
plt.axis('off')
plt.savefig(f'train_debug_epoch_{epoch}.png', dpi=100, bbox_inches='tight')
plt.close()
print(f"Debug visualization saved to train_debug_epoch_{epoch}.png")
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
return model, train_losses, val_losses

推理预测

# ============== 推理函数 ==============
def load_model(model_path, model_name='b0', num_classes=4, device='cuda'):
"""加载训练好的模型"""
model = get_segformer(model_name=model_name, num_classes=num_classes)
# 加载模型权重
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"模型加载成功,参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
print(f"模型加载成功,参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def predict(model, image_path, device='cuda'):
model = model.to(device)
model.eval()
# 加载并预处理图像
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0)  # Add batch dimension
# Move to device
image = image.to(device)
# Predict
with torch.no_grad():
output = model(image)
# Get prediction (argmax along channels)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
return image.squeeze(0).cpu().numpy(), pred
# ============== 可视化函数 ==============
def visualize_results(original, prediction, save_path=None):
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.transpose(original, (1, 2, 0)))
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(prediction, cmap='jet', vmin=0, vmax=3)
plt.title('Segmentation Prediction')
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=100, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
else:
plt.show()
def generate_sample_image_and_label(save_img_path="sample_image.png", save_label_path=None, size=256):
"""
生成一张带几何形状的模拟图像和对应的标签图(可选保存)。
- 背景: 类别 0
- 红色圆: 类别 1
- 绿色矩形: 类别 2
- 蓝色椭圆: 类别 3
"""
# 创建灰色背景图像
img = np.full((size, size, 3), 128, dtype=np.uint8)
label = np.zeros((size, size), dtype=np.uint8)
# 1. 红色圆(类别 1)
center1 = (80, 80)
radius1 = 25
cv2.circle(img, center1, radius1, (255, 0, 0), -1)  # 红色
cv2.circle(label, center1, radius1, 1, -1)
# 2. 绿色矩形(类别 2)
pt1 = (150, 60)
pt2 = (200, 110)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
cv2.rectangle(label, pt1, pt2, 2, -1)
# 3. 蓝色椭圆(类别 3)
center2 = (120, 180)
axes = (30, 15)
cv2.ellipse(img, center2, axes, 45, 0, 360, (0, 0, 255), -1)  # 蓝色
cv2.ellipse(label, center2, axes, 45, 0, 360, 3, -1)
# 保存图像
cv2.imwrite(save_img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
print(f"Sample image saved to {save_img_path}")
if save_label_path:
# 保存标签为可视化灰度图(0~255 映射)
label_vis = (label * 60).astype(np.uint8)  # 0,60,120,180 便于肉眼区分
cv2.imwrite(save_label_path, label_vis)
print(f"Label visualization saved to {save_label_path}")
return img, label

主函数

# ============== 主程序 ==============
if __name__ == "__main__":
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
# 创建模拟数据集
dataset = MockSegmentationDataset()
# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)
# # 初始化模型b0
# model = Segformer(
#     dims=(32, 64, 160, 256), # 各阶段通道数 [C1, C2, C3, C4]
#     heads=(1, 2, 5, 8), # 各阶段注意力头数
#     ff_expansion=(8, 8, 4, 4), # FFN扩展因子
#     reduction_ratio=(8, 4, 2, 1), # 序列缩减比例
#     num_layers=2, # 各阶段层数
#     decoder_dim=256, # 解码器统一维度
#     num_classes=4 # 分割类别数
# )
model_name = 'b0'  # 可选 'b0', 'b1', 'b2', 'b3', 'b4', 'b5'
model = get_segformer(model_name, num_classes=4)
os.makedirs(model_name, exist_ok=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 训练模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model, train_losses, val_losses = train_model(
model,
train_loader,
val_loader,
num_epochs=5,
# num_epochs=10,  # 增加到10个epoch
learning_rate=1e-4,
device=device
)
# 保存模型
torch.save(model.state_dict(), f'{model_name}/segformer_model.pth')
print(f"Model saved to '{model_name}/segformer_model.pth'")
# 测试推理
print("\nTesting inference on a sample image...")
# 生成一个结构清晰的模拟图像用于推理
sample_img, sample_label = generate_sample_image_and_label(
save_img_path="sample_image.png",
save_label_path="sample_label.png",  # 可选:保存标签用于对比
size=256
)
sample_img_path = "sample_image.png"
# 加载模型
model = load_model(f'{model_name}/segformer_model.pth', model_name=model_name, num_classes=4, device=device)
# 进行预测
original, prediction = predict(model, sample_img_path, device=device)
# 可视化结果
visualize_results(original, prediction, save_path=f"{model_name}/segmentation_result.png")
print(f"Inference completed. Result saved to '{model_name}/segmentation_result.png'")

输出结果

Model parameters: 7718244
Using device: cuda
Debug visualization saved to train_debug_epoch_0.png
Epoch 1/5, Train Loss: 0.1226, Val Loss: 0.0077
Epoch 2/5, Train Loss: 0.0052, Val Loss: 0.0037
Debug visualization saved to train_debug_epoch_2.png
Epoch 3/5, Train Loss: 0.0031, Val Loss: 0.0026
Epoch 4/5, Train Loss: 0.0022, Val Loss: 0.0019
Debug visualization saved to train_debug_epoch_4.png
Epoch 5/5, Train Loss: 0.0017, Val Loss: 0.0015
Model saved to 'b0/segformer_model.pth'
Testing inference on a sample image...
Sample image saved to sample_image.png
Label visualization saved to sample_label.png
模型加载成功,参数数量: 7718244
Visualization saved to b0/segmentation_result.png
Inference completed. Result saved to 'b0/segmentation_result.png'

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

完整代码

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import torch.nn.functional as F
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
# ============== MockSegmentationDataset ==============
class MockSegmentationDataset(Dataset):
def __init__(self, size=256, num_samples=1000, num_classes=4):
self.size = size
self.num_samples = num_samples
self.num_classes = num_classes
# 图像变换
self.image_transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 使用固定模式而不是完全随机,让模型容易学习
rng = np.random.RandomState(idx)  # 固定随机种子,让数据可重复
# 生成更结构化的背景
img = np.full((self.size, self.size, 3), 128, dtype=np.uint8)  # 固定灰色背景
seg_map = np.zeros((self.size, self.size), dtype=np.uint8)
# 固定位置和尺寸的形状,减少随机性
positions = [
(self.size//4, self.size//4),      # 左上
(3*self.size//4, self.size//4),    # 右上  
(self.size//4, 3*self.size//4),    # 左下
(3*self.size//4, 3*self.size//4),  # 右下
]
# 为每个样本固定选择2个形状,确保类别平衡
shape_indices = [idx % 3 + 1, (idx + 1) % 3 + 1]  # 循环使用类别1,2,3
for i, cls in enumerate(shape_indices[:2]):  # 只画2个形状
pos = positions[i]
if cls == 1:  # 圆形
cv2.circle(seg_map, pos, 25, int(cls), -1)
cv2.circle(img, pos, 25, (255, 0, 0), -1)  # 红色
elif cls == 2:  # 矩形
pt1 = (pos[0]-25, pos[1]-20)
pt2 = (pos[0]+25, pos[1]+20)
cv2.rectangle(seg_map, pt1, pt2, int(cls), -1)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
elif cls == 3:  # 椭圆
cv2.ellipse(seg_map, pos, (30, 15), 45, 0, 360, int(cls), -1)
cv2.ellipse(img, pos, (30, 15), 45, 0, 360, (0, 0, 255), -1)  # 蓝色
# 应用图像变换
img = Image.fromarray(img)
img = self.image_transform(img)
# 直接转换为tensor,不应用与图像相同的变换
seg_map = torch.from_numpy(seg_map).long()
return img, seg_map
# ============== SegFormer模型定义 ==============
class DsConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
)
def forward(self, x):
return self.net(x)
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False) # 关键:序列缩减
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1), # 升维
DsConv2d(hidden_dim, hidden_dim, 3, padding=1), # 深度可分离卷积
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1) # 降维
)
def forward(self, x):
return self.net(x)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
super().__init__()
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out, heads=heads, reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(
self,
x,
return_layer_outputs=False
):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret
class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=4
):
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
# 多尺度特征融合
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1), # 统一通道数
nn.Upsample(scale_factor=2 ** i) # 上采样到1/4尺度
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1), # 特征融合
nn.Conv2d(decoder_dim, num_classes, 1), # 分类头
)
def forward(self, x):
H, W = x.shape[-2:]  # 原始输入高宽
layer_outputs = self.mit(x, return_layer_outputs=True)
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
out = self.to_segmentation(fused)
# 关键修复:上采样到原始输入尺寸
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out
# ============== 训练函数 ==============
def get_segformer(model_name='b0', num_classes=4, decoder_dim=256):
config = {
'b0': dict(dims=(32, 64, 160, 256), num_layers=(2, 2, 2, 2)),
'b1': dict(dims=(64, 128, 320, 512), num_layers=(2, 2, 2, 2)),
'b2': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 6, 3)),
'b3': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 18, 3)),
'b4': dict(dims=(64, 128, 320, 512), num_layers=(3, 8, 27, 3)),
'b5': dict(dims=(64, 128, 320, 512), num_layers=(3, 6, 40, 3)),
}
if model_name not in config:
raise ValueError(f"Unsupported model: {model_name}")
cfg = config[model_name]
ff_expansion = (4, 4, 4, 4) if model_name == 'b5' else (8, 8, 4, 4)
return Segformer(
dims=cfg['dims'],
heads=(1, 2, 5, 8),
ff_expansion=ff_expansion,
reduction_ratio=(8, 4, 2, 1),
num_layers=cfg['num_layers'],
decoder_dim=decoder_dim,
num_classes=num_classes
)
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=1e-4, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", leave=False):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
# 确保输出和标签维度匹配
# outputs: [batch, num_classes, H, W]
# labels: [batch, H, W]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss /= len(train_loader.dataset)
train_losses.append(train_loss)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loss /= len(val_loader.dataset)
val_losses.append(val_loss)
# 每2个epoch可视化一次训练样本的预测
if epoch % 2 == 0 or epoch == num_epochs - 1:
model.eval()
with torch.no_grad():
# 取一个训练样本
sample_img, sample_label = next(iter(train_loader))
sample_img, sample_label = sample_img[:1].to(device), sample_label[:1].to(device)
output = model(sample_img)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(sample_img[0].cpu().permute(1, 2, 0))
plt.title('Input')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(sample_label[0].cpu(), cmap='jet', vmin=0, vmax=3)
plt.title('Ground Truth')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred, cmap='jet', vmin=0, vmax=3)
plt.title(f'Prediction Epoch {epoch}')
plt.axis('off')
plt.savefig(f'train_debug_epoch_{epoch}.png', dpi=100, bbox_inches='tight')
plt.close()
print(f"Debug visualization saved to train_debug_epoch_{epoch}.png")
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
return model, train_losses, val_losses
# ============== 推理函数 ==============
def load_model(model_path, model_name='b0', num_classes=4, device='cuda'):
"""加载训练好的模型"""
model = get_segformer(model_name=model_name, num_classes=num_classes)
# 加载模型权重
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"模型加载成功,参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def predict(model, image_path, device='cuda'):
model = model.to(device)
model.eval()
# 加载并预处理图像
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0)  # Add batch dimension
# Move to device
image = image.to(device)
# Predict
with torch.no_grad():
output = model(image)
# Get prediction (argmax along channels)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
return image.squeeze(0).cpu().numpy(), pred
# ============== 可视化函数 ==============
def visualize_results(original, prediction, save_path=None):
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.transpose(original, (1, 2, 0)))
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(prediction, cmap='jet', vmin=0, vmax=3)
plt.title('Segmentation Prediction')
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=100, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
else:
plt.show()
def generate_sample_image_and_label(save_img_path="sample_image.png", save_label_path=None, size=256):
"""
生成一张带几何形状的模拟图像和对应的标签图(可选保存)。
- 背景: 类别 0
- 红色圆: 类别 1
- 绿色矩形: 类别 2
- 蓝色椭圆: 类别 3
"""
# 创建灰色背景图像
img = np.full((size, size, 3), 128, dtype=np.uint8)
label = np.zeros((size, size), dtype=np.uint8)
# 1. 红色圆(类别 1)
center1 = (80, 80)
radius1 = 25
cv2.circle(img, center1, radius1, (255, 0, 0), -1)  # 红色
cv2.circle(label, center1, radius1, 1, -1)
# 2. 绿色矩形(类别 2)
pt1 = (150, 60)
pt2 = (200, 110)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
cv2.rectangle(label, pt1, pt2, 2, -1)
# 3. 蓝色椭圆(类别 3)
center2 = (120, 180)
axes = (30, 15)
cv2.ellipse(img, center2, axes, 45, 0, 360, (0, 0, 255), -1)  # 蓝色
cv2.ellipse(label, center2, axes, 45, 0, 360, 3, -1)
# 保存图像
cv2.imwrite(save_img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
print(f"Sample image saved to {save_img_path}")
if save_label_path:
# 保存标签为可视化灰度图(0~255 映射)
label_vis = (label * 60).astype(np.uint8)  # 0,60,120,180 便于肉眼区分
cv2.imwrite(save_label_path, label_vis)
print(f"Label visualization saved to {save_label_path}")
return img, label
# ============== 主程序 ==============
if __name__ == "__main__":
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
# 创建模拟数据集
dataset = MockSegmentationDataset()
# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)
# # 初始化模型b0
# model = Segformer(
#     dims=(32, 64, 160, 256), # 各阶段通道数 [C1, C2, C3, C4]
#     heads=(1, 2, 5, 8), # 各阶段注意力头数
#     ff_expansion=(8, 8, 4, 4), # FFN扩展因子
#     reduction_ratio=(8, 4, 2, 1), # 序列缩减比例
#     num_layers=2, # 各阶段层数
#     decoder_dim=256, # 解码器统一维度
#     num_classes=4 # 分割类别数
# )
model_name = 'b0'  # 可选 'b0', 'b1', 'b2', 'b3', 'b4', 'b5'
model = get_segformer(model_name, num_classes=4)
os.makedirs(model_name, exist_ok=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 训练模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model, train_losses, val_losses = train_model(
model,
train_loader,
val_loader,
num_epochs=5,
# num_epochs=10,  # 增加到10个epoch
learning_rate=1e-4,
device=device
)
# 保存模型
torch.save(model.state_dict(), f'{model_name}/segformer_model.pth')
print(f"Model saved to '{model_name}/segformer_model.pth'")
# 测试推理
print("\nTesting inference on a sample image...")
# 生成一个结构清晰的模拟图像用于推理
sample_img, sample_label = generate_sample_image_and_label(
save_img_path="sample_image.png",
save_label_path="sample_label.png",  # 可选:保存标签用于对比
size=256
)
sample_img_path = "sample_image.png"
# 加载模型
model = load_model(f'{model_name}/segformer_model.pth', model_name=model_name, num_classes=4, device=device)
# 进行预测
original, prediction = predict(model, sample_img_path, device=device)
# 可视化结果
visualize_results(original, prediction, save_path=f"{model_name}/segmentation_result.png")
print(f"Inference completed. Result saved to '{model_name}/segmentation_result.png'")

参考

[1] Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers. 2021
[2] https://github.com/NVlabs/SegFormer.git
[3] https://github.com/bubbliiiing/segformer-pytorch.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/926967.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用html5做的静态网站网站wordpress修改首页名称

一、常用的常用的生命周期钩子: mounted: 发送 ajax 请求、启动定时器、绑定自定义事件、订阅消息等【初始化操作】 mounted() {console.log(vm 实例被挂载之后:mounted);this.timer setInterval(() > {...} }beforeDestroy: 清除定时器、解绑自定…

网络营销方案seo入门到精通

考察点 大数,快排知识点 题目 分析 本题目给一个整型数组,要求他能排出来的最小的数字。这道题目我们大可以通过排列的方式枚举出所有的数字然后求一个最小的,只不过这种方式时间复杂度非常高。接下来我们通过举例的方式观察我们的思维和数…

深入解析:Playwright同步、异步、并行、串行执行效率比较

深入解析:Playwright同步、异步、并行、串行执行效率比较pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consola…

怎么用织梦做自己的网站苏州app开发

1 混合出版物 允许传统稿件提交或作者支付的开放获取(OA)稿件 2 长度 所有页面限制包括参考文献和作者简历。对于常规论文,接受稿件的最终版面设计完成后超出这些限制的页面,将收取强制性超长页面费用(MOPC&#xf…

详细介绍:异步日志系统

详细介绍:异步日志系统2025-10-04 09:03 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; fon…

Linux基础开发工具 --- vim - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025十一集训——Day2模拟赛

赛时: 四个题都很可做的样子, T3好像是原??? OK T1 简单二分,切了。 调 lowerbound 那里 +-1 的细节,8:50 成功过掉大样例。 开 T2。 9:00,想到差分。 诶T2咋是原,就一黄f**k.本文来自博客园,作者:zhangxia…

完整教程:ARM Cortex-M:内存保护单元 (MPU) 发布

完整教程:ARM Cortex-M:内存保护单元 (MPU) 发布pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas"…

【Clion】【文件编码】Clion内置控制台中文字体乱码的解决方案及编码格式调整

View Post【Clion】【文件编码】Clion内置控制台中文字体乱码的解决方案及编码格式调整Clion内置控制台中文字体乱码的解决方案及编码格式调整一、问题描述 在使用Clion时,在保证文件编码,项目编码,属性文件的默认编…

完整教程:JavaWeb零基础学习Day1——HTMLCSS

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

淘宝联盟网站建设源码网站服务器租赁合同

疑问:Mysql多事务默认情况下,同时修改同一条记录运行修改吗?是否要手动加上for update行锁。 猜想:MySQL 会自动对涉及的数据行加上写锁(排他锁),以确保数据的一致性和隔离性。这是在默认的事务…

2025十一集训——Day模拟赛

赛时: 四个题都很可做的样子, T3好像是原??? OK T1 简单二分,切了。 调 lowerbound 那里 +-1 的细节,8:50 成功过掉大样例。 开 T2。本文来自博客园,作者:zhangxiao666,转载请注明原文链接:https://www.cnb…

Qt纯代码实现智能安防集中管理平台/楼宇对讲管理系统/门禁管理/视频监控

一、前言说明 这个项目很多年前就完成的,属于一个定制的项目,最初做的事楼宇对讲相关的功能,后面陆续增加了门禁和视频监控,这些模块加起来,慢慢的形成了一个智能安防集中管理平台的模样,但是确切的说又不够标准…

在织梦网站做静态网页伦敦做网站

文章目录 一、Hazelcast简介1、Hazelcast概述2、Hazelcast之IMDG3、数据分区 二、Hazelcast配置1、maven坐标2、集群搭建(1)组播自动搭建 3、客户端4、集群分组5、其他配置 三、Hazelcast分布式数据结构1、IMap2、IQueue:队列3、MultiMap4、I…

长沙做网站推广代理记账公司注册需要什么条件

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 1. 引言 人工智能(AI)已经成为当今世界最具革命性的技术之一,它正在深刻改变各个行业&#x…

网站流量提升方法vps搭建asp网站

八爪鱼,被动收入,财务自由,现金流,现金流游戏,各银行利率,money,资产负债表,财务自由,资产管理,个人理财,管理个人资产,理财,打造被动收入,躺着赚钱,让钱为我打工

汉文博士词典库源文件已在 github 开放

无版权问题的词典源文件已从城通网盘转存至 Github: https://github.com/wmjordan/Hanbox.Dict 部分词典的编译配置文件需使用汉文博士 7.0 或更新的版本编译。

网站自己怎么制作做网站用asp还是php

3.1 表达式和语句 表达式一共分为三种: (1)变量或常量 运算符构成的计算表达式 (2)new 表达式,结果是一个数组或类的对象。(后面讲) (3)方法调用表达式&…

梧州市建设局网站数据来源网站怎么做脚注

问题描述: 我的团队一直在处理一个包含基于标准库的 SD 卡的项目。最近我们决定迁移到 HAL 并开始了。 幸运的是,我们项目的所有部分都尽可能地更改为 HAL,它们运行良好,但我们不知道为什么 SD 卡不能正常运行。 我们没有更改外设的配置时钟&…