GBDT算法原理及Python实现

一、概述

  GBDT(Gradient Boosting Decision Tree,梯度提升决策树)是集成学习中提升(Boosting)方法的典型代表。它以决策树(通常是 CART 树,即分类回归树)作为弱学习器,通过迭代的方式,不断拟合残差(回归任务)或负梯度(分类任务),逐步构建一系列决策树,最终将这些树的预测结果进行累加,得到最终的预测值。

二、算法原理

1. 梯度下降思想​

  梯度下降是一种常用的优化算法,用于寻找函数的最小值。在 GBDT 中,它扮演着至关重要的角色。假设我们有一个损失函数 L ( y , y ^ ) L\left( y,\hat{y} \right) L(y,y^),其中 y y y是真实值, y ^ \hat y y^是预测值。梯度下降的目标就是通过不断调整模型参数,使得损失函数的值最小化。具体来说,每次迭代时,沿着损失函数关于参数的负梯度方向更新参数,以逐步接近最优解。在 GBDT 中,虽然没有显式地更新参数(通过构建多颗决策树来拟合目标),但拟合的目标是损失函数的负梯度,本质上也是利用了梯度下降的思想。

2. 决策树的构建​

  GBDT 使用决策树作为弱学习器。决策树是一种基于树结构的预测模型,它通过对数据特征的不断分裂,将数据划分成不同的子集,每个子集对应树的一个节点。在每个节点上,通过某种准则(如回归任务中的平方误差最小化,分类任务中的基尼指数最小化)选择最优的特征和分裂点,使得划分后的子集在目标变量上更加 “纯净” 或具有更好的区分度。通过递归地进行特征分裂,直到满足停止条件(如达到最大树深度、节点样本数小于阈值等),从而构建出一棵完整的决策树。

3. 迭代拟合的过程​

(1) 初始化模型

  首先,初始化一个简单的模型,通常是一个常数模型,记为 f 0 ( X ) f_0(X) f0(X) ,其预测值为所有样本真实值的均值(回归任务)或多数类(分类任务),记为 y ^ 0 \hat y_0 y^0。此时,模型的预测结果与真实值之间存在误差。

(2) 计算残差或负梯度

  在回归任务中,计算每个样本的残差,即真实值 y i y_i yi与当前模型预测值 y ^ i , t − 1 \hat y_{i,t-1} y^i,t1的差值 r i , t = y i − y ^ i , t − 1 r_{i,t}=y_i-\hat y_{i,t-1} ri,t=yiy^i,t1,其中表示迭代的轮数。在分类任务中,计算损失函数关于当前模型预测值的负梯度 g i , t = − ϑ L ( y i , y ^ i , t − 1 ) ϑ y ^ i , t − 1 g_{i,t}=-\frac{\vartheta L(y_i,\hat y_{i,t-1})}{\vartheta \hat y_{i,t-1}} gi,t=ϑy^i,t1ϑL(yi,y^i,t1)

(3) 拟合决策树

  使用计算得到的残差(回归任务)或负梯度(分类任务)作为新的目标值,训练一棵新的决策树 f t ( X ) f_t(X) ft(X)。这棵树旨在拟合当前模型的误差,从而弥补当前模型的不足。

(4) 更新模型

  根据新训练的决策树,更新当前模型。更新公式为 y ^ i , t = y ^ i , t − 1 + α f t ( x i ) \hat y_{i,t}=\hat y_{i,t-1}+\alpha f_t(x_i) y^i,t=y^i,t1+αft(xi),其中是学习率(也称为步长),用于控制每棵树对模型更新的贡献程度。学习率较小可以使模型训练更加稳定,但需要更多的迭代次数;学习率较大则可能导致模型收敛过快,甚至无法收敛。

(5) 重复迭代

  重复步骤 (2)–(4)步,不断训练新的决策树并更新模型,直到达到预设的迭代次数、损失函数收敛到一定程度或满足其他停止条件为止。最终,GBDT 模型由多棵决策树组成,其预测结果是所有决策树预测结果的累加。

算法过程图示

在这里插入图片描述
  GBDT 算法将梯度下降思想与决策树相结合,通过迭代拟合残差或负梯度,逐步构建一个强大的集成模型。它在处理复杂数据和非线性关系时表现较为出色,在数据挖掘、机器学习等领域得到了广泛的应用。然而,GBDT 也存在一些缺点,如训练时间较长、对异常值较为敏感等,在实际应用中需要根据具体情况进行优化和调整 。

