Tabnet介绍(Decision Manifolds)和PyTorch TabNet之TabNetRegressor

Tabnet介绍(Decision Manifolds)和PyTorch TabNet之TabNetRegressor

  • Decision Manifolds
  • TabNet
    • 1.核心思想
    • 2. 架构组成
    • 3. 工作流程
    • 4. 优点
  • PyTorch TabNet
    • TabNetRegressor参数
      • 1. 模型相关参数
        • `n_d`
        • `n_a`
        • `n_steps`
        • `gamma`
        • `cat_idxs`
        • `cat_dims`
        • `cat_emb_dim`
      • 2. 训练相关参数
        • `optimizer_fn`
        • `optimizer_params`
        • `scheduler_fn`
        • `scheduler_params`
        • `mask_type`
      • 3. 其他参数
        • `seed`
        • `verbose`
        • `device_name`
    • TabNetRegressor.fit 参数详解
      • 1. 核心训练数据参数
        • `X_train`
        • `y_train`
      • 2. 验证数据参数
        • `eval_set`
        • `eval_name`
        • `eval_metric`
      • 3. 训练控制参数
        • `max_epochs`
        • `patience`
        • `batch_size`
        • `virtual_batch_size`
        • `num_workers`
      • 4. 回调与日志参数
        • `drop_last`
        • `callbacks`
        • `from_unsupervised`
        • `loss_fn `
    • TabNetRegressor.predict 参数
      • 1. 核心参数
        • `X`
        • `batch_size`
        • `num_workers`
        • `from_unsupervised`
        • `return_proba`
        • `verbose`
      • 2. 返回值
  • 参考

Decision Manifolds

指在决策树模型中,数据点通过一系列超平面的分割形成的决策边界。具体来说:

  • 在决策树模型中:决策流形由一系列垂直于特征轴的超平面组成,这些超平面将数据空间划分为多个区域,每个区域代表一个决策区域。例如,一个简单的决策树可能通过比较特征值与某个阈值来决定数据点的分类或回归结果。
  • 适用于表格数据:由于表格数据通常具有结构化特征,决策流形的这种分割方式能够有效地捕捉数据中的线性关系,尤其是在特征维度较低的情况下,能够实现较好的分类或回归性能。
  • 可解释性:决策流形的直观分割使得模型的决策过程易于理解,每个分割超平面都对应一个特定的特征阈值,便于人类解释和理解模型的决策依据。
  • 对比神经网络:与依赖于高维非线性映射的神经网络不同,决策流形提供了一种更直接、更简单的决策方式,这在某些情况下使得决策树模型在表格数据上表现更佳。

此外,决策流形的概念也与模型的归纳偏差相关,即模型在学习过程中倾向于生成符合某种先验知识或规则的解。对于表格数据,决策树模型的决策流形天生具备线性分割的归纳偏差,这有助于它在没有过多参数调整的情况下,仍然能够有效地学习到数据的结构。

TabNet

一种专门为结构化数据(表格数据)设计的深度学习模型,由 Google 提出。它通过注意力机制和可解释性设计,解决了传统神经网络在处理表格数据时透明度不足的问题。以下是 TabNet 的详细解析:

1.核心思想

  • 稀疏注意力机制:TabNet 使用稀疏注意力机制来选择输入特征的子集进行处理,从而减少计算量并提高模型的可解释性。
  • 逐步特征选择:模型逐步选择重要的特征,并忽略不相关的特征,这使得 TabNet 能够专注于对任务最重要的特征。

2. 架构组成

TabNet 的架构主要由以下几个部分组成:

  • Feature Transformer:负责对输入特征进行非线性变换。
  • Attention Mechanism:通过注意力机制选择重要的特征子集。
  • Masking Mechanism:生成掩码,决定哪些特征被选中参与下一步计算。
  • Decoder:用于预测任务(如分类或回归)。

3. 工作流程

  • 输入层:将表格数据输入到模型中。
  • 特征变换:通过 Feature Transformer 对特征进行非线性变换。
  • 注意力选择:使用注意力机制选择重要的特征子集。
  • 掩码生成:生成掩码以决定哪些特征参与下一步计算。
  • 输出层:通过 Decoder 输出预测结果。

