CatBoost中的预测偏移和排序提升

CatBoost 中,预测偏移(Prediction Shift)排序提升(Ordered Boosting) 是其关键概念和创新点。CatBoost 通过引入 排序提升 解决了梯度提升决策树(GBDT)算法中常见的 预测偏移问题,从而提高了模型的稳定性和性能。以下是对这两个概念的详细解释:


1. 预测偏移(Prediction Shift)

概念

预测偏移是指在梯度提升决策树(GBDT)训练过程中,由于在模型训练阶段同时使用特征和目标变量可能会导致未来信息泄漏,从而影响模型性能和稳定性。

  • 原因
    在标准 GBDT 算法中,训练样本的目标变量会被用来更新模型,而同时目标变量也会被用于特征变换(如目标统计编码)。这种特征变换过程中可能会使用目标变量的全局信息,从而导致未来样本的信息被泄漏到当前训练样本中。

  • 结果

    • 训练误差较低,但在测试集上表现较差(过拟合)。
    • 对目标变量统计不准确,尤其是分类特征的目标统计编码可能引入偏差。

示例

假设有一个分类特征 x x x 和目标变量 y y y

样本 i i i分类特征 x i x_i xi目标变量 y i y_i yi
1A1
2A0
3A1

如果在第 3 个样本训练过程中使用整个数据集的目标统计均值(包括第 3 个样本本身的 y 3 = 1 y_3 = 1 y3=1),则会导致信息泄漏。例如:

目标统计编码:
编码值 = 总目标值 总样本数 = 1 + 0 + 1 3 = 0.67 \text{编码值} = \frac{\text{总目标值}}{\text{总样本数}} = \frac{1 + 0 + 1}{3} = 0.67 编码值=总样本数总目标值=31+0+1=0.67

这会将第 3 个样本的目标变量泄漏到其特征变换中。


2. 排序提升(Ordered Boosting)

概念

排序提升是 CatBoost 提出的用于解决 预测偏移问题 的方法。其核心思想是在每一轮训练中,严格按照样本的时间或排列顺序,只使用当前样本之前的数据计算特征变换(如目标统计编码),避免了未来信息泄漏。


Ordered Boosting 的实现原理

CatBoost 的 排序提升 使用了一种特殊的数据划分和特征计算方式:

  1. 样本顺序化:

    • 假设样本被排列为 ( x σ 1 , y σ 1 ) , ( x σ 2 , y σ 2 ) , … , ( x σ n , y σ n ) (x_{\sigma_1}, y_{\sigma_1}), (x_{\sigma_2}, y_{\sigma_2}), \dots, (x_{\sigma_n}, y_{\sigma_n}) (xσ1,yσ1),(xσ2,yσ2),,(xσn,yσn),其中 σ \sigma σ 表示样本的排列顺序。
    • 在训练第 i i i 个样本时,仅使用前 i − 1 i-1 i1 个样本的数据来计算特征值。
  2. 目标统计的顺序计算:

    • 对于分类特征,目标统计值的计算严格遵循样本顺序。例如,计算第 i i i 个样本的目标统计值 T S ( x i ) TS(x_i) TS(xi) 时,仅基于样本 ( x 1 , y 1 ) , ( x 2 , y 2 ) , … , ( x i − 1 , y i − 1 ) (x_1, y_1), (x_2, y_2), \dots, (x_{i-1}, y_{i-1}) (x1,y1),(x2,y2),,(xi1,yi1) 的目标变量 y j y_j yj
    • 避免了将当前样本 y i y_i yi 或未来样本的目标值泄漏到统计值中。
  3. 模型更新的顺序化:

    • CatBoost 使用排序提升算法训练决策树时,每棵树的分裂决策仅基于当前模型状态和之前的数据更新。

排序提升算法的伪代码

如下图 14-2 中描述的伪代码:

  1. 对训练样本 ( x , y ) (x, y) (x,y) 按顺序排列为 ( x σ 1 , y σ 1 ) , ( x σ 2 , y σ 2 ) , … , ( x σ n , y σ n ) (x_{\sigma_1}, y_{\sigma_1}), (x_{\sigma_2}, y_{\sigma_2}), \dots, (x_{\sigma_n}, y_{\sigma_n}) (xσ1,yσ1),(xσ2,yσ2),,(xσn,yσn)
  2. 初始化模型 M 0 M_0 M0
  3. 对于每个样本 i i i
    • 根据模型状态 M i − 1 M_{i-1} Mi1 和前 i − 1 i-1 i1 个样本的目标变量 y σ j y_{\sigma_j} yσj 计算目标统计值。
    • 更新模型状态 M i = M i − 1 + Δ M M_i = M_{i-1} + \Delta M Mi=Mi1+ΔM,其中 Δ M \Delta M ΔM 是模型的增量更新(如一棵树的增量效果)。
  4. 输出最终模型 M n M_n Mn