三、Python实现

(环境:Python 3.11,scikit-learn 1.5.1)

分类情形

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import metrics# 生成样本数据
X, y = make_classification(n_samples=1000, n_features=50, n_informative=10, n_redundant=5, random_state=1)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)# 创建GDBT分类模型
gbc = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=1)# 训练模型
gbc.fit(X_train, y_train)# 进行预测
y_pred = gbc.predict(X_test)# 计算准确率
accuracy = metrics.accuracy_score(y_test,y_pred)
print('准确率为:',accuracy)

在这里插入图片描述

回归情形

from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 生成样本数据
X, y = make_regression(n_samples=1000, n_features=10, n_informative=5, noise=0.1, random_state=42)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建GDBT回归模型
model = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, random_state=42)# 训练模型
model.fit(X_train, y_train)# 在测试集上进行预测
y_pred = model.predict(X_test)# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"MSE: {mse}")

在这里插入图片描述

End.



下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/80973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WordPress开心导航站_一站式网址_资源与资讯垂直行业主题模板

一款集网址、资源与资讯于一体的导航类主题,专为追求高效、便捷用户体验的垂直行业网站而设计无论您是构建行业资讯门户、资源聚合平台还是个人兴趣导航站,这款开心版导航主题都能成为您理想的选择。 核心特色: 一体化解决方案:整合了网址导航、资源下载…

马井堂-区块链技术:架构创新、产业变革与治理挑战(马井堂)

区块链技术:架构创新、产业变革与治理挑战 摘要 区块链技术作为分布式账本技术的革命性突破,正在重构数字时代的信任机制。本文系统梳理区块链技术的核心技术架构,分析其在金融、供应链、政务等领域的实践应用,探讨共识算法优化、…

从像素到驾驶决策:Python与OpenCV赋能自动驾驶图像识别

从像素到驾驶决策:Python与OpenCV赋能自动驾驶图像识别 引言:图像识别的力量驱动自动驾驶 自动驾驶技术正以令人惊叹的速度改变交通方式,而其中最核心的技术之一便是图像识别。作为车辆的“视觉系统”,图像识别可以实时获取道路信息,识别交通标志、车辆、行人等关键目标…

Spring计时器StopWatch 统计各个方法执行时间和占比

Spring计时器StopWatch 用法代码 返回结果是毫秒 一毫秒等于千分之一秒(0.001秒)。因此,如果你有一个以毫秒为单位的时间值,你可以通过将这个值除以1000来将其转换为秒。例如,500毫秒等于0.5秒。 import org.springf…

2.2.2goweb内置的 HTTP 处理程序2

http.StripPrefix http.StripPrefix 是 Go 语言 net/http 包中的一个函数,它的主要作用是创建一个新的 HTTP 处理程序。这个新处理程序会在处理请求之前,从请求的 URL 路径中移除指定的前缀,然后将处理工作委托给另一个提供的处理程序。 使…

【Fifty Project - D20】

今日完成记录 TimePlan完成情况7:30 - 11:30收拾行李闪现广州 & 《挪威的森林》√10:00 - 11:00Leetcode√16:00 - 17:00健身√ Leetcode 每日一题 每日一题来到了滑动窗口系列,今天是越…

【图片识别改名】批量读取图片区域文字识别后批量改名,基于Python和腾讯云的实现方案

项目场景 ​​办公文档管理​​:将扫描的发票、合同等文档按编号、日期自动重命名。例如,识别“编号:2023001 日期:20230403”生成“2023001_20230403.jpg”。​​产品图片整理​​:电商产品图片按产品编号、名称自动命名。例如,…

生物化学笔记:神经生物学概论04 视觉通路简介视网膜视网膜神经细胞大小神经节细胞(视错觉)

视觉通路简介 神经节细胞的胞体构成一明确的解剖层次,其外邻神经纤维层,内接内丛状层,该层在鼻侧厚约10~20μm,最厚在黄斑区约60~80μm。 全部细胞数约为120万个(1000000左右)。 每个细胞有一轴突&#xff…

「Mac畅玩AIGC与多模态08」开发篇04 - 基于 OpenAPI Schema 开发专用 Agent 插件