4. 优点

  • 可解释性:由于稀疏注意力机制,TabNet 可以明确指出哪些特征对预测结果有贡献。
  • 高效性:通过选择重要特征,减少了不必要的计算。
  • 灵活性:适用于多种任务,包括分类和回归。

PyTorch TabNet

pytorch_tabnet 是基于 PyTorch 实现的 TabNet 模型库,专为结构化数据(表格数据)设计。它提供了高效的特征选择和可解释性功能,适用于分类和回归任务。

TabNetRegressor参数

TabNetRegressor 是一种基于 TabNet 架构的回归模型,适用于结构化数据的回归任务。以下是其主要参数的详细说明:


1. 模型相关参数

n_d
  • 类型: int
  • 默认值: 8
  • 描述: 表示决策路径中每个步骤的维度大小。较大的值会增加模型的表达能力,但也可能导致过拟合。
n_a
  • 类型: int
  • 默认值: 8
  • 描述: 表示注意力机制的维度大小。与 n_d 类似,控制模型的复杂度。
n_steps
  • 类型: int
  • 默认值: 3
  • 描述: 表示 TabNet 模型中的步数(steps),即模型在每轮迭代中选择特征的次数。更大的值可以捕获更多的特征组合。
gamma
  • 类型: float
  • 默认值: 1.3
  • 描述: 控制特征稀疏性的超参数。较大的值会导致更少的特征被选中。
cat_idxs
  • 类型: list[int]
  • 默认值: []
  • 描述: 指定分类特征的索引列表。如果数据集中包含分类变量,需要通过此参数指定。
cat_dims
  • 类型: list[int]
  • 默认值: []
  • 描述: 指定分类特征的类别数量。与 cat_idxs 配合使用,用于定义分类变量的嵌入维度。
cat_emb_dim
  • 类型: int 或 list[int]
  • 默认值: 1
  • 描述: 分类特征的嵌入维度。如果为整数,则所有分类特征共享相同的嵌入维度;如果为列表,则每个分类特征可以有不同的嵌入维度。

2. 训练相关参数

optimizer_fn
  • 类型: function
  • 默认值: Adam
  • 描述: 优化器函数,默认使用 PyTorch 的 Adam 优化器。
optimizer_params
  • 类型: dict
  • 默认值: {‘lr’: 0.02}
  • 描述: 传递给优化器的参数字典,例如学习率(lr)。
scheduler_fn
  • 类型: function
  • 默认值: None
  • 描述: 学习率调度器函数。如果需要动态调整学习率,可以通过此参数指定。
scheduler_params
  • 类型: dict
  • 默认值: None
  • 描述: 传递给学习率调度器的参数字典。
mask_type
  • 类型: str
  • 默认值: “sparsemax”
  • 描述: 特征选择掩码的类型,可选值为 "sparsemax""entmax""sparsemax" 更加常用。

mask_type 参数用于指定特征选择掩码的类型,控制模型在每个决策步骤中选择特征的方式。


3. 其他参数

seed
  • 类型: int
  • 默认值: 0
  • 描述: 随机种子,用于确保结果的可重复性。
verbose
  • 类型: int
  • 默认值: 1
  • 描述: 控制输出日志的详细程度。0 表示静默模式,1 表示普通模式,2 表示调试模式。
device_name
  • 类型: str
  • 默认值: “auto”
  • 描述: 指定计算设备。"auto" 会自动检测是否有 GPU 可用。

TabNetRegressor.fit 参数详解

1. 核心训练数据参数

X_train
  • 必须为 numpy.ndarray 格式,不支持直接传入 pandas.DataFrame 或 pandas.Series。
y_train
  • 必须为 numpy.ndarray 格式,且形状需调整为 (n_samples, 1)。

2. 验证数据参数

eval_set
  • 类型: list[tuple]
  • 默认值: None
  • 描述: 验证集数据列表,格式为 [(X_valid, y_valid)]。支持多个验证集。
eval_name
  • 类型: list[str]
  • 默认值: None
  • 描述: 每个验证集的名称,便于在日志中区分不同验证集。
