Scikit-learn 简单介绍入门和常用API汇总 - 教程

news/2025/9/18 17:45:31/文章来源:https://www.cnblogs.com/yxysuanfa/p/19099276

Scikit-learn 简单介绍和入门示例

1. 概述

Scikit-learn(简称 sklearn)是 Python 生态中最流行的 机器学习库,主要用于传统 ML 任务。它基于 NumPy、SciPy 和 Matplotlib,提供了统一的 API,涵盖 数据预处理、特征工程、模型训练、评估与调优

定位:


2. 设计思想

Scikit-learn 遵循 模块化、统一接口、组合化 的原则。
主要接口规范:

  • fit(X, y=None):训练模型 / 学习参数
  • predict(X):预测
  • transform(X):数据变换(特征工程、降维)
  • fit_transform(X):训练并变换(常用于预处理)
  • score(X, y):评估模型
  • get_params() / set_params():超参数管理

统一接口的好处:无论是 SVM、决策树还是 PCA,调用方式都基本相同。


3. 模块全景

3.1 数据预处理

3.2 监督学习

  • 分类

    • 线性:LogisticRegression, SGDClassifier
    • 树模型:DecisionTreeClassifier, RandomForestClassifier, GradientBoostingClassifier
    • SVM:SVC
    • 朴素贝叶斯:GaussianNB
  • 回归

    • 线性:LinearRegression, Ridge, Lasso
    • 树模型:DecisionTreeRegressor, RandomForestRegressor
    • 支持向量回归:SVR

3.3 非监督学习

3.4 模型选择与评估

  • 交叉验证cross_val_score, KFold

  • 调参GridSearchCV, RandomizedSearchCV

  • 指标

    • 分类:准确率、精确率、召回率、F1、ROC-AUC
    • 回归:均方误差 (MSE)、R²

3.5 工程工具


4. 工作流程

