PointNet系列【语义分割】自定义数据的模型训练

目录

一、平台

二、数据

三、代码

3.1 文件组织结构

3.2 lasDataLoader.py 读取数据

3.3 修改原始模型的通道数量

3.4 lasTrainSS.py【训练】

3.5 lasTestSS.py【预测】

一、平台

Windows 10

GPU RTX 3090 + CUDA 11.1 + cudnn 8.9.6

Python 3.9

Torch 1.9.1 + cu111

所用的原始代码:https://github.com/yanx27/Pointnet_Pointnet2_pytorch

二、数据

有Classification属性的已经分类的LAS点云

三、代码

分享给有需要的人,代码质量勿喷。

对原始代码进行了简化和注释。

分割结果保存成txt,或者利用 laspy 生成点云。

别问为啥在C盘,问就是2T的三星980Pro

完整代码:https://download.csdn.net/download/xinjiang666/88755213?spm=1001.2014.3001.5501

3.1 文件组织结构

3.2 lasDataLoader.py 读取数据

# 6通道特征:块相对坐标;全局相对坐标:有效果# ### 分类的类别
classNumber = 2 #0-未分类;1-路面# 训练用
class lasDataset(Dataset):def __init__(self, split='train', data_root='dataset', train_ratio=0.6,val_ratio=0.2,test_ratio=0.2, num_point=1024, block_size=1.0, sample_rate=1.0, transform=None):# 局部坐标XYZ(m) rgbpoints = np.transpose(np.array([las.X*lasHeader.scales[0],las.Y*lasHeader.scales[1],las.Z*lasHeader.scales[2],las.red,las.green,las.blue]))self.las_points.append(points)coordMIN, coordMAX = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]self.las_coord_MIN.append(coordMIN), self.las_coord_MAX.append(coordMAX)# labellabels = np.transpose(np.array([las.classification]))labels[labels == 11] = 1self.las_labels.append(labels)num_point_all.append(labels.size)# 标签的统计直方图tmp, _ = np.histogram(labels, range(classNumber + 1))labelweights += tmpdef __getitem__(self, idx):las_idx = self.las_idxs[idx]points = self.las_points[las_idx]   # N * 6labels = self.las_labels[las_idx]   # NN_points = points.shape[0]# # normalize old 9通道特征# selected_points = points[selected_point_idxs, :]  # num_point * 6# current_points = np.zeros((self.num_point, 9))  # num_point * 9# current_points[:, 6] = selected_points[:, 0] / self.las_coord_MAX[las_idx][0]# current_points[:, 7] = selected_points[:, 1] / self.las_coord_MAX[las_idx][1]# current_points[:, 8] = selected_points[:, 2] / self.las_coord_MAX[las_idx][2]# selected_points[:, 0] = selected_points[:, 0] - center[0] # 相对块中心的x# selected_points[:, 1] = selected_points[:, 1] - center[1] # 相对块中心的y# selected_points[:, 3:6] /= 255.0# current_points[:, 0:6] = selected_points# region ### normalize 6通道特征:块相对坐标;全局相对坐标(无颜色)selected_points = points[selected_point_idxs, 0:3]  # num_point * 3block_points = points[selected_point_idxs, 0:3]block_points[:, 0] = selected_points[:, 0] - center[0] # 相对块中心的xblock_points[:, 1] = selected_points[:, 1] - center[1] # 相对块中心的yblock_points[:, 2] = selected_points[:, 2]current_points = np.zeros((self.num_point, 6))  # num_point * 6current_points[:, 0:3] = block_pointscurrent_points[:, 3] = (selected_points[:, 0]-self.las_coord_MIN[las_idx][0]) / (self.las_coord_MAX[las_idx][0]-self.las_coord_MIN[las_idx][0])current_points[:, 4] = (selected_points[:, 1]-self.las_coord_MIN[las_idx][1]) / (self.las_coord_MAX[las_idx][1]-self.las_coord_MIN[las_idx][1])current_points[:, 5] = (selected_points[:, 2]-self.las_coord_MIN[las_idx][2]) / (self.las_coord_MAX[las_idx][2]-self.las_coord_MIN[las_idx][2])# endregion# 测试用
class testDatasetToPred():# prepare to give prediction on each pointsdef __init__(self, data_root, block_points=1024, split='test', stride=0.5, block_size=1.0, padding=0.001):for file in self.file_list:# las文件的绝对路径pathLAS = os.path.join(data_root, file)# 读取文件:ndarrya:点数量,7(xyz rgb l)las = laspy.read(pathLAS)# 头文件信息 偏移和尺度lasHeader = las.headerself.las_offset.append(lasHeader.offsets), self.las_scales.append(lasHeader.scales)# 局部坐标XYZ(m) rgb 真实标签data = np.transpose(np.array([las.X * lasHeader.scales[0], las.Y * lasHeader.scales[1], las.Z * lasHeader.scales[2],las.red,las.green,las.blue, las.classification])) # ndarray=点数量,7# 局部坐标XYZ(m)points = data[:, :3]coordMIN, coordMAX = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]self.las_coord_MIN.append(coordMIN), self.las_coord_MAX.append(coordMAX)self.scene_points_list.append(data[:, :6])# 真实标签labels = data[:, 6]labels[labels == 11] = 1self.semantic_labels_list.append(labels)def __getitem__(self, index):for index_x in range(0, grid_x):# region ### 6通道特征:块相对坐标;全局相对坐标(无颜色)data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0)data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0)normlized_xyz = np.zeros((point_size, 3))temp = points[point_idxs, :]normlized_xyz[:, 0] = (temp[:, 0]-coordMIN[0]) / (coordMAX[0]-coordMIN[0])normlized_xyz[:, 1] = (temp[:, 1]-coordMIN[1]) / (coordMAX[1]-coordMIN[1])normlized_xyz[:, 2] = (temp[:, 2]-coordMIN[2]) / (coordMAX[2]-coordMIN[2])### 6通道特征组合data_batch = np.concatenate((data_batch[:,0:3], normlized_xyz), axis=1)#endregion

