9.XGBoost

本教程是机器学习系列的一部分。 在此步骤中,您将学习如何使用功能强大的xgboost库构建和优化模型。

What is XGBoost

XGBoost是处理标准表格数据的领先模型(您在Pandas DataFrames中存储的数据类型,而不是像图像和视频这样的更奇特的数据类型)。 XGBoost模型在许多Kaggle比赛中占据主导地位。

为了达到峰值精度,XGBoost模型比Random Forest等技术需要更多的知识和模型调整。 在本教程之后,你将能够

  •      遵循XGBoost的完整建模工作流程
  •      微调XGBoost模型以获得最佳性能

XGBoost是Gradient Boosted决策树算法的一种实现(scikit-learn有另一个版本的算法,但XGBoost有一些技术优势。)什么是Gradient Boosted决策树? 我们将通过一个图表。

xgboost image

我们经历了重复构建新模型的循环,并将它们组合成一个整体模型。 我们通过计算数据集中每个观察的误差来开始循环。 然后我们构建一个新模型来预测。 我们将此误差预测模型的预测添加到“模型集合”中。

为了进行预测,我们添加了以前所有模型的预测。 我们可以使用这些预测来计算新误差,构建下一个模型,并将其添加到整体中。

那个周期之外还有一件。 我们需要一些基础预测来开始循环。 在实践中,最初的预测可能非常幼稚。 即使它的预测非常不准确,随后对整体的添加将解决这些错误。

这个过程可能听起来很复杂,但使用它的代码很简单。 我们将在下面的模型调整部分填写一些额外的解释性细节。

Example

我们将从预先加载到train_X,test_X,train_y,test_x的数据开始。

[1]

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Imputerdata = pd.read_csv('../input/train.csv')
data.dropna(axis=0, subset=['SalePrice'], inplace=True)
y = data.SalePrice
X = data.drop(['SalePrice'], axis=1).select_dtypes(exclude=['object'])
train_X, test_X, train_y, test_y = train_test_split(X.as_matrix(), y.as_matrix(), test_size=0.25)my_imputer = Imputer()
train_X = my_imputer.fit_transform(train_X)
test_X = my_imputer.transform(test_X)
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:9: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.if __name__ == '__main__':
/opt/conda/lib/python3.6/site-packages/sklearn/utils/deprecation.py:58: DeprecationWarning: Class Imputer is deprecated; Imputer was deprecated in version 0.20 and will be removed in 0.22. Import impute.SimpleImputer from sklearn instead.warnings.warn(msg, category=DeprecationWarning)

我们像在scikit-learn建立模型和拟合。

【2】

from xgboost import XGBRegressormy_model = XGBRegressor()
# Add silent=True to avoid printing out updates with each cycle
my_model.fit(train_X, train_y, verbose=False)

我们同样评估模型并像在scikit-learn中那样进行预测。

【3】

# make predictions
predictions = my_model.predict(test_X)from sklearn.metrics import mean_absolute_error
print("Mean Absolute Error : " + str(mean_absolute_error(predictions, test_y)))
Mean Absolute Error : 17543.750299657535

Model Tuning

XGBoost有一些参数可以显着影响您的模型的准确性和训练速度。您应该了解的第一个参数是:
n_estimators 和 early_stopping_rounds

n_estimators指定完成上述建模周期的次数。

在欠拟合vs过拟合图中,n_estimators将您向右移动。值太低会导致欠拟合,这对训练数据和新数据的预测都是不准确的。太大的值会导致过度拟合,这时对训练数据的预测准确,但对新数据的预测不准确(这是我们关心的)。您可以试验数据集以找到理想值。典型值范围从100到1000,但这很大程度上取决于下面讨论的学习率。

参数early_stopping_rounds提供了一种自动查找理想值的方法。过早停止会导致模型在验证分数停止改善时停止迭代,即使我们不是n_estimators的硬停止。为n_estimators设置一个高值然后使用early_stopping_rounds找到停止迭代的最佳时间是明智的。

由于随机机会有时会导致单轮,其中验证分数没有提高,因此您需要指定一个数字,以确定在停止前允许多少轮直线恶化。 early_stopping_rounds = 5是一个合理的值。因此,我们在连续5轮恶化的验证分数后停止。

以下是适合early_stopping的代码:

【4】

my_model = XGBRegressor(n_estimators=1000)
my_model.fit(train_X, train_y, early_stopping_rounds=5, eval_set=[(test_X, test_y)], verbose=False)
XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,max_depth=3, min_child_weight=1, missing=None, n_estimators=1000,n_jobs=1, nthread=None, objective='reg:linear', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, subsample=1)

