杰瑞科技汇

Python LightGBM实例如何快速上手?

目录

  1. LightGBM 简介
  2. 安装
  3. 核心概念
    • Dataset 对象
    • Booster 对象
  4. 完整实例:从数据加载到模型预测
    • 步骤 1:导入库
    • 步骤 2:准备数据(使用 scikit-learn 的鸢尾花数据集)
    • 步骤 3:创建 LightGBM 数据集
    • 步骤 4:定义模型参数
    • 步骤 5:训练模型
    • 步骤 6:进行预测
    • 步骤 7:评估模型
  5. 进阶用法
    • 回归任务示例
    • 使用 scikit-learn API (更推荐)
    • 交叉验证
    • 特征重要性
    • 超参数调优 (使用 optuna)
  6. 模型保存与加载

LightGBM 简介

LightGBM (Light Gradient Boosting Machine) 是由微软开发的开源、分布式、高性能的梯度提升框架,它基于决策树算法,并且在速度和内存效率上进行了优化,非常适合处理大规模数据。

Python LightGBM实例如何快速上手?-图1
(图片来源网络,侵删)

主要优点:

  • 速度快:采用基于梯度的单边采样 (GOSS) 和互斥特征捆绑 (EFB) 等技术,训练速度远超传统 GBDT。
  • 内存占用低:EFB 技术大大减少了特征的数量,从而降低了内存消耗。
  • 精度高:在多个公开数据集上,LightGBM 的精度通常优于其他 boosting 算法。
  • 支持并行学习:支持特征维度和数据维度的并行化。
  • 支持 GPU 加速

主要缺点:

  • 容易过拟合:由于学习速度快,在数据量小或参数设置不当的情况下,容易出现过拟合。
  • 对敏感参数敏感:如 num_leaves, max_depth 等参数需要仔细调整。

安装

您可以使用 pip 来安装 LightGBM,如果您的系统支持 CUDA,建议安装 GPU 版本以获得极致的训练速度。

# 安装 CPU 版本
pip install lightgbm
# 安装 GPU 版本 (需要先安装 CUDA 和 cuDNN)
pip install lightgbm --install-option=--gpu

核心概念

LightGBM 主要有两个核心对象:DatasetBooster

Python LightGBM实例如何快速上手?-图2
(图片来源网络,侵删)

Dataset 对象

Dataset 是 LightGBM 专门用于存储高效训练数据的内部数据结构,它比直接使用 NumPy 数组或 Pandas DataFrame 更节省内存,并且训练速度更快。

创建 Dataset 的基本语法:

import lightgbm as lgb
# lgb.Dataset(data, label, ...)
train_data = lgb.Dataset(X_train, label=y_train)

Booster 对象

Booster 是训练好的模型对象,它包含了模型的全部信息,如树的结构、叶子节点权重等,我们不会直接创建 Booster,而是通过调用 train() 函数来返回一个训练好的 Booster 对象。


完整实例:从数据加载到模型预测

这个示例将演示如何使用 LightGBM 的原生 API(lgb.train)来解决一个分类问题

Python LightGBM实例如何快速上手?-图3
(图片来源网络,侵删)

步骤 1:导入库

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import lightgbm as lgb

步骤 2:准备数据

我们使用 scikit-learn 自带的鸢尾花数据集,这是一个多分类问题,有3个类别。

# 加载数据
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"训练集形状: {X_train.shape}")
print(f"测试集形状: {X_test.shape}")

步骤 3:创建 LightGBM 数据集

将训练数据转换为 LightGBM 的 Dataset 格式。

# 创建 LightGBM 数据集
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) # 用于评估集

步骤 4:定义模型参数

参数是 LightGBM 的核心,直接影响模型性能,这里我们先设置一些基础参数。

