杰瑞科技汇

python中gridsearchcv

什么是 GridSearchCV

GridSearchCV 是 Scikit-learn 库中的一个工具,它的全称是 网格搜索交叉验证,它的主要作用是帮助我们找到机器学习模型的最佳超参数组合。

python中gridsearchcv-图1
(图片来源网络,侵删)

可以把它想象成一个“参数自动调优器”。

核心概念拆解:

  • 超参数:这是在模型训练开始之前,需要由我们自己设定的参数,它们不像模型的权重(如线性回归中的系数)那样可以通过训练数据学习得到。

    • SVM 中的 C (惩罚系数) 和 kernel (核函数类型)。
    • 随机森林 中的 n_estimators (树的数量) 和 max_depth (树的最大深度)。
    • K近邻 中的 n_neighbors (邻居数量)。
    • 选择合适的超参数对模型性能至关重要。
  • 网格搜索:这是一种“暴力枚举”的策略,你提供一个参数的“网格”(即所有你想要尝试的参数值的列表),GridSearchCV 会遍历这个网格中所有可能的参数组合

  • 交叉验证:这是评估模型性能的一种稳健方法,它将数据集分成 K 份(K=5),轮流使用其中 K-1 份作为训练集,剩下 1 份作为验证集,重复 K 次,最终将 K 次的性能分数取平均,这样做可以避免因数据集划分的偶然性导致的评估偏差,得到一个更可靠的性能估计。

    python中gridsearchcv-图2
    (图片来源网络,侵删)

GridSearchCV 将“网格搜索”和“交叉验证”结合起来,系统地、自动化地为你找到在给定参数网格中,通过交叉验证评估得分最高的超参数组合。


为什么需要 GridSearchCV

手动调整超参数是一个繁琐且低效的过程,你可能会:

  1. 凭感觉试:随机选一组参数,训练,看效果。
  2. 依赖经验:但不同模型、不同数据集的最佳参数千差万别。
  3. 陷入局部最优:可能试了几组效果不错的参数,但错过了真正最优的组合。

GridSearchCV 的优势在于:

  • 自动化:你只需要定义好要尝试的参数范围,剩下的工作它都帮你完成。
  • 全面性:它会尝试所有组合,确保你找到的是在给定网格内的全局最优解(或次优解)。
  • 可靠性:通过交叉验证评估,结果比单次划分数据集的评估更可信。

如何使用 GridSearchCV?(代码示例)

下面我们通过一个完整的例子来学习如何使用 GridSearchCV,我们将使用 SVC (支持向量机) 模型,并为其寻找最佳的超参数 Cgamma

python中gridsearchcv-图3
(图片来源网络,侵删)

步骤 1:导入必要的库

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import classification_report

步骤 2:准备数据

我们使用 Scikit-learn 自带的鸢尾花数据集。

# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 将数据分为训练集和测试集
# 注意:GridSearchCV 只在训练集上进行,测试集要留到最后用来评估最终模型
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

步骤 3:定义模型和参数网格

这是最关键的一步。

  • 模型:我们选择 SVC
  • 参数网格:我们想要尝试的 Cgamma 的值。param_grid 是一个字典,键是参数名,值是包含所有可能值的列表。
# 1. 创建模型
svc = SVC()
# 2. 定义参数网格
# 我们想尝试 C=[0.1, 1, 10, 100] 和 gamma=['scale', 'auto', 0.001, 0.01]
param_grid = {
    'C': [0.1, 1, 10, 100],
    'gamma': ['scale', 'auto', 0.001, 0.01],
    'kernel': ['rbf'] # 我们只尝试 RBF 核,因为 gamma 主要与 RBF 核相关
}

步骤 4:创建并运行 GridSearchCV

# 3. 创建 GridSearchCV 对象
# estimator: 我们要优化的模型
# param_grid: 参数网格
# cv=5: 5折交叉验证
# scoring='accuracy': 评估指标是准确率
# n_jobs=-1: 使用所有可用的 CPU 核心来加速计算
# verbose=2: 打印详细的进度信息
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1, verbose=2)
# 4. 在训练数据上运行网格搜索
# 这会非常耗时,因为它会训练 4 (C的值) * 4 (gamma的值) * 5 (折数) = 80 个模型
print("开始运行 GridSearchCV...")
grid_search.fit(X_train, y_train)
print("GridSearchCV 运行完成!")