eval_metric
  • 类型: list[str] 或 callable
  • 默认值: [‘rmse’]
  • 描述: 评估指标,可选值包括 'rmse''mse' 等。也可以传入自定义的评估函数。

3. 训练控制参数

max_epochs
  • 类型: int
  • 默认值: 100
  • 描述: 最大训练轮数。如果提前收敛,则可能在达到最大轮数之前停止。
patience
  • 类型: int
  • 默认值: 10
  • 描述: 早停机制的耐心值。如果验证集性能在连续 patience 轮内没有提升,则停止训练。
batch_size
  • 类型: int
  • 默认值: 1024
  • 描述: 每次迭代的批量大小。较大的批量可能会加速训练,但需要更多的内存。
virtual_batch_size
  • 类型: int
  • 默认值: 128
  • 描述: 虚拟批量大小,用于模拟小批量梯度下降,减少内存占用。
num_workers
  • 类型: int
  • 默认值: 0
  • 描述: 数据加载器中的工作线程数。设置为 0 表示使用主进程加载数据。

4. 回调与日志参数

drop_last
  • 类型: bool
  • 默认值: False
  • 描述: 是否丢弃最后一个不完整的批量数据。
callbacks
  • 类型: list[Callback]
  • 默认值: None
  • 描述: 自定义回调函数列表,例如学习率调度器、早停等。
from_unsupervised
  • 类型: TabNetPretrainer
  • 默认值: None
  • 描述: 如果提供了预训练的 TabNetPretrainer 模型,则从无监督预训练阶段继续训练。

loss_fn
  • 类型: Callable(可调用对象,例如函数或类方法)
  • 默认值: 默认使用均方误差(MSE)损失函数。
  • 描述: 允许用户自定义训练过程中使用的损失函数。

TabNetRegressor.predict 参数

1. 核心参数

X
  • 数据格式应与训练时使用的 X_train 一致。
batch_size
  • 类型: int
  • 默认值: 1024
  • 描述: 每次预测的批量大小。较大的批量可能会加速预测过程,但需要更多的内存。
num_workers
  • 类型: int
  • 默认值: 0
  • 描述: 数据加载器中的工作线程数。设置为 0 表示使用主进程加载数据。
from_unsupervised
  • 类型: TabNetPretrainer
  • 默认值: None
  • 描述: 如果提供了预训练的 TabNetPretrainer 模型,则从无监督预训练阶段继续预测。
return_proba
  • 类型: bool
  • 默认值: False
  • 描述: 是否返回预测的概率分布(仅适用于分类任务)。对于回归任务,此参数无效。
verbose
  • 类型: int
  • 默认值: 1
  • 描述: 控制输出日志的详细程度。0 表示静默模式,1 表示普通模式,2 表示调试模式。

2. 返回值

  • 当 TabNetRegressor.predict 返回的预测结果是一个二维数组(例如 (n_samples, 1))时,可以使用 flatten 方法将其转换为一维数组 (n_samples,)。

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=2025)
for trn_idx, val_idx in skf.split(train_feats[feature_names], train_feats['V_bins']):X_train = train_feats.loc[trn_idx][feature_names].valuesY_train = train_feats.loc[trn_idx]['V'].values.reshape(-1, 1)X_val = train_feats.loc[val_idx][feature_names].valuesY_val = train_feats.loc[val_idx]['V'].values.reshape(-1, 1)print("Train Num: ", len(Y_train))print("Val Num: ", len(Y_val))model = tab_model.TabNetRegressor(n_d = 8,n_a = 8,n_steps = 1,gamma = 1.6,lambda_sparse = 6e-5,n_independent = 4,n_shared = 2,optimizer_fn = torch.optim.AdamW,optimizer_params = dict(lr=0.025),scheduler_fn = torch.optim.lr_scheduler.ReduceLROnPlateau,scheduler_params = dict(mode='min', factor=0.6, patience=3),mask_type = 'entmax',seed=2025,device_name = 'cuda',verbose = 1,)model.fit(X_train=X_train,y_train=Y_train,eval_set=[(X_val, Y_val)],eval_name=['val'],eval_metric=['rmse'],patience=10,max_epochs=100,batch_size=512, virtual_batch_size=128, num_workers=1, drop_last=False,)pred = model.predict()pred = pred.flatten()