3.3 修改原始模型的通道数量

pointnet_sem_seg.pyclass get_model(nn.Module):def __init__(self, num_class):super(get_model, self).__init__()self.k = num_classself.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=6) ###6通道特征
pointnet2_sem_seg.pyclass get_model(nn.Module):def __init__(self, num_classes):super(get_model, self).__init__()self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 6 + 3, [32, 32, 64], False)### 6代表输入网络的通道数量

3.4 lasTrainSS.py【训练】

# 参考
# https://github.com/yanx27/Pointnet_Pointnet2_pytorch
# 先在Terminal运行:python -m visdom.server
# 再运行本文件# True为PointNet++
PN2bool = True
# PN2bool = False# 当前文件的路径
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))# 输出 PointNet训练模型的路径: PointNet
dirModel1 = ROOT_DIR + '/trainModel/pointnet_model'
if not os.path.exists(dirModel1):os.makedirs(dirModel1)
# 输出 PointNet++训练模型的路径
dirModel2 = ROOT_DIR + '/trainModel/PointNet2_model'
if not os.path.exists(dirModel2):os.makedirs(dirModel2)# 日志的路径
pathLog = os.path.join(ROOT_DIR, 'LOG_train.txt')# 训练数据集的路径
pathDataset = os.path.join(ROOT_DIR, 'dataset/lasDatasetClassification/')# 点云语义分割的类别名称:这里只分了2类
classNumber = 2
classes = ['un', 'rs']
class2label = {cls: i for i, cls in enumerate(classes)}
seg_classes = class2label
seg_label_to_cat = {}
for i, cat in enumerate(seg_classes.keys()):seg_label_to_cat[i] = cat

3.5 lasTestSS.py【预测】