params = {
    'objective': 'multiclass',  # 多分类任务
    'num_class': 3,             # 类别数量
    'metric': 'multi_logloss',  # 评估指标
    'boosting_type': 'gbdt',    # 提升算法类型
    'num_leaves': 31,           # 一棵树上的叶子节点数
    'learning_rate': 0.05,      # 学习率
    'feature_fraction': 0.9,    # 每次迭代随机选择90%的特征
    'bagging_fraction': 0.8,    # 每次迭代随机选择80%的数据
    'bagging_freq': 5,          # 每5次迭代进行一次bagging
    'verbose': -1,              # 不打印训练过程中的信息
    'random_state': 42          # 随机种子,保证结果可复现
}

步骤 5:训练模型

使用 lgb.train() 函数来训练模型,我们可以在这里传入 valid_sets 来监控模型在验证集上的表现。

# 训练模型
print("开始训练模型...")
model = lgb.train(
    params,
    lgb_train,
    num_boost_round=100,       # 迭代次数
    valid_sets=[lgb_train, lgb_eval], # 训练集和验证集
    callbacks=[lgb.early_stopping(stopping_rounds=10), lgb.log_evaluation(10)]
)
  • num_boost_round: 模型迭代(即生成多少棵树)的次数。
  • valid_sets: 用于在训练过程中评估模型性能的数据集列表。
  • callbacks: 回调函数列表。
    • lgb.early_stopping: 如果验证集的指标在指定的 stopping_rounds 次迭代内没有提升,则提前停止训练,防止过拟合。
    • lgb.log_evaluation: 每隔指定的迭代次数打印一次评估指标。

步骤 6:进行预测

训练完成后,我们可以使用 predict() 方法对测试集进行预测。

# 进行预测
y_pred_prob = model.predict(X_test) # 输出的是每个样本属于各个类别的概率
print(f"预测概率示例 (前5个样本):\n{y_pred_prob[:5]}")
# 将概率转换为类别标签
y_pred = np.argmax(y_pred_prob, axis=1)
print(f"预测类别示例 (前5个样本):\n{y_pred[:5]}")

步骤 7:评估模型

使用 scikit-learn 的指标来评估模型在测试集上的表现。

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型准确率: {accuracy:.4f}")
# 打印详细的分类报告
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

进阶用法

回归任务示例

回归任务与分类任务非常相似,只需要修改 objectivemetric 参数即可。

from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
# 加载数据
housing = fetch_california_housing()
X_reg, y_reg = housing.data, housing.target
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=42)
# 定义回归参数
params_reg = {
    'objective': 'regression',  # 回归任务
    'metric': 'rmse',          # 评估指标为均方根误差
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'verbose': -1,
    'random_state': 42
}
# 创建数据集
lgb_train_reg = lgb.Dataset(X_train_reg, y_train_reg)
lgb_eval_reg = lgb.Dataset(X_test_reg, y_test_reg, reference=lgb_train_reg)
# 训练模型
model_reg = lgb.train(
    params_reg,
    lgb_train_reg,
    num_boost_round=100,
    valid_sets=[lgb_eval_reg],
    callbacks=[lgb.early_stopping(10), lgb.log_evaluation(0)]
)
# 预测和评估
y_pred_reg = model_reg.predict(X_test_reg)
rmse = np.sqrt(mean_squared_error(y_test_reg, y_pred_reg))
print(f"回归模型 RMSE: {rmse:.4f}")

使用 scikit-learn API (更推荐)

对于习惯了 scikit-learn 生态的用户,LightGBM 提供了 LGBMClassifierLGBMRegressor 类,它们与 scikit-learn 的 fitpredict 接口完全兼容,这种方式更易于集成到现有的机器学习流程中。

from lightgbm import LGBMClassifier
# 使用 scikit-learn API
lgbm_sk = LGBMClassifier(
    objective='multiclass',
    num_class=3,
    num_leaves=31,
    learning_rate=0.05,
    n_estimators=100, # 对应 num_boost_round
    random_state=42,
    verbose=-1
)
# 训练
lgbm_sk.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(10)])
# 预测和评估
y_pred_sk = lgbm_sk.predict(X_test)
accuracy_sk = accuracy_score(y_test, y_pred_sk)
print(f"\n使用 scikit-learn API 的准确率: {accuracy_sk:.4f}")

交叉验证