使用early_stopping_rounds时,您需要留出一些数据来检查要使用的轮数。 如果您以后想要使用所有数据拟合模型,请将n_estimators设置为在过早停止运行时发现的最佳值。

learning_rate

对于更好的XGBoost模型,这是一个微妙但重要的技巧:

我们不是通过简单地将每个组件模型中的预测相加来获得预测,而是将每个模型的预测乘以一个小数字,然后再添加它们。这意味着我们添加到集合中的每个树都会减少我们的预测。 在实践中,这降低了模型过度拟合的倾向。

因此,您可以使用更高的n_estimators值而不会过度拟合。 如果使用过早停止,将自动设置适当数量的树。

一般来说,较小的学习率(以及大量的估算器)将产生更准确的XGBoost模型,尽管它也会使模型更长时间进行训练,因为它在整个循环中进行了更多的迭代。

修改上面的示例以包含学习率将产生以下代码:

【5】

my_model = XGBRegressor(n_estimators=1000, learning_rate=0.05)
my_model.fit(train_X, train_y, early_stopping_rounds=5, eval_set=[(test_X, test_y)], verbose=False)
XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=0.05, max_delta_step=0,max_depth=3, min_child_weight=1, missing=None, n_estimators=1000,n_jobs=1, nthread=None, objective='reg:linear', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, subsample=1)

n_jobs

在较大数据集上需要考虑运行时间,您可以使用并行性来更快地构建模型。通常将参数n_jobs设置为等于计算机上的核心数。在较小的数据集上,这无济于事。

由此产生的模型将不会更好,因此对于拟合时间的微优化通常只是分散注意力。但是,它在大型数据集中非常有用,否则您将在fit命令期间等待很长时间。

XGBoost有许多其他参数,但这些参数将消耗更长时间帮助您微调XGBoost模型以获得最佳性能。

Conclusion


XGBoost是目前用于在传统数据(也称为表格或结构数据)上构建精确模型的主要算法。去应用它来改进你的模型!

Your Turn


使用XGBoost转换你的模型。
使用提前停止为n_estimators找到一个好的值。然后使用所有训练数据和n_estimators的值重新估计模型。
完成后,返回学习机器学习,继续改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/440326.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

*【HDU - 5711】Ingress(tsp旅行商问题,优先队列贪心,状压dp,floyd最短路,图论)

题干: Brickgao, who profited from your accurate calculating last year, made a great deal of money by moving bricks. Now he became gay shy fool again and recently he bought an iphone and was deeply addicted into a cellphone game called Ingress. …

ajax get请求成功,成功()函数的AJAX GET请求

