基于UNet算法的农业遥感图像语义分割——补充版

前言

本案例希望建立一个UNET网络模型,来实现对农业遥感图像语义分割的任务。本篇博客主要包括对上一篇博客中的相关遗留问题进行解决,并对网络结构进行优化调整以适应个人的硬件设施——NVIDIA GeForce RTX 3050。

本案例的前两篇博客直达链接基于UNet算法的农业遥感图像语义分割(下)和基于UNet算法的农业遥感图像语义分割(上)

1.模型简化

1.1 二分类语义分割效果解答

上一篇博客最终的预测结果为二分类的语义分割,即经过彩色映射后,结果只有黑和蓝两种颜色。原因是因为模型虽然参数更新了1400多次,但其实从遍历数据集的角度考虑也就65个epoch.在这里插入图片描述
同时网络模型参数量约7.7M,模型并未充分学习到训练集上的信息。之所以会出现二分类的预测结果,是与模型初始化权重有关。

1.2网络模型调整

因此针对上述情况,我将模型改成了单层的编码器-解码器架构,同时将Block模块中做进一步特征融合的卷积层移除,具体结构如下所示:

class Block(nn.Module):def __init__(self, in_channels, out_channels):super(Block, self).__init__()self.relu = nn.ReLU(inplace=False)self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)def forward(self, x):x = self.conv1(x)x = self.relu(x)# x = self.conv2(x)# x = self.relu(x)return xclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.relu = nn.ReLU(inplace=False)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 编码器部分self.conv1 = Block(3, 32)# 解码器部分self.up2 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)self.conv2 = Block(32, 32)self.conv3 = nn.Conv2d(32, 4, kernel_size=1)def forward(self, x):# 编码器conv1 = self.conv1(x)  # 32, 512, 512pool1 = self.pool(conv1)  # 32, 256, 256# 解码器up2 = self.up2(pool1)  # 32, 512, 512conv2 = self.conv2(up2)  # 32, 512, 512conv3 = self.conv3(conv2)  # 4, 512, 512return conv3

此时查看模型的信息如下所示:
在这里插入图片描述
模型的参数量已经减少至14.4k,可以预见结果并不会很好。因为输入的图像尺寸就已经512×512×3,相比而言,该模型显然不能充分拟合该任务。

2.训练策略调整

2.1训练损失波动解答

因为统计的损失是按照每个iter进行统计的,每次的迭代过程在该批次下的参数更新朝着当前批次损失变小的方向进行,但对其他批次可能损失会升高,因此损失波动剧烈,但整体呈下降趋势。
这里的解决方案如下:

  1. 将参数更新过程中记录的iter次数进行减少,如将iter%10==0调整成iter%200==0
  2. 将参数更新过程中的记录的结果转换成累积量,即将10个iter中损失进行累加或者将一个epoch中的所有损失进行累加(本案例后续改进采用该方式)。
  3. 将参数更新过程中的记录的结果转换成平均量,即将10个iter中损失进行平均或者将一个epoch中的所有损失进行平均。

2.2训练过程调整

因为本案例的数据集本身就很小,所以这里采用的是将一个epoch中的所有损失进行累加统计进行输出可视化。同时为了避免模型参数保存冗余问题,将模型保存策略进行调整,只保存在验证集上损失最小的模型,同时使用覆盖原则将之前的保存模型进行覆盖,以节省空间开销,具体代码调整如下:

    # 创建一个 SummaryWriter 对象,用于将数据写入 TensorBoardwriter = SummaryWriter("dataset/logs")epoch = 0best_val_loss = float('inf')# best_val_loss = 7.899# model.load_state_dict(torch.load('./models/secweights_40.pth'))while epoch < 500:epoch += 1print("---------第{}轮训练开始---------".format(epoch))train_loss = 0for i, (img, label) in tqdm(enumerate(dataloader_train)):img = img.to(device).float()label = label.long().to(device)model.train()output = model(img)# output = torch.argmax(output, dim=1).double()# iter_num += 1loss = getLoss(output, label)train_loss += loss.item()loss.backward()optimizer.step()optimizer.zero_grad()# print("---------第{}轮训练结束---------".format(epoch))print("第{}轮训练的损失为:{}".format(epoch, train_loss))writer.add_scalar('Training Loss3', train_loss, epoch)if epoch % 10 == 0:# torch.save(model.state_dict(), './models/thirdweights_{}.pth'.format(epoch))val_loss = 0with torch.no_grad():model.eval()for i, (img, label) in tqdm(enumerate(dataloader_val)):img = img.to(device).float()label = label.long().to(device)output = model(img)loss = getLoss(output, label)val_loss += loss.item()print("第{}轮验证的损失为:{}".format(epoch, val_loss))if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), './models/best_model2.pth')print("Saved new best model")writer.add_scalar('Validation Loss3', val_loss, epoch)writer.close()