# 参考
# https://github.com/yanx27/Pointnet_Pointnet2_pytorch# True为PointNet++
PN2bool = True
# PN2bool = False# save to LAS
import laspy
def SaveResultLAS(newLasPath, las_offsets, las_scales,point_np, rgb_np, label1, label2):# datanewx = point_np[:, 0]+las_offsets[0]newy = point_np[:, 1]+las_offsets[1]newz = point_np[:, 2]+las_offsets[2]newred = rgb_np[:, 0]newgreen = rgb_np[:, 1]newblue = rgb_np[:, 2]newclassification = label1newuserdata = label2minx = min(newx)miny = min(newy)minz = min(newz)# create a new headernewheader = laspy.LasHeader(point_format=3, version="1.2")newheader.scales = np.array([0.0001, 0.0001, 0.0001])newheader.offsets = np.array([minx, miny, minz])newheader.add_extra_dim(laspy.ExtraBytesParams(name="Classification", type=np.uint8))newheader.add_extra_dim(laspy.ExtraBytesParams(name="UserData", type=np.uint8))# create a Lasnewlas = laspy.LasData(newheader)newlas.x = newxnewlas.y = newynewlas.z = newznewlas.red = newrednewlas.green = newgreennewlas.blue = newbluenewlas.Classification = newclassificationnewlas.UserData = newuserdata# writenewlas.write(newLasPath)# 当前文件的路径
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))# 模型的路径
pathTrainModel = os.path.join(ROOT_DIR, 'trainModel/pointnet_model')
if PN2bool:pathTrainModel = os.path.join(ROOT_DIR, 'trainModel/PointNet2_model')# 预测结果路径
visual_dir = ROOT_DIR + '/testResultPN/'
if PN2bool:visual_dir = ROOT_DIR + '/testResultPN2/'
visual_dir = Path(visual_dir)
visual_dir.mkdir(exist_ok=True)# 日志的路径
pathLog = os.path.join(ROOT_DIR, 'LOG_test_eval.txt')# 测试数据的路径
pathDataset = os.path.join(ROOT_DIR, 'dataset/lasDatasetClassification2/')# 点云语义分割的类别名称:这里只分了2类
classNumber = 2
classes = ['un', 'rs']
class2label = {cls: i for i, cls in enumerate(classes)}
seg_classes = class2label
seg_label_to_cat = {}
for i, cat in enumerate(seg_classes.keys()):seg_label_to_cat[i] = cat

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/640646.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每个人都可以是架构师,每个人都需要培养架构思维

您好, 如果喜欢我的文章或者想上岸大厂,可以关注公众号「量子前端」,将不定期关注推送前端好文、分享就业资料秘籍,也希望有机会一对一帮助你实现梦想 什么是架构 “架构”,即架设、构建。完成对于平台的合理架设&am…

VMware安装Linux-Redhat7.9 详细步骤

目录 一、安装准备二、安装步骤 一、安装准备 Redhat 7.9 镜像下载 VMware安装步骤可查看文章:https://blog.csdn.net/a2279338659/article/details/126346345 可去官网下载,或者加群下载镜像资源。 二、安装步骤 创建新的虚拟机: 我这边…

Java学习(二十二)--正则表达式

介绍 为什么需要 正则表达式是处理文本的利器; 基本介绍 正则表达式,又称规则表达式,(Regular Expression,在代码中常简写为regex、regexp或RE)。它是一个强大的字符串处理工具,可以对字符串进行查找、提…

语音模块学习——LSYT201B模组(深圳雷龙科技)

目录 引子 处理器 外设 音频 蓝牙 模组展示 引子 关注我的老粉们应该知道我之前用过语音模块做东西,那个比较贵要50多。 今天这个淘宝20元左右比那个便宜,之前那个内核是51的,一个8位机。 后面我做东西的时候语音模块可能会换成这个&…

【计算机网络】Socket的TCP_NODELAY选项与Nagle算法

TCP_NODELAY是一个套接字选项,用于控制TCP套接字的延迟行为。当TCP_NODELAY选项被启用时,即设置为true,就会禁用Nagle算法,从而实现TCP套接字的无延迟传输。这意味着每次发送数据时都会立即发送,不会等待缓冲区的填充或…

代码随想录算法训练营DAY24|回溯1

算法训练DAY24|回溯1 第77题. 组合 力扣题目链接 给定两个整数 n 和 k,返回 1 ... n 中所有可能的 k 个数的组合。 示例: 输入: n 4, k 2 输出: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4], ] 上面我们说了要解决 n为100,k为50的情况&#xff0…

vscode连不上虚拟机,一直密码错误

最近在做毕设,但是vscode使用连接不上虚拟机,我以为是网络配置的问题,一顿查阅没找到原因。 后来查了一下ssh的日志,发现ssh有消息,但是也提示密码错误。 没找到密码配置格式什么的,经查看sshd配置文件发现…

DLL注入技术

源地址 注入程序 #include <Windows.h> #include <iostream> #include <Tlhelp32.h> #include <stdio.h> #include <tchar.h> #include <iostream>using namespace std;BOOL getProcess32Info(PROCESSENTRY32 *info, const TCHAR proces…

Go语言的映射reflect使用大全

