好的!我来把这段代码整理成博客园风格的笔记,一段代码一段讲解:
FCN-ResNet18 语义分割完整实现详解
1. 导入必要的库
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
代码说明:
torch:PyTorch深度学习框架torchvision:提供预训练模型和数据集nn:神经网络模块F:函数式接口d2l:《动手学深度学习》工具库
2. VOC数据集类别和颜色定义
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle','bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog','horse', 'motorbike', 'person', 'potted plant', 'sheep','sofa', 'train', 'tv/monitor'
]VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]
]
代码说明:
VOC_CLASSES:21个语义类别名称VOC_COLORMAP:每个类别对应的RGB颜色值- 背景为黑色
[0,0,0],飞机为红色[128,0,0]等 - 这是PASCAL VOC数据集的官方定义
3. 创建颜色到标签的映射字典
colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
for i, colormap in enumerate(VOC_COLORMAP):# 将RGB颜色转换为唯一索引:R*256^2 + G*256 + Bcolor_index = (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]colormap2label[color_index] = i
代码说明:
- 创建大小为256³的查找表(覆盖所有RGB颜色)
- 将每个VOC颜色映射到对应的类别ID
- 例如:黑色
[0,0,0]→ 索引0→ 类别0(背景)
4. 高效的标签转换函数
def voc_label_indices_fast(colormap, colormap2label):"""使用查找表快速将RGB标签图转换为类别ID"""# colormap: (H, W, 3) RGB图像# 将RGB图像转换为索引:R*256^2 + G*256 + Bindices = (colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 + colormap[:, :, 2]# 使用查找表直接映射到类别IDreturn colormap2label[indices]def preprocess_mask(mask, colormap2label):"""预处理掩码:RGB → 类别ID"""if mask.dim() == 3 and mask.shape[-1] == 3: # 如果是RGB图像mask = voc_label_indices_fast(mask, colormap2label)return mask
代码说明:
voc_label_indices_fast:批量处理整个图像,比逐像素循环快很多- 利用向量化操作一次性计算所有像素的索引
preprocess_mask:封装函数,自动判断输入格式
5. 双线性插值卷积核
def bilinear_kernel(in_channels, out_channels, kernel_size):factor = (kernel_size + 1) // 2center = factor - 1 if kernel_size % 2 == 1 else factor - 0.5og = torch.arange(kernel_size).reshape(-1, 1), torch.arange(kernel_size).reshape(1, -1)filt = (1 - torch.abs(og[0] - center) / factor) * (1 - torch.abs(og[1] - center) / factor)weight = torch.zeros((in_channels, out_channels, kernel_size, kernel_size))weight[range(in_channels), range(out_channels), :, :] = filtreturn weight
代码说明:
- 创建双线性插值权重核
- 中心权重最大,向边缘逐渐减小
- 用于初始化转置卷积,实现平滑的上采样
6. 构建FCN-ResNet18网络
# 1) 加载预训练 ResNet18,做 encoder
pretrained_net = torchvision.models.resnet18(pretrained=True)
net = nn.Sequential(*list(pretrained_net.children())[:-2])# 2) segmentation head
num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv',nn.ConvTranspose2d(num_classes, num_classes,kernel_size=64, padding=16, stride=32))# 3) 初始化反卷积为双线性插值
net.transpose_conv.weight.data.copy_(bilinear_kernel(num_classes, num_classes, 64))
代码说明:
- 编码器:使用ResNet18(去掉最后两层)
- 1×1卷积:将512特征通道转换为21个类别通道
- 转置卷积:32倍上采样,恢复原始分辨率
- 双线性初始化:避免棋盘伪影,加速收敛
7. 演示颜色映射过程
def demonstrate_colormap2label():"""演示colormap2label的使用方法"""print("=== 演示colormap2label映射 ===")test_colors = [[0, 0, 0], # 背景 - 黑色[128, 0, 0], # 飞机 - 红色[0, 128, 0], # 自行车 - 绿色[192, 128, 128], # 人 - 灰色[255, 255, 255] # 不在VOC中的颜色]for color in test_colors:color_index = (color[0] * 256 + color[1]) * 256 + color[2]class_id = colormap2label[color_index].item()if class_id == 0 and color != [0, 0, 0]:class_name = "未知类别"else:class_name = VOC_CLASSES[class_id]print(f"RGB{color} -> 索引{color_index} -> 类别{class_id}: {class_name}")# 运行演示
demonstrate_colormap2label()
输出示例:
=== 演示colormap2label映射 ===
RGB[0, 0, 0] -> 索引0 -> 类别0: background
RGB[128, 0, 0] -> 索引8388608 -> 类别1: aeroplane
RGB[0, 128, 0] -> 索引32768 -> 类别2: bicycle
RGB[192, 128, 128] -> 索引12632256 -> 类别15: person
RGB[255, 255, 255] -> 索引16777215 -> 类别0: 未知类别
8. 加载VOC数据集
batch_size = 32
crop_size = (320, 480)
print("\n加载VOC数据集...")
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)# 检查一个批次
for X, Y in train_iter:print(f"输入图像形状: {X.shape}") # (batch, 3, H, W)print(f"标签形状: {Y.shape}") # (batch, H, W) - 已经是类别IDbreak
代码说明:
d2l.load_data_voc自动处理数据加载和预处理- 输入图像:
(32, 3, 320, 480)- 批量32,3通道,320×480分辨率 - 标签:
(32, 320, 480)- 每个像素是0-20的类别ID
9. 定义损失函数
def loss(inputs, targets):return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)
代码说明:
inputs:(N, 21, H, W)- 21个类别的概率图targets:(N, H, W)- 每个像素的真实类别ID- 先计算每个像素的交叉熵,然后在空间维度取平均
10. 模型训练
num_epochs = 5
lr = 0.001
wd = 1e-3
devices = d2l.try_all_gpus()print(f"\n开始训练,使用设备: {devices}")
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)print("训练完成!")
代码说明:
- 训练5个epoch,学习率0.001
- 使用SGD优化器,权重衰减1e-3
- 自动检测并使用所有可用的GPU
d2l.train_ch13封装了标准的训练流程
关键技术点总结
- 全卷积网络:去除全连接层,支持任意尺寸输入
- 编码器-解码器结构:ResNet18编码 + 转置卷积解码
- 双线性初始化:转置卷积权重初始化为双线性插值
- 逐像素分类:每个像素独立进行21分类
- 颜色映射:RGB标签图 → 类别ID图
这个实现展示了现代语义分割网络的核心思想,结合了迁移学习和端到端训练的优势。