3.结果分析

3.1训练过程损失

在训练过程中的损失记录如下:

在这里插入图片描述
通过结果可以看出上述修改方式确实取得了不错的效果,模型训练集的抖动已大幅度减小。
从曲线角度考虑,训练集损失已经趋向于平稳,同时验证集上损失也趋向于平稳,由此判断模型已经基本收敛,但训练集的损失仍停留在较高水平,大概率是因为模型过于简单,难以拟合该任务的需求。

3.2模型预测结果

这里将模型最终保存的结果加载进来,对未知图片进行预测,代码如下:

import matplotlib.pyplot as plt
import torch
import cv2
import numpy as np
from torch.utils.tensorboard import SummaryWriterfrom Net2 import Net# I=cv2.imread('dataset/0.9/image/16213.png')#dataset/test.png
I=cv2.imread('dataset/test.png')
I=np.transpose(I, (2, 0, 1))
I=I/255.0
I=I.reshape(1,3,512,512)
I=torch.tensor(I)
model=Net().double()
model.load_state_dict(torch.load('models/best_model2.pth'))
output=model(I)
# print(output.shape)
# print(output[0,:,:5,:5])
predicted_classes = torch.argmax(output, dim=1).squeeze(0).numpy()color_map = {0: [0, 0, 0],  # 黑色1: [255, 0, 0],  # 红色2: [0, 255, 0],  # 绿色3: [0, 0, 255]  # 蓝色
}height, width = predicted_classes.shape
colored_image = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(height):for j in range(width):class_id = predicted_classes[i, j]colored_image[i, j] = color_map[class_id]plt.imshow(colored_image)
plt.axis('off')
plt.show()
print(colored_image.shape)
colored_image=np.transpose(colored_image, (2, 0, 1))
writer=SummaryWriter('dataset/logs')writer.add_image('test3',colored_image)
writer.close()

预测结果如下:
在这里插入图片描述
从结果角度考虑,确实实现了四分类的语义分割效果,但预测的效果并不是很好,因此需要进一步修改网络结构。

4.网络模型优化

具体修改主要包括引入批量规范化BatchNormalization的处理和增加了Dropout的机制以及对网络结构调整为三层的编码器-解码器架构。

4.1 BatchNormalization

批量规范化的核心思想是对每一层的输入进行归一化处理,使得每一层的输入分布在训练过程中保持相对稳定。具体来说,它将输入数据的每个特征维度都归一化到均值为 0、方差为 1 的标准正态分布。这样可以减少内部协变量偏移的影响,加快训练速度。

这里还有其他的逐层归一化方式,这里不做详细介绍。因为BatchNormalization聚焦于小批量层面,更适用于该任务,或者说更适用视觉图像处理方面

在这里插入图片描述
图片来源:本校《深度学习》课程的PPT

4.2 Dropout的机制

Dropout的机制能有效防止过拟合,在训练神经网络时,它通过以一定的概率随机将神经元的输出设置为0,即暂时“丢弃”这些神经元及其连接,每次迭代训练时在训练一个不同的子网络,通过多个子网络的综合效果来提高模型的泛化能力。类似于基学习器集成学习的思想。

4.3网络模型代码

上述的两种方式是针对Block模块的,这里为了更好的拟合语义分割的任务,需要进一步加深网络结构,考虑到硬件资源有限,于是使用的是三层编码器-解码器架构,修改后的网络模型完整代码如下:

class Block(nn.Module):def __init__(self, in_channels, out_channels, dropout_rate=0.1):super(Block, self).__init__()self.relu = nn.ReLU(inplace=False)self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.dropout1 = nn.Dropout2d(p=dropout_rate)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.dropout2 = nn.Dropout2d(p=dropout_rate)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.dropout1(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.dropout2(x)return xclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.relu = nn.ReLU(inplace=False)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 编码器部分self.conv1 = Block(3, 32)self.conv2 = Block(32, 64)self.conv3 = Block(64, 128)# 解码器部分self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.conv4 = Block(128, 64)self.up5 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)self.conv5 = Block(64, 32)self.conv6 = nn.Conv2d(32, 4, kernel_size=1)def forward(self, x):# 编码器conv1 = self.conv1(x)  # 32, 512, 512pool1 = self.pool(conv1)  # 32, 256, 256conv2 = self.conv2(pool1)  # 64, 256, 256pool2 = self.pool(conv2)  # 64, 128, 128conv3 = self.conv3(pool2)  # 128, 128, 128# 解码器up4 = self.up4(conv3)  # 64, 256, 256conv4 = torch.cat([up4, conv2], dim=1)  # 128, 256, 256conv4 = self.conv4(conv4)  # 64, 256, 256up5 = self.up5(conv4)  # 32, 512, 512conv5 = torch.cat([up5, conv1], dim=1)  # 64, 512, 512conv5 = self.conv5(conv5)  # 32, 512, 512conv6 = self.conv6(conv5)  # 4, 512, 512return conv6

5.改进模型结果分析

训练策略和之前保持不变,这里就不重复解释,只对结果进行说明。

5.1训练过程损失

训练过程损失记录如下:
在这里插入图片描述
通过结果看到,训练集和验证集损失也基本趋于平稳,因此判断模型基本收敛。

5.2模型预测结果

将之前训练好的模型参数加载进来,对未知图片进行预测,结果如下:
在这里插入图片描述
通过结果可以看出,预测结果相对于之前有了很大的改善,基本实现了语义分割的效果,只是在微小内容上,识别的并不准确。可能是因为模型还是不够复杂,不足以拟合该任务。

6.结语

至此,基于UNET算法的农业遥感图像语义分割任务到此结束,期望能够对你有所帮助。同时该项目也是我接触的第一个语义分割项目,解释的如有不足还请批评指出!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/903624.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Compose笔记(二十一)--AnimationVisibility

这一节主要了解一下Compose的AnimationVisibility,AnimatedVisibility 是 Jetpack Compose 里用于实现组件可见性动画效果的组件&#xff0c;借助它能让组件在显示和隐藏时带有平滑的过渡动画&#xff0c;从而提升用户体验。现总结如下: API 1. visible 含义&#xff1a;这是一…

基于 HT 构建 2D 智慧仓储可视化系统的技术解析

在当今数字化时代&#xff0c;仓储管理对于企业的运营效率和成本控制愈发关键。图扑软件&#xff08;Hightopo&#xff09;凭借其强大的 HT for Web 产品&#xff0c;打造出 2D 智慧仓储可视化平台&#xff0c;为仓储管理带来了全新的技术解决方案。 HT 是一款基于 WebGL、can…

HTML ASCII 编码详解

HTML ASCII 编码详解 引言 HTML&#xff08;HyperText Markup Language&#xff09;是一种用于创建网页的标准标记语言。在HTML中&#xff0c;字符的表示方式非常重要&#xff0c;因为它直接影响到网页内容的显示效果。ASCII编码作为一种基本的字符编码方式&#xff0c;在HTM…

pinia-plugin-persistedstate的使用

pinia持久化存储的使用 安装 npm install pinia-plugin-persistedstate 注册 import { createPinia } from pinia import piniaPluginPersistedstate from pinia-plugin-persistedstateconst pinia createPinia() pinia.use(piniaPluginPersistedstate)export default pinia …

Vue:el-table-tree懒加载数据

目录 一、出现场景二、具体使用三、修改时重新加载树节点四、新增、删除重新加载树节点 一、出现场景 在项目的开发过程中&#xff0c;我们经常会使用到表格树的格式&#xff0c;但是犹豫数据较多&#xff0c;使用分页又不符合项目需求时&#xff0c;就需要对树进行懒加载的操…

ChipCN IDE KF32 导入工程后,无法编译的问题

使用ChipON IDE for KungFu32 导入已有的工程是时&#xff0c;发现能够编译&#xff0c;但是点击&#xff0c;同时选择硬件调试时 没有任何响应。查看工程调试配置时&#xff0c;发现如下问题&#xff1a; 没有看到添加有启动配置&#xff0c;说明就是这里的问题了(应该是IDE的…

前端笔记-Element-Plus

结束了vue的基础学习&#xff0c;现在进一步学习组件 Element-Plus部分学习目标&#xff1a; Element Plus1、查阅官方文档指南2、学习常用组件的使用方法3、Table、Pagination、Form4、Input、Input Number、Switch、Select、Date Picker、Button5、Message、MessageBox、N…

C++入门小馆: 模板

嘿&#xff0c;各位技术潮人&#xff01;好久不见甚是想念。生活就像一场奇妙冒险&#xff0c;而编程就是那把超酷的万能钥匙。此刻&#xff0c;阳光洒在键盘上&#xff0c;灵感在指尖跳跃&#xff0c;让我们抛开一切束缚&#xff0c;给平淡日子加点料&#xff0c;注入满满的pa…

强化学习之基于无模型的算法之基于值函数的深度强化学习算法

3、基于值函数的深度强化学习算法 1&#xff09;深度Q网络&#xff08;DQN&#xff09; 核心思想 DQN是一种将Q学习与深度神经网络结合的方法&#xff0c;用于解决高维状态空间的问题。 它以环境的状态作为输入&#xff0c;通过神经网络输出每个动作的 Q 值&#xff0c;智能体…

网络规划和设计

1.结构化综合布线系统包括建筑物综合布线系统PDS&#xff0c;智能大夏布线系统IBS和工业布线系统IDS 2.GB 50311-2016综合布线系统工程设计规范 GB/T 50312-2016综合布线系统工程验收规范 3.结构化布线系统分为6个子系统&#xff1a; 工作区子系统&#xff1b;水平布线子系…

软件设计师-错题笔记-计算机硬件和体系

1. 解析&#xff1a;循环冗余校验码也叫CRC校验码&#xff0c;其中运算包括了模2&#xff08;异或&#xff09;来构造校验位。别的三种没有用到模2的方法。 2. 解析&#xff1a;如果是正数&#xff0c;则是首位为0&#xff0c;其余位全为1&#xff0c;这时最大数(2^(n-1))-1…

OpenCV 4.7企业级开发实战:从图像处理到目标检测的全方位指南

简介 OpenCV作为工业级计算机视觉开发的核心工具库,其4.7版本在图像处理、视频分析和深度学习模型推理方面实现了显著优化。 本文将从零开始,系统讲解OpenCV 4.7的核心特性和功能更新,同时结合企业级应用场景,提供详细代码示例和实战项目,帮助读者掌握从基础图像处理到复…

LeetCode算法题 (除自身以外数组的乘积)Day14!!!C/C++

https://leetcode.cn/problems/product-of-array-except-self/description/ 一、题目分析 给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀…

如何写好Verilog状态机

还记得之前软件的同事说过的一句话。怎么凸显自己的工作量&#xff0c;就是自己给自己写BUG。 看过夏宇闻老师书的都知道&#xff0c;verilog的FSM有moore和mealy,然后有一段&#xff0c;二段&#xff0c;三段式。记得我还是学生的时候&#xff0c;看到这里的时候&#xff0c;感…

晶振频率/稳定度/精度/温度特性的深度解析与测量技巧

在电子设备的精密世界里&#xff0c;晶振如同跳动的心脏&#xff0c;为各类系统提供稳定的时钟信号。晶振的频率、稳定度、精度以及温度特性&#xff0c;这些关键参数不仅决定了设备的性能&#xff0c;更在不同的应用场景中发挥着至关重要的作用。 一、频率选择的本质&#xff…

Kafka-可视化工具-Offset Explorer

安装&#xff1a; 下载地址&#xff1a;Offset Explorer 安装好后如图&#xff1a; 1、下载安装完毕&#xff0c;进行新增连接&#xff0c;启动offsetexplorer.exe&#xff0c;在Add Cluster窗口Properties 选项下填写Cluster name 和 kafka Cluster Version Cluster name (集…

LabVIEW模板之温度监测应用

这是一个温度监测应用程序&#xff0c;基于 Continuous Measurement and Logging 示例项目构建&#xff0c;用于读取模拟温度值&#xff0c;当温度超出给定范围时发出警报 。 这个。 详细说明 运行操作&#xff1a;直接运行该 VI 程序。点击 “Start” 按钮&#xff0c;即可开…

后端[特殊字符][特殊字符]看前端之Row与Col

是的&#xff0c;在 Ant Design 的栅格布局系统中&#xff0c;每个 <Row> 组件确实对应页面上的一个独立行。以下是更详细的解释&#xff1a; 核心概念 组件作用类比现实场景<Row>横向容器&#xff0c;定义一行内容类似 Excel 表格中的一行<Col>纵向分割&am…

[特殊字符] SpringCloud项目中使用OpenFeign进行微服务远程调用详解(含连接池与日志配置)

&#x1f4da; 目录 为什么要用OpenFeign&#xff1f; 在cart-service中整合OpenFeign 2.1 引入依赖 2.2 启用OpenFeign 2.3 编写Feign客户端 2.4 调用Feign接口 开启连接池&#xff0c;优化Feign性能 3.1 引入OkHttp 3.2 配置启用OkHttp连接池 3.3 验证连接池生效 Feign最佳…