步骤 5:查看结果

GridSearchCV 训练完成后,会包含很多有用的信息。

# 5. 查看最佳参数
print(f"最佳参数组合: {grid_search.best_params_}")
# 6. 查看最佳模型在交叉验证中的得分
print(f"最佳交叉验证得分: {grid_search.best_score_:.4f}")
# 7. 获取最佳模型
# best_estimator_ 是已经用最佳参数训练好的模型
best_svc = grid_search.best_estimator_
# 8. 在测试集上进行评估
# 注意:这里我们不再训练,直接用 best_svc 进行预测
y_pred = best_svc.predict(X_test)
print("\n在测试集上的分类报告:")
print(classification_report(y_test, y_pred))

完整代码与输出示例

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import classification_report
# 1. 准备数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 2. 定义模型和参数网格
svc = SVC()
param_grid = {
    'C': [0.1, 1, 10, 100],
    'gamma': ['scale', 'auto', 0.001, 0.01],
    'kernel': ['rbf']
}
# 3. 创建并运行 GridSearchCV
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1, verbose=2)
grid_search.fit(X_train, y_train)
# 4. 查看结果
print(f"最佳参数组合: {grid_search.best_params_}")
print(f"最佳交叉验证得分: {grid_search.best_score_:.4f}")
# 5. 在测试集上评估最佳模型
best_svc = grid_search.best_estimator_
y_pred = best_svc.predict(X_test)
print("\n在测试集上的分类报告:")
print(classification_report(y_test, y_pred))

可能的输出:

Fitting 5 folds for each of 16 candidates, totalling 80 fits
开始运行 GridSearchCV...
GridSearchCV 运行完成!
最佳参数组合: {'C': 1, 'gamma': 'scale', 'kernel': 'rbf'}
最佳交叉验证得分: 0.9750
在测试集上的分类报告:
              precision    recall  f1-score   support
           0       1.00      1.00      1.00        10
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        11
    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

GridSearchCV 的关键参数

参数 描述
estimator 你想要优化的 Scikit-learn 模型对象。
param_grid 一个字典,定义了要搜索的参数和它们的候选值。
scoring 评估模型性能的指标,可以是字符串(如 'accuracy', 'f1_macro', 'roc_auc')或一个可调用对象,默认是 estimator 的默认评分方法。
cv 交叉验证的折数,可以是整数(如 5),也可以是交叉验证生成器对象,默认是 5
n_jobs 并行运行的作业数。-1 表示使用所有可用的 CPU 核心,可以大大缩短搜索时间。
verbose 控制输出的详细程度。0 不输出,1 偶尔输出,>1 输出更详细的信息。
refit 布尔值,如果为 True(默认),在找到最佳参数后,会用整个训练集(X_train, y_train)和最佳参数重新训练一个模型,并将该模型保存在 best_estimator_ 属性中。

重要注意事项与局限性

  1. 计算成本高GridSearchCV 的最大缺点是它尝试所有组合,如果你有 3 个参数,每个参数有 10 个候选值,并且使用 5 折交叉验证,那么它总共要训练 3 * 10 * 5 = 150 个模型,参数空间越大,时间呈指数级增长。
  2. 可能错过最优解:它只在你给定的网格中寻找最优解,如果你设置的候选值范围不包含真正的最佳参数,它就找不到,参数网格的设置需要一定的先验知识或进行初步探索。
  3. 替代方案
    • RandomizedSearchCV:这是 GridSearchCV 的一个流行替代方案,它不是尝试所有组合,而是在参数空间中进行随机采样,你只需要指定要尝试的组合次数 (n_iter),这在参数空间很大时非常高效,往往能用更少的计算量找到接近最优的解。
    • 更高级的优化库:如 Optuna, Hyperopt 等,它们提供了更智能的搜索策略(如贝叶斯优化),可以更高效地找到全局最优解。

GridSearchCV 是 Scikit-learn 中一个极其实用且强大的工具,是机器学习工作流程中模型调优环节的标准配置,虽然它计算成本高,但对于中小型数据集和参数空间来说,它提供了一种简单、全面且可靠的超参数优化方法,掌握它的使用,能显著提升你的模型性能。

分享:
扫描分享到社交APP
上一篇
下一篇