排序提升的优点
  1. 避免信息泄漏:

    • 通过按顺序计算特征值和模型更新,确保每个样本的特征计算只依赖于之前的样本信息。
    • 解决了传统梯度提升算法中的预测偏移问题。
  2. 提高模型鲁棒性:

    • 排序提升能够更好地适应分类特征中高基数、稀疏类别的情况。
    • 即使样本数量有限,也能生成稳定的特征统计值。
  3. 改进模型性能:

    • 避免了模型过拟合,提升了测试集上的性能。

排序提升的一个例子

假设训练样本如下:

样本 i i i分类特征 x i x_i xi目标变量 y i y_i yi
1A1
2B0
3A1
4B1
目标统计值的计算:

对于分类特征 x x x,计算目标统计值 T S ( x i ) TS(x_i) TS(xi) 时:

  1. 第 1 行:

    • T S ( x 1 ) TS(x_1) TS(x1):没有之前的样本,所以使用全局均值 p p p
  2. 第 2 行:

    • T S ( x 2 ) TS(x_2) TS(x2):类别 B B B 的目标统计值基于之前样本:
      T S ( x 2 ) = p TS(x_2) = p TS(x2)=p
  3. 第 3 行:

    • T S ( x 3 ) TS(x_3) TS(x3):类别 A A A 的目标统计值基于第 1 行:
      T S ( x 3 ) = y 1 1 = 1 TS(x_3) = \frac{y_1}{1} = 1 TS(x3)=1y1=1
  4. 第 4 行:

    • T S ( x 4 ) TS(x_4) TS(x4):类别 B B B 的目标统计值基于第 2 行:
      T S ( x 4 ) = y 2 1 = 0 TS(x_4) = \frac{y_2}{1} = 0 TS(x4)=1y2=0

总结

  1. 预测偏移(Prediction Shift)

    • 是由于目标变量泄漏到特征变换中引起的模型训练问题,导致过拟合和不稳定性。
  2. 排序提升(Ordered Boosting)

    • 是 CatBoost 的核心创新,通过严格按照时间或排列顺序训练模型,避免了预测偏移问题。
    • 在分类特征处理、目标统计值计算和模型更新中都有应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/60376.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云aliyun gradle安装包下载地址

阿里云 查找你要下载的安装包 macports-distfiles-gradle安装包下载_开源镜像站-阿里云 https://mirrors.aliyun.com/macports/distfiles/gradle/gradle-8.9-bin.zip 腾讯 https://mirrors.cloud.tencent.com/gradle/ https://mirrors.cloud.tencent.com/gradle/ https…

《揭秘观察者模式:作用与使用场景全解析》

在软件开发的世界中,设计模式就像是建筑师手中的蓝图,指导着软件系统的构建。其中,观察者模式是一种极为重要且广泛应用的设计模式。今天,我们就来深入探讨一下观察者模式的作用和使用场景。 一、观察者模式是什么? …

SpringBoot(九)使用Jsoup解析html字符串

目前在做博客相关的功能,在显示文章详情的时候,我看到那些大的博客社区,文章中的图片都是可以点击放大的,我感觉这个功能非常好,我也想做,在PHP版本的博客中已经实现了。 实现原理其实很简单,使用PHP的simple_html_dom库解析HTML字符串,找到其中的img标签,在img标签上…

Spring——容器:IoC

容器:IoC IoC 是 Inversion of Control 的简写,译为“控制反转”,它不是一门技术,而是一种设计思想,是一个重要的面向对象编程法则,能够指导我们如何设计出松耦合、更优良的程序。 Spring 通过 IoC 容器来…

uniapp—android原生插件开发(4uniapp引用aar插件)

本篇文章从实战角度出发,将UniApp集成新大陆PDA设备RFID的全过程分为四部曲,涵盖环境搭建、插件开发、AAR打包、项目引入和功能调试。通过这份教程,轻松应对安卓原生插件开发与打包需求! 一、将android程序打包成aar插件包 直接使…

RedisTemplate序列化设置

前言 在使用 Redis 作为缓存数据库时,我们通常会使用 RedisTemplate 来简化与 Redis 进行交互的操作。而其中一个重要的配置项就是序列化设置,它决定了数据在存储到 Redis 中时的格式。本文将介绍如何进行 RedisTemplate 的序列化设置,以及一…

如何优化Elasticsearch的查询性能?

优化Elasticsearch查询性能可以从以下几个方面进行: 合理设计索引和分片: 确保设置合理的分片和副本数,考虑数据量、节点数和集群大小。根据数据量和节点数量调整分片数量,避免使用过多分片,因为每个分片都需要额外的…