后不叫我有一个jQuery的AJAX脚本像下面:成功()函数的AJAX GET请求function FillCity() {var stateID $("#ddlState").val();$.ajax({url: Url.Action("Employee", "Index"),type: "GET",dataType: "json",data:…

《TCP/IP详解》学习笔记(二):数据链路层

数据链路层有三个目的: 为IP模块发送和 接收IP数据报。为ARP模块发送ARP请求和接收ARP应答。为RARP发送RARP请 求和接收RARP应答ip大家都听说过。至于ARP和RARP,ARP叫做地址解析协议,是用IP地址换MAC地址的一种协议,而RARP则叫…

【POJ - 2762】Going from u to v or from v to u?(Tarjan缩点,树形dp 或 拓扑排序,欧拉图相关)

题干: In order to make their sons brave, Jiajia and Wind take them to a big cave. The cave has n rooms, and one-way corridors connecting some rooms. Each time, Wind choose two rooms x and y, and ask one of their little sons go from one to the o…

《TCP/IP详解》学习笔记(三):IP协议、ARP协议

把这三个协议放到一起学习是因为这三个协议处于同一层,ARP 协议用来找到目标主机的 Ethernet 网卡 Mac 地址,IP 则承载要发 送的消息。数据链路层可以从 ARP 得到数据的传送信息,而从 IP 得到要传输的数据信息。 IP 协议 IP 协议是 TCP/IP 协议的核心,所有的 TCP,UDP,IMCP,IGCP…

光与夜之恋服务器维护中,光与夜之恋7月16日停服维护说明 维护详情一览

光与夜之恋7月16日停服维护说明维护详情一览。光与夜之恋7月16日停服维护更新了哪些内容?我们去了解一下。【7月16日停服维护说明】亲爱的设计师:为了给设计师们提供更好的游戏体验,光启市将于7月16日(周五)00:00进行预计5小时的停服维护,可…

10.Partial Dependence Plots

本教程是ML系列的一部分。 在此步骤中,您将学习如何创建和解释部分依赖图,这是从模型中提取洞察力的最有价值的方法之一。 What Are Partial Dependence Plots 有人抱怨机器学习模型是黑盒子。这些人会争辩说我们无法看到这些模型如何处理任何给定的数据…

springboot监控服务器信息,面试官:聊一聊SpringBoot服务监控机制

目录前言任何一个服务如果没有监控,那就是两眼一抹黑,无法知道当前服务的运行情况,也就无法对可能出现的异常状况进行很好的处理,所以对任意一个服务来说,监控都是必不可少的。就目前而言,大部分微服务应用…

0.《Apollo自动驾驶工程师技能图谱》

【新年礼物】开工第一天,送你一份自动驾驶工程师技能图谱! 布道团队 Apollo开发者社区 1月 2日 AI时代到来,人才的缺乏是阻碍行业大步发展的主要因素之一。Apollo平台发布以来,我们接触到非常多的开发者他们并不是专业自动驾驶领…

【HDU - 1116】【POJ - 1386】Play on Words(判断半欧拉图,欧拉通路)

题干: Some of the secret doors contain a very interesting word puzzle. The team of archaeologists has to solve it to open that doors. Because there is no other way to open the doors, the puzzle is very important for us. There is a large number…

11.Pipelines

本教程是ML系列的一部分。 在此步骤中,您将了解如何以及为何使用管道清理建模代码。 What Are Pipelines 管道是保持数据处理和建模代码有序的简单方法。 具体来说,管道捆绑了预处理和建模步骤,因此您可以像使用单个包一样使用整个捆绑包。…

ubuntu服务器创建共享文件夹,Ubuntu samba安装创建共享目录及使用

Ubuntu samba更新了很多版本更新,我本人认为Ubuntu samba是很好使的文件系统,在此向大家推荐。如今技术不断更新,各种使用文件都已经淘汰。我认为还是有很不错的如Ubuntu samba值得大家来运用。一. Ubuntu samba的安装:sudo apt-get insall s…

【POJ - 2337】Catenyms(欧拉图相关,欧拉通路输出路径,tricks)

题干: A catenym is a pair of words separated by a period such that the last letter of the first word is the same as the last letter of the second. For example, the following are catenyms: dog.gophergopher.ratrat.tigeraloha.alohaarachnid.dog A…

12.Cross-Validation

本教程是ML系列的一部分。 在此步骤中,您将学习如何使用交叉验证来更好地衡量模型性能。 What is Cross Validation 机器学习是一个迭代过程。 您将面临关于要使用的预测变量,要使用的模型类型,提供这些模型的参数等的选择。我们通过测量各…

服务器不显示u盘,服务器不读u盘启动

服务器不读u盘启动 内容精选换一换介绍使用Atlas 200 DK前需要准备的配件及开发服务器。Atlas 200 DK使用需要用户提前自购如表1所示配件。准备一个操作系统为Ubuntu X86架构的服务器,用途如下:为Atlas 200 DK制作SD卡启动盘。读卡器或者Atlas 200 DK会通…

【FZU - 2039】Pets(二分图匹配,水题)

题干: 有n个人,m条狗,然后会给出有一些人不喜欢一些狗就不会购买,问最多能卖多少狗。。 Input There is a single integer T in the first line of the test data indicating that there are T(T≤100) test cases. In the fir…

Leetcode刷题实战(1):Two Sum

Leetcode不需要过多介绍了,今天一边开始刷题一边开始总结: 官网链接如下:https://leetcode.com/problemset/all/ 题1描述: 1Two Sum38.80%Easy Given an array of integers, return indices of the two numbers such that they…

信息服务器为什么选择在贵州,为啥云服务器在贵州

为啥云服务器在贵州 内容精选换一换当用户已在ECS服务购买GPU加速型云服务器,并且想在该云服务器上运行应用时,可以通过纳管的方式将该云服务器纳入VR云渲游平台管理。登录控制台,在服务列表中选择“计算 > VR云渲游平台”。在左侧导航栏&…

LeetCode刷题实战(2):Add Two Numbers

题2描述: 2Add Two Numbers29.10%Medium You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order and each of their nodes contain a single digit. Add the two numbers and return it as a…

【BZOJ - 1305】dance跳舞(拆点网络流,建图,最大流,残留网络上跑最大流)

题干: 一次舞会有n个男孩和n个女孩。每首曲子开始时,所有男孩和女孩恰好配成n对跳交谊舞。每个男孩都不会和同一个女孩跳两首(或更多)舞曲。有一些男孩女孩相互喜欢,而其他相互不喜欢(不会“单向喜欢”&am…