LightGBM 内置了高效的交叉验证功能,无需手动划分数据。

# 使用与分类任务相同的 params
cv_results = lgb.cv(
    params,
    lgb_train,
    num_boost_round=100,
    nfold=5,             # 5折交叉验证
    stratified=True,     # 对于分类任务,使用分层抽样
    callbacks=[lgb.early_stopping(10), lgb.log_evaluation(0)]
)
# cv_results 是一个字典,包含了每一折的评估结果
# 打印最佳迭代次数和对应的分数
print(f"最佳迭代次数: {len(cv_results['multi_log-mean-mean'])}")
print(f"最佳 logloss: {min(cv_results['multi_log-mean-mean']):.4f}")

特征重要性

训练好的模型可以方便地查看特征重要性。

# 获取特征重要性
importance = model.feature_importance(importance_type='gain') # 'gain' 是基于平均增益
feature_names = iris.feature_names
# 创建 DataFrame 并排序
importance_df = pd.DataFrame({'feature': feature_names, 'importance': importance})
importance_df = importance_df.sort_values(by='importance', ascending=False)
print("\n特征重要性 (基于增益):")
print(importance_df)
# 可视化
import matplotlib.pyplot as plt
importance_df.plot.bar(x='feature', y='importance', legend=False)'Feature Importance')
plt.ylabel('Importance (Gain)')
plt.show()

超参数调优 (使用 optuna)

手动调参非常耗时,可以使用自动调参工具如 optuna

import optuna
def objective(trial):
    # 定义搜索空间
    params = {
        'objective': 'multiclass',
        'num_class': 3,
        'metric': 'multi_logloss',
        'boosting_type': 'gbdt',
        'num_leaves': trial.suggest_int('num_leaves', 10, 100),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'feature_fraction': trial.suggest_float('feature_fraction', 0.7, 1.0),
        'bagging_fraction': trial.suggest_float('bagging_fraction', 0.7, 1.0),
        'bagging_freq': trial.suggest_int('bagging_freq', 1, 10),
        'verbose': -1,
        'random_state': 42
    }
    # 5折交叉验证
    cv_results = lgb.cv(
        params,
        lgb_train,
        num_boost_round=100,
        nfold=5,
        stratified=True,
        callbacks=[lgb.early_stopping(10), lgb.log_evaluation(0)]
    )
    # 返回需要最小化的指标
    return min(cv_results['multi_log-mean-mean'])
# 创建研究并开始优化
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=20) # 尝试20次
print("\n最佳参数:")
print(study.best_params)
print(f"最佳分数: {study.best_value:.4f}")

模型保存与加载

训练好的模型可以被保存下来,以便后续使用或部署。

# 保存模型
model.save_model('my_lgbm_model.txt')
print("\n模型已保存为 my_lgbm_model.txt")
# 加载模型
loaded_model = lgb.Booster(model_file='my_lgbm_model.txt')
# 使用加载的模型进行预测
y_pred_loaded = loaded_model.predict(X_test)
y_pred_loaded = np.argmax(y_pred_loaded, axis=1)
accuracy_loaded = accuracy_score(y_test, y_pred_loaded)
print(f"加载模型的准确率: {accuracy_loaded:.4f}")

这份 Python LightGBM 实例教程涵盖了从入门到进阶的主要内容:

  1. 基础流程:展示了如何加载数据、创建 Dataset、定义参数、训练、预测和评估。
  2. API 选择:介绍了原生 lgb.train API 和更易于集成的 scikit-learn API (LGBMClassifier/LGBMRegressor)。
  3. 关键功能:演示了交叉验证、特征重要性分析等实用功能。
  4. 高级技巧:介绍了使用 optuna 进行超参数自动调优的方法。
  5. 模型持久化:说明了如何保存和加载训练好的模型。

LightGBM 是一个非常强大且灵活的工具,掌握它对于从事数据科学和机器学习工作非常有帮助,建议您从官方文档中获取更多信息,并结合自己的项目进行实践。

分享:
扫描分享到社交APP
上一篇
下一篇