目录 前言 一、映射的基本用法 1.获取类型信息 2.获取值 3.读取和设置值 4.使用Kind来区分类型 5.操作结构体 6.创建新实例 7.调用方法 8.调用方法 二、使用实例 总结 前言 Go语言作为一个高性能的静态语言&#xff0c;我们在写函数的时候&#xff0c;由于go语言的特性&#x…

工业相机与镜头参数及选型

文章目录 1、相机成像系统模型1.1 视场1.2 成像简化模型 2、工业相机参数2.1 分辨率2.2 靶面尺寸2.3 像元尺寸2.4 帧率/行频2.5 像素深度2.6 动态范围2.7 信噪比2.8 曝光时间2.9 相机接口 3、工业镜头参数3.1 焦距3.2 光圈3.3 景深3.4 镜头分辨率3.5 工作距离&#xff08;Worki…

微信小程序入门,学习全局配置与页面配置

目录 一、微信小程序 二、微信小程序的全局配置 三、微信小程序的页面配置 四、全局配置与页面配置的区别 一、微信小程序 微信小程序是一种基于微信平台的应用程序&#xff0c;它可以在微信内部直接运行&#xff0c;无需下载安装。微信小程序具有以下特点和优势&#xff…

Spring Boot自动配置原理

1.SpringBootApplication注解 springboot是基于spring的新型的轻量级框架&#xff0c;最厉害的地方当属**自动配置。**那我们就可以根据启动流程和相关原理来看看&#xff0c;如何实现传奇的自动配置 SpringBootApplication//标注在某个类上&#xff0c;表示这个类是SpringBo…

【技术预研】starRocks高性价比替换hbase

hbase作为类列数据库&#xff0c;更准确说是列族数据库。本质上是一个文件查询系统&#xff0c;追求极限的写入和读取。 而starRocks作为olap数据库&#xff0c;在保持优秀的关联计算能力的前提下&#xff0c;还有不错的查询效率&#xff0c;当然和hbase本身比还有一定差距。 但…

<蓝桥杯软件赛>零基础备赛20周--第15周--快速幂+素数

报名明年4月蓝桥杯软件赛的同学们&#xff0c;如果你是大一零基础&#xff0c;目前懵懂中&#xff0c;不知该怎么办&#xff0c;可以看看本博客系列&#xff1a;备赛20周合集 20周的完整安排请点击&#xff1a;20周计划 每周发1个博客&#xff0c;共20周。 在QQ群上交流答疑&am…

【设计模式】张一鸣笔记:责任链接模式怎么用?

我将通过一个贴近现实的故事——请假审批流程&#xff0c;带你了解和掌握责任链模式。 什么是责任链模式&#xff1f; 责任链模式是一种行为设计模式&#xff0c;它让你可以避免将请求的发送者与接收者耦合在一起&#xff0c;让多个对象都有处理请求的机会将这个对象连成一条…

python基础教程九 抽象二(函数参数)

1. 值从哪里来 定义函数时&#xff0c;你可能心存疑虑&#xff0c;参数的值是怎么来的呢&#xff1f; 在def语句中&#xff0c;位于函数名后面的变量通常称为形式参数&#xff0c;在调用函数时提供的值称为实参&#xff0c;但在本书不做严格区分。 2. 我能修改参数吗 函数通…

同样是IT行业,测试和开发薪资真就差这么大吗?

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

Java进阶之旅第六天

Java进阶之旅第六天 Stream流 Stream的思想 Stream流中引入函数式编程的思想,以类似流水线的方式处理数据,使得代码更加高效整洁Stream中提供并行处理的能力,可以将数据分成多个子任务,并行处理 各类型的调用方法 类型方法说明单列集合default Stream streamCollection中默…

请写出js中的两种定时器,区别是什么?怎么清除定时器?

在JavaScript中有两种常用的定时器&#xff1a;setTimeout 和 setInterval。 setTimeout&#xff1a;此函数用于在指定的毫秒数后执行一次函数或计算出的表达式。例如&#xff0c;如果你想在5秒后打印一条消息&#xff0c;你可以这样做&#xff1a; var myTimer setTimeout(f…

下载csdn文章,并保存md笔记中的图片链接至本地

推荐1个下载别人csdn文章笔记的java项目&#xff1a;csdn-blog2markword-downloader 拿到别人的md笔记后&#xff0c;但是笔记中的图片又是以链接的格式给的&#xff0c;这个链接说不定后面就失效了&#xff0c;笔记也就看不到图片了。手动右键也可以保存图片&#xff0c;但是…