一、概述 本篇介绍如何在 macOS 环境下,通过编写 OpenAPI Schema,开发自定义的专用插件,让智能体可以调用外部 API,扩展功能至任意在线服务。实践内容基于 Dify 平台,适配 macOS 开发环境。 二、环境准备 1. 确认本地开发环境 macOS 系统Dify 平台已完成部署并可访问本…

【计算机视觉】深度解析MediaPipe:谷歌跨平台多媒体机器学习框架实战指南

深度解析MediaPipe:谷歌跨平台多媒体机器学习框架实战指南 技术架构与设计哲学核心设计理念系统架构概览 核心功能与预构建解决方案1. 人脸检测2. 手势识别3. 姿势估计4. 物体检测与跟踪 实战部署指南环境配置基础环境准备获取源码 构建第一个示例(手部追…

NVIDIA高级辅助驾驶领域的创新实践与云计算教育启示

AI与高级辅助驾驶的时代浪潮 人工智能正在重塑现代交通的面貌,而高级辅助驾驶技术无疑是这场变革中最具颠覆性的力量之一。作为全球AI计算的领军企业,NVIDIA凭借其全栈式技术生态和创新实践,为高级辅助驾驶的产业化落地树立了标杆。从芯片到…

头歌实训之存储过程、函数与触发器

🌟 各位看官好,我是maomi_9526! 🌍 种一棵树最好是十年前,其次是现在! 🚀 今天来学习C语言的相关知识。 👍 如果觉得这篇文章有帮助,欢迎您一键三连,分享给更…

医学图像处理软件中几种MPR

1:设备厂商的MPR 2:后处理的MPR 3:阅片PACS的MPR 4:手术导航 手术规划的MPR 设备厂商的MPR需求更多是扫描线、需要3DMPR ,三条定位线的任意角度旋转。 后处理的MPR,需求更多的是算法以及UI工具的研发&a…

java 类的实例化过程,其中的相关顺序 包括有继承的子类等复杂情况,静态成员变量的初始化顺序,这其中jvm在干什么

Java类的实例化过程及初始化顺序 Java类的实例化过程涉及多个步骤,特别是在存在继承关系和静态成员的情况下。下面我将详细解释整个过程,包括JVM在其中的角色。 1. 类加载阶段(JVM的工作) 在实例化一个类之前,JVM首…

Sce2DriveX: 用于场景-到-驾驶学习的通用 MLLM 框架——论文阅读

《Sce2DriveX: A Generalized MLLM Framework for Scene-to-Drive Learning》2025年2月发表,来自中科院软件所和中科院大学的论文。 端到端自动驾驶直接将原始传感器输入映射到低级车辆控制,是Embodied AI的重要组成部分。尽管在将多模态大语言模型&…

【题解-Acwing】870. 约数个数

题目:870. 约数个数 题目描述 给定 n 个正整数 ai,请你输出这些数的乘积的约数个数,答案对 109+7 取模。 输入 第一行包含整数 n。 接下来 n 行,每行包含一个整数 ai。 输出 输出一个整数,表示所给正整数的乘积的约数个数,答案需对 109+7 取模。 数据范围 1 ≤ …

创龙全志T536全国产(4核A55 ARM+RISC-V+NPU 17路UART)工业开发板硬件说明书

前 言 本文档主要介绍TLT536-EVM评估板硬件接口资源以及设计注意事项等内容。 T536MX-CXX/T536MX-CEN2处理器的IO电平标准一般为1.8V、3.3V,上拉电源一般不超过3.3V或1.8V,当外接信号电平与IO电平不匹配时,中间需增加电平转换芯片或信号隔离芯片。按键或接口需考虑ESD设计…

Redis 持久化双雄:RDB 与 AOF 深度解析

Redis 是一种内存数据库,为了保证数据在服务器重启或故障时不丢失,提供了两种持久化方式:RDB(Redis Database)和 AOF(Append Only File)。以下是它们的详细介绍: 一、RDB 持久化 工…

数据结构|并查集

Hello !朋友们,这是我在学习过程中梳理的笔记,以作以后复习回顾,有时略有潦草,一些话是我用自己的话描述的,可能不够准确,还是感谢大家的阅读! 目录 一、并查集Quickfind 二、两种算…

【GPU 微架构技术】Pending Request Table(PRT)技术详解

PRT(Pending Request Table)是 GPU 中用于管理 未完成内存请求(outstanding memory requests)的一种硬件结构,旨在高效处理大规模并行线程的内存访问需求。与传统的 MSHR(Miss Status Handling Registers&a…