参考

1.https://github.com/dreamquark-ai/tabnet
2.https://github.com/google-research/google-research/tree/master/tabnet
3.https://arxiv.org/abs/1908.07442

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/900978.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图像变换方式区别对比(Opencv)

1. 变换示例 import cv2 import matplotlib.pyplot as plotimg cv2.imread(url) img_cut img[100:200, 200:300] img_rsize cv2.resize(img, (50, 50)) (hight,width) img.shape[:2] rotate_matrix cv2.getRotationMatrix2D((hight//2, width//2), 50, 1) img_wa cv2.wa…

Navicat分组、查询分享

1、分组 有些项目业务表比较多,多达几百张,如果通过人眼看,很容易头晕。这时候可以通过Navicat表分组来进行分类。 使用场景 按版本分组按业务功能分组 创建分组 示例:按版本分组,可以将1.0版本的表放到1.0中。 分组…

大模型在初治CLL成人患者诊疗全流程风险预测与方案制定中的应用研究

目录 一、绪论 1.1 研究背景与意义 1.2 国内外研究现状 1.3 研究目的与内容 二、大模型技术与慢性淋巴细胞白血病相关知识 2.1 大模型技术原理与特点 2.2 慢性淋巴细胞白血病的病理生理与诊疗现状 三、术前风险预测与手术方案制定 3.1 术前数据收集与预处理 3.2 大模…

for循环的优化方式、循环的种类、使用及平替方案。

本篇文章主要围绕for循环,来讲解循环处理数据中常见的六种方式及其特点,性能。通过本篇文章你可以快速了解循环的概念,以及循环在实际使用过程中的调优方案。 作者:任聪聪 日期:2025年4月11日 一、循环的种类 1.1 默认有以下类型 原始 for 循环 for(i = 0;i<10;i++){…

穿透三层内网VPC1

网络拓扑: 打开入口web服务 信息收集发现漏洞CVE-2024-4577 PHP CGI Windows平台远程代码执行漏洞&#xff08;CVE-2024-4577&#xff09;复现_cve-2024-4577漏洞复现-CSDN博客 利用POC&#xff1a; 执行成功&#xff0c;那么直接上传马子&#xff0c;注意&#xff0c;这里要…

【计算机网络】同步操作 vs 异步操作:核心区别与实战场景解析

&#x1f4cc; 引言 在网络通信和分布式系统中&#xff0c;**同步&#xff08;Synchronous&#xff09;和异步&#xff08;Asynchronous&#xff09;**是两种基础却易混淆的操作模式。本文将通过代码示例、生活类比和对比表格&#xff0c;帮你彻底理解它们的区别与应用场景。 1…

TensorFlow充分并行化使用CPU

关键字&#xff1a;TensorFlow 并行化、TensorFlow CPU多线程 场景&#xff1a;在没有GPU或者GPU性能一般、环境不可用的机器上&#xff0c;对于多核CPU&#xff0c;有时TensorFlow或上层的Keras默认并没有完全利用机器的计算能力&#xff08;CPU占用没有接近100%&#xff09;…

Kubernetes容器编排与云原生实践

第一部分&#xff1a;Kubernetes基础架构与核心原理 第1章 容器技术的演进与Kubernetes的诞生 1.1 虚拟化技术的三次革命 物理机时代&#xff1a;资源浪费严重&#xff0c;利用率不足15% 虚拟机突破&#xff1a;VMware与Hyper-V实现硬件虚拟化&#xff0c;利用率提升至50% …

Windows 录音格式为什么是 M4A?M4A 怎样转为 MP3 格式

M4A 格式凭借其高效的压缩技术和卓越的音质表现脱颖而出&#xff0c;成为了包括 Windows 在内的众多操作系统默认的录音格式选择。然而&#xff0c;尽管 M4A 格式拥有诸多优点&#xff0c;不同的应用场景有时需要将这些文件转换为其他格式以满足特定需求。 本文将探讨 M4A 格式…

Qt之OpenGL使用Qt封装好的着色器和编译器

代码 #include "sunopengl.h"sunOpengl::sunOpengl(QWidget *parent) {}unsigned int VBO,VAO; float vertices[]{0.5f,0.5f,0.0f,0.5f,-0.5f,0.0f,-0.5f,-0.5f,0.0f,-0.5f,0.5f,0.0f };unsigned int indices[]{0,1,3,1,2,3, }; unsigned int EBO; sunOpengl::~sunO…

HCIP-17 BGP基础2

HCIP-17 BGP基础2 一、bgp的路由黑洞问题 1.bgp的同步功能 ipv4-family unicast IPV4的地址簇 undo synchronization 关闭BGP同步功能 bgp的同步功能原理 当边界路由器从ibgp邻居收到一条路由后&#xff0c;会使用该路由和igp路由表进行比较。 如果在igp路由表中存在…

leetcode_15. 三数之和_java

15. 三数之和https://leetcode.cn/problems/3sum/ 1、题目 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。…

Open Interpreter:重新定义人机交互的开源革命

引言 在人工智能技术蓬勃发展的今天&#xff0c;人机交互的方式正经历着前所未有的变革。Open Interpreter&#xff0c;作为一个开源项目&#xff0c;正在重新定义我们与计算机的互动方式。它允许大型语言模型&#xff08;LLMs&#xff09;在本地运行代码&#xff0c;通过自然…

【JavaScript】错误处理与调试

个人主页&#xff1a;Guiat 归属专栏&#xff1a;HTML CSS JavaScript 文章目录 1. JavaScript 错误处理基础1.1 错误类型1.2 try...catch 语句 2. 错误抛出与自定义错误2.1 throw 语句2.2 自定义错误类型 3. 异步错误处理3.1 Promise 错误处理3.2 async/await 错误处理 4. 调试…

算法基础模板

高精度加法 #include <bits/stdc.h> using namespace std; const int N10005; int A[N],B[N],C[N],al,bl,cl; void add(int A[],int B[],int C[]) {for(int icl-1;~i;i--){C[cl]A[i]B[i];C[cl1]C[cl]/10;C[cl]%10;}if(C[cl])cl; } int main() {string a,b;cin>>a&…

自行搭建一个Git仓库托管平台

1.安装Git sudo apt install git 2.Git本地仓库创建&#xff08;自己选择一个文件夹&#xff09; git init 这里我在 /home/test 下面初始化了代码仓库 1. 首先在仓库中新建一个txt文件&#xff0c;并输入一些内容 2. 将文件添加到仓库 git add test.txt 执行之后没有任何输…

[MySQL]数据库与表创建

欢迎来到啾啾的博客&#x1f431;。 这是一个致力于构建完善 Java 程序员知识体系的博客&#x1f4da;。 它记录学习点滴&#xff0c;分享工作思考和实用技巧&#xff0c;偶尔也分享一些杂谈&#x1f4ac;。 欢迎评论交流&#xff0c;感谢您的阅读&#x1f604;。 本篇简单记录…

相机回调函数为静态函数原因

在注册相机SDK的回调函数时&#xff0c;是否需要设置为静态函数取决于具体SDK的设计要求&#xff0c;但通常需要遵循以下原则&#xff1a; 1. 必须使用静态函数的情况 当相机SDK是C语言接口或要求普通函数指针时&#xff0c;回调必须声明为静态成员函数或全局函数&#xff1a;…

《Vue Router实战教程》4.路由的匹配语法

欢迎观看《Vue Router 实战&#xff08;第4版&#xff09;》视频课程 路由的匹配语法 大多数应用都会使用 /about 这样的静态路由和 /users/:userId 这样的动态路由&#xff0c;就像我们刚才在动态路由匹配中看到的那样&#xff0c;但是 Vue Router 可以提供更多的方式&#…

Debezium报错处理系列之第128篇:增量快照报错java.lang.OutOfMemoryError: Java heap space

Debezium报错处理系列之第128篇:增量快照报错java.lang.OutOfMemoryError: Java heap space 一、完整报错二、错误原因三、解决方法Debezium从入门到精通系列之:研究Debezium技术遇到的各种错误解决方法汇总: Debezium从入门到精通系列之:百篇系列文章汇总之研究Debezium技…