使用R语言survminer获取生存分析高风险和低风险的最佳截断值cut-off

使用R语言进行Cox比例风险模型分析和最佳截断值寻找 引言 在生存分析中,Cox比例风险模型是一种常用的统计方法,用于评估多个变量对生存时间的影响。在临床研究中,我们经常需要根据某些连续变量的预测值来对患者进行分组,以便更好…

ORU——ORAN 无线电单元参考架构

ORU ORU-开放无线电单元ORU 类型O-RU“A类”O-RU“B类” 参考相关文章 ORU-开放无线电单元 ORU(开放无线电单元)的目的是将天线发送和接收的无线电信号转换为数字信号,该数字信号可通过前传传输到分布式单元(DU)。考虑…

FFMPEG录屏(22)--- Linux 下基于X11枚举所有显示屏,并获取大小和截图等信息

众人拾柴火焰高,github给个star行不行? open-traa/traa traa is a versatile project aimed at recording anything, anywhere. The primary focus is to provide robust solutions for various recording scenarios, making it a highly adaptable tool…

卷积核里面的数字表示什么意思?

卷积核里面的数字表示的是一种权重,这些权重在与输入数据进行卷积操作时起着至关重要的作用。简单来说,卷积核是一个小型矩阵,它里面的每个数字都对应着输入数据中某个位置的数值在特征提取过程中的一个系数。 当卷积核在输入数据上滑动时&am…

多线程和线程同步复习

多线程和线程同步复习 进程线程区别创建线程线程退出线程回收全局写法传参写法 线程分离线程同步同步方式 互斥锁互斥锁进行线程同步 死锁读写锁api细说读写锁进行线程同步 条件变量生产者消费者案例问题解答加强版生产者消费者 总结信号量信号量实现生产者消费者同步-->一个…

FlinkPipelineComposer 详解

FlinkPipelineComposer 详解 原文 背景 在flink-cdc 3.0中引入了pipeline机制,提供了除Datastream api/flink sql以外的一种方式定义flink 任务 通过提供一个yaml文件,描述source sink transform等主要信息 由FlinkPipelineComposer解析&#xff0c…

Zustand浅学习

道阻且长,行而不辍,未来可期 之前只是会使用zustand,也没仔细看过zustand的文档,前段时间一个合约朋友问我前端的zustand怎么用,啊,这,是那个笑起来明媚的不像话的帅哥问我问题诶,那我得认真一下…

海量数据迁移:Elasticsearch到OpenSearch的无缝迁移策略与实践

文章目录 一.迁移背景二.迁移分析三.方案制定3.1 使用工具迁移3.2 脚本迁移 四.方案建议 一.迁移背景 目前有两个es集群,版本为5.2.2和7.16.0,总数据量为700T。迁移过程需要不停服务迁移&#…

TypeScript:现代 JavaScript 的超级集

目录 为什么使用 TypeScript? TypeScript 的基本特性 TypeScript 的优势 TypeScript项目实战 简单的命令行任务管理系统 TypeScript 是由微软开发的一个开源编程语言,它是 JavaScript 的一个严格超集。TypeScript 的核心特性是静态类型检查,使得开发者可以在编写代码时…

‌MySQL 5.7和8.0版本在多个方面存在显著区别,主要包括性能优化、新特性引入以及安全性提升

性能优化‌ ‌编码器和解码器‌:MySQL 8.0引入了更快和更高效的编码器和解码器,支持压缩、加密、并发等方面的优化,而MySQL 5.7的编码器和解码器相对较慢。‌认证方式‌:MySQL 8.0默认使用caching_sha2_password作为登录认证插件&…

【贪心算法】贪心算法三

贪心算法三 1.买卖股票的最佳时机2.买卖股票的最佳时机 II3.K 次取反后最大化的数组和4.按身高排序5.优势洗牌(田忌赛马) 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励&#…

QtLua

描述 QtLua 库旨在使用 Lua 脚本语言使 Qt4/Qt5 应用程序可编写脚本。它是 QtScript 模块的替代品。 QtLua 不会为 Qt 生成或使用生成的绑定代码。相反,它提供了有用的 C 包装器类,使 C 和 lua 对象都可以从 lua 和 C 访问。它利用 Qt 元对象系统将 QOb…

Devops业务价值流:敏捷测试最佳实践

在迭代增量开发模式下,我们强调按照用户故事的优先级进行软件小功能的频繁交付。由于迭代周期紧凑,测试与开发活动往往并行进行,测试时间相对有限。为确保在这种快节奏的开发环境中依然能够保持产品质量,我们特制定以下测试阶段的…