一个典型的 Scikit-learn 项目通常包括:

  1. 数据准备:加载、清洗、划分(train_test_split
  2. 预处理:标准化、特征选择、降维
  3. 建模:选择分类/回归/聚类算法
  4. 评估:使用交叉验证与指标函数
  5. 调参GridSearchCVRandomizedSearchCV
  6. 部署:模型持久化 (joblib.dump/load)

5. 简单示例

鸢尾花分类 为例:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
# 数据加载
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 构建流水线:标准化 + SVM 分类器
pipeline = Pipeline([
("scaler", StandardScaler()),
("clf", SVC(kernel="linear"))
])
# 训练与预测
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
# 评估
print("Accuracy:", accuracy_score(y_test, y_pred))

6. 优势与不足

优势

  • 统一 API,学习成本低
  • 覆盖大多数传统 ML 算法
  • 性能不错,底层部分用 Cython/C++ 优化
  • 文档和社区生态完善

不足

  • 不支持大规模分布式数据
  • 不包含深度学习模型
  • 在线学习能力有限(部分算法支持 partial_fit

7. 适用场景


8.入门示例yanz演示

阶段一:入门(快速上手)

目标:掌握 Scikit-learn 的基本 API,能完成简单的分类/回归任务。

学习要点
  • 熟悉 fit / predict / transform 接口
  • 使用 train_test_split 划分数据
  • 调用常见模型:LinearRegression, LogisticRegression, SVC, KNeighborsClassifier
  • 使用 accuracy_scoremean_squared_error 等指标
综合示例:鸢尾花分类
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# 数据集
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 模型
clf = LogisticRegression(max_iter=200)
clf.fit(X_train, y_train)
# 预测与评估
y_pred = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

阶段二:提升(实战技巧)

目标:掌握 Pipeline、特征工程、模型调参,能在真实数据集上完成较复杂的任务。

学习要点
综合示例:房价预测(加州房价数据集)
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error
# 数据集
X, y = fetch_california_housing(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 流水线:标准化 + 随机森林
pipe = Pipeline([
("scaler", StandardScaler()),
("rf", RandomForestRegressor(random_state=42))
])
# 超参数搜索
param_grid = {
"rf__n_estimators": [50, 100],
"rf__max_depth": [10, 20, None]
}
grid = GridSearchCV(pipe, param_grid, cv=3, scoring="neg_mean_squared_error")
grid.fit(X_train, y_train)
# 预测与评估
y_pred = grid.predict(X_test)
print("Best params:", grid.best_params_)
print("MSE:", mean_squared_error(y_test, y_pred))

阶段三:高级(综合应用)

目标:能 系统性构建机器学习项目,包括数据预处理、特征选择、模型集成、结果可视化与解释。

学习要点
综合示例:信用卡客户违约预测(分类任务)
import joblib
import pandas as pd
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report
# 假设已加载信用卡客户数据 (X: 特征, y: 是否违约)
data = pd.read_csv("credit_card.csv")
X = data.drop("default", axis=1)
y = data["default"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# 流水线:预处理 + 特征选择 + 集成学习
pipe = Pipeline([
("scaler", StandardScaler()),
("select", SelectKBest(score_func=f_classif, k=10)),
("clf", VotingClassifier(estimators=[
("rf", RandomForestClassifier(random_state=42)),
("gb", GradientBoostingClassifier(random_state=42))
], voting="soft"))
])
# 随机搜索调参
param_dist = {
"clf__rf__n_estimators": [100, 200],
"clf__gb__learning_rate": [0.05, 0.1]
}
search = RandomizedSearchCV(pipe, param_dist, cv=3, scoring="f1", n_iter=4, random_state=42)
search.fit(X_train, y_train)
# 预测与评估
y_pred = search.predict(X_test)
print("Best params:", search.best_params_)
print(classification_report(y_test, y_pred))
# 模型保存
joblib.dump(search.best_estimator_, "credit_model.pkl")

总结

  • 入门:掌握 API + 简单模型(Logistic/SVC/LinearRegression)
  • 提升:学会 Pipeline、特征工程、调参(GridSearchCV/RandomizedSearchCV)
  • 高级:综合应用,能做完整 ML 项目(特征选择 + 集成学习 + 模型解释 + 部署)

Scikit-learn 常用 API 汇总

1️、数据集工具

from sklearn import datasets
from sklearn.model_selection import train_test_split
  • datasets.load_iris():鸢尾花分类
  • datasets.load_digits():手写数字识别
  • datasets.fetch_california_housing():加州房价数据
  • datasets.make_classification():生成分类数据
  • datasets.make_regression():生成回归数据
  • train_test_split(X, y, test_size=0.2, random_state=42):划分训练/测试集

2️、数据预处理

from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
  • 标准化StandardScaler().fit_transform(X)
  • 归一化MinMaxScaler().fit_transform(X)
  • 正则化Normalizer().fit_transform(X)
  • 独热编码OneHotEncoder().fit_transform(X)
  • 标签编码LabelEncoder().fit_transform(y)

3️、特征工程

from sklearn.feature_selection import SelectKBest, f_classif, RFE
from sklearn.decomposition import PCA
  • 特征选择SelectKBest(score_func=f_classif, k=10).fit_transform(X, y)
  • 递归特征消除RFE(estimator, n_features_to_select=5).fit_transform(X, y)
  • 主成分分析PCA(n_components=2).fit_transform(X)

4️、常用模型

分类

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
  • 逻辑回归:LogisticRegression()
  • 支持向量机:SVC(kernel="linear")
  • 决策树:DecisionTreeClassifier()
  • 随机森林:RandomForestClassifier(n_estimators=100)
  • 梯度提升:GradientBoostingClassifier()
  • 朴素贝叶斯:GaussianNB()
  • 最近邻:KNeighborsClassifier(n_neighbors=5)

回归

from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
  • 线性回归:LinearRegression()
  • 岭回归:Ridge(alpha=1.0)
  • Lasso 回归:Lasso(alpha=0.1)
  • 支持向量回归:SVR(kernel="rbf")
  • 随机森林回归:RandomForestRegressor(n_estimators=100)

聚类 / 非监督学习

from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
  • KMeans 聚类:KMeans(n_clusters=3)
  • DBSCAN:DBSCAN(eps=0.5, min_samples=5)
  • 层次聚类:AgglomerativeClustering(n_clusters=3)
  • 高斯混合模型:GaussianMixture(n_components=3)

5️、模型评估

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import mean_squared_error, r2_score
  • 分类指标

    • accuracy_score(y_true, y_pred)
    • precision_score(y_true, y_pred)
    • recall_score(y_true, y_pred)
    • f1_score(y_true, y_pred)
    • roc_auc_score(y_true, y_prob)
    • classification_report(y_true, y_pred)
  • 回归指标

    • mean_squared_error(y_true, y_pred)
    • r2_score(y_true, y_pred)
  • 混淆矩阵confusion_matrix(y_true, y_pred)


6️、模型选择与调参

from sklearn.model_selection import cross_val_score, KFold, StratifiedKFold
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
  • 交叉验证:cross_val_score(model, X, y, cv=5)

  • K 折:KFold(n_splits=5)

  • 分层 K 折:StratifiedKFold(n_splits=5)

  • 网格搜索:

    GridSearchCV(estimator, param_grid, cv=3, scoring="accuracy")
  • 随机搜索:

    RandomizedSearchCV(estimator, param_distributions, cv=3, n_iter=10)

7️、工程工具

from sklearn.pipeline import Pipeline
import joblib
  • 流水线

    pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", SVC())
    ])
  • 模型保存joblib.dump(model, "model.pkl")

  • 模型加载model = joblib.load("model.pkl")


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/907306.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AUTOSAR网络管理

汽车行业的网络管理一般有两种,一种是AutoSar另一种是OSEK,为啥汽车要网络管理,其实是为了降低车辆电池消耗,当车辆不工作时所有总线上的ECU通讯模块或整个ECU处于低功耗状态。网络管理一般用在电池供电的ECU,比如…

写用例注意点

写用例注意点: 1、测试标题 明确测试点 2、写用例的前几条用例都是主要场景的用例先写 微信个人能发微信红包 微信群发能发拼手气红包 微信群发能发拼手气红包 微信群发能发专属气红包 3、测试标题尽量写内容不要写案…

12 路低延迟推流!米尔 RK3576 赋能智能安防 360 环视

在智慧城市建设加速与社区安防需求升级的双重驱动下,“360 无死角监控 + 实时响应” 已成为安防领域的核心诉求。传统监控方案常受限于摄像头接入数量不足、编解码效率低、推流延迟高三大痛点,难以覆盖社区、园区等复…

Xilinx DDR3仿真 DBG

Xilinx DDR3仿真 DBG emmmm,其实这个错误不用去管,直接运行也不影响仿真的进行。 https://blog.csdn.net/qq_42959171/article/details/139726943

A公司一面:类加载的过程是怎么样的? 双亲委派的优点和缺点? 产生fullGC的情况有哪些? spring的动态代理有哪些?区别是什么? 如何排查CPU使用率过高?

A公司一面:类加载的过程是怎么样的? 双亲委派的优点和缺点? 产生fullGC的情况有哪些? spring的动态代理有哪些?区别是什么? 如何排查CPU使用率过高?摘要 A公司的面经JVM的类加载的过程是怎么样的? 双亲委派模型…

redis-hash类型参数基本命令

redis-hash类型参数基本命令redis存储数据的value可以是hash类型的,也称之为hash表,字典等。hash表就是一个map,由key-value组成。 我们把hash表的key称为field,值称为value。注意:redis的hash表的field和value都…

Alternating Subsequence

CF1343C Alternating Subsequence 题目描述 回忆一下,如果序列 \(b\) 是序列 \(a\) 的一个子序列,那么 \(b\) 可以通过从 \(a\) 中删除零个或多个元素(不改变剩余元素的顺序)得到。例如,如果 \(a=[1, 2, 1, 3, 1,…

白鲸开源“创客北京2025”再摘殊荣,聚焦Agentic AI时代数据基础设施建设

近日,“创客北京2025”创新创业大赛海淀区级赛圆满落幕,经过最终比拼,北京白鲸开源科技有限公司凭借 「Agentic AI时代下的数据基础设施平台」(白鲸数据集成调度平台/WhaleStudio) 脱颖而出,荣获企业组二等奖。近…

深入解析:大模型-Transformer原理与实战篇

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

python基础-公共操作

数据类型间公共支持的操作符运算: + ,* ,in , not in‘+’ :支持的容器类型 字符串、列表、元组 ,实现两个容器的合并‘*’ : 支持的容器类型 字符串、列表、元组, 赋值容器内容str1 = q str1* 5 =qq…

天翼云第九代弹性云主机:让每一次计算快人一步

随着数字化转型进程不断深入,云计算已成为推动千行百业智能化升级的核心引擎。弹性计算服务凭借其灵活扩展、高可用和高性能等特点,正持续为企业提供关键基础设施支持。面对日益复杂的业务场景与持续增长的计算需求,…

若依(RuoYi)框架漏洞总结

0x01 特征 绿若依 icon_hash=”706913071”蓝若依 icon_hash=” -1231872293”0x02 漏洞 弱口令 用户:admin ruoyi druid 密码:123456 admin druid admin123 admin888若依前台默认shiro key命令执行漏洞…

第一次个人项目作业_论文查重

第一次项目作业这个作业属于哪个课程 https://edu.cnblogs.com/campus/gdgy/Class34Grade23ComputerScience这个作业要求在哪里 https://edu.cnblogs.com/campus/gdgy/Class34Grade23ComputerScience/homework/13477这…

2025年版《中科院期刊分区表》与2023年版对比表,附名单可直接查阅

2025年版《中科院期刊分区表》与2023年版相比,主要有以下几个变化‌: ‌1、发布时间提前‌:2025年版分区表从12月提前至3月发布,与投稿周期同步,学者可以尽早锁定期刊最新分区,避免“投稿后降区”的风险‌。 ‌2…

对马岛之魂

护身符 稻荷神护身符----增加资源的获取 aa

2019年双因素认证最佳实践指南

本文深入探讨2019年双因素认证的正确实现方式,对比TOTP与WebAuthn技术优劣,分析用户行为模式,并提供实际部署建议,帮助开发者构建更安全的认证系统。2019年正确实现双因素认证 - Trail of Bits博客 自3月起,Trail…

oracle 删除重复数据

delete hpas_index_data_source swhere s.id in (select idFROM (SELECT t1.*,ROW_NUMBER() OVER(PARTITION BY t1.indexid, t1.doctor_id, t1.start_date, t1.vals ORDER BY t1.rowid) as rnFROM hpas_index_data_sou…

Account Kit(华为账号服务)再进化,开发者接入效率飙升!

Hi 各位开发者朋友~👋 为持续优化开发体验,提升集成效率,Account Kit接入体验再升级,助力构建更流畅、更安全的登录体验,让开发效率火力全开!😎 【体验升级】华为账号相关权益申请入口统一迁移至AGC华为账号…

Codeforces Round 1051 (Div. 2) D题启发(DP

题目简述 需要找到所有最长单调递减子序列长度不超过2的子列个数,做法是dp。 状态记录 我们不必理会题解中乱七八糟的定义,只需要知道他事实上就是模拟了当我们拿到一个已知数列时贪心的过程,把我们计算最长单调递减…

[踩坑劝退]批量生成 grafana dashboard 的技术

[踩坑劝退]批量生成 grafana dashboard 的技术作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢!cnblogs博客 zhihu Github 公众号:一本正经的瞎扯最近想要一个功能:把 VictoriaMetrics 采集的数据,自动变…