什么是 GridSearchCV?
GridSearchCV 是 Scikit-learn 库中的一个工具,它的全称是 网格搜索交叉验证,它的主要作用是帮助我们找到机器学习模型的最佳超参数组合。

可以把它想象成一个“参数自动调优器”。
核心概念拆解:
-
超参数:这是在模型训练开始之前,需要由我们自己设定的参数,它们不像模型的权重(如线性回归中的系数)那样可以通过训练数据学习得到。
- SVM 中的
C(惩罚系数) 和kernel(核函数类型)。 - 随机森林 中的
n_estimators(树的数量) 和max_depth(树的最大深度)。 - K近邻 中的
n_neighbors(邻居数量)。 - 选择合适的超参数对模型性能至关重要。
- SVM 中的
-
网格搜索:这是一种“暴力枚举”的策略,你提供一个参数的“网格”(即所有你想要尝试的参数值的列表),
GridSearchCV会遍历这个网格中所有可能的参数组合。 -
交叉验证:这是评估模型性能的一种稳健方法,它将数据集分成 K 份(K=5),轮流使用其中 K-1 份作为训练集,剩下 1 份作为验证集,重复 K 次,最终将 K 次的性能分数取平均,这样做可以避免因数据集划分的偶然性导致的评估偏差,得到一个更可靠的性能估计。
(图片来源网络,侵删)
GridSearchCV 将“网格搜索”和“交叉验证”结合起来,系统地、自动化地为你找到在给定参数网格中,通过交叉验证评估得分最高的超参数组合。
为什么需要 GridSearchCV?
手动调整超参数是一个繁琐且低效的过程,你可能会:
- 凭感觉试:随机选一组参数,训练,看效果。
- 依赖经验:但不同模型、不同数据集的最佳参数千差万别。
- 陷入局部最优:可能试了几组效果不错的参数,但错过了真正最优的组合。
GridSearchCV 的优势在于:
- 自动化:你只需要定义好要尝试的参数范围,剩下的工作它都帮你完成。
- 全面性:它会尝试所有组合,确保你找到的是在给定网格内的全局最优解(或次优解)。
- 可靠性:通过交叉验证评估,结果比单次划分数据集的评估更可信。
如何使用 GridSearchCV?(代码示例)
下面我们通过一个完整的例子来学习如何使用 GridSearchCV,我们将使用 SVC (支持向量机) 模型,并为其寻找最佳的超参数 C 和 gamma。

步骤 1:导入必要的库
import numpy as np from sklearn import datasets from sklearn.model_selection import train_test_split, GridSearchCV from sklearn.svm import SVC from sklearn.metrics import classification_report
步骤 2:准备数据
我们使用 Scikit-learn 自带的鸢尾花数据集。
# 加载数据 iris = datasets.load_iris() X = iris.data y = iris.target # 将数据分为训练集和测试集 # 注意:GridSearchCV 只在训练集上进行,测试集要留到最后用来评估最终模型 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
步骤 3:定义模型和参数网格
这是最关键的一步。
- 模型:我们选择
SVC。 - 参数网格:我们想要尝试的
C和gamma的值。param_grid是一个字典,键是参数名,值是包含所有可能值的列表。
# 1. 创建模型
svc = SVC()
# 2. 定义参数网格
# 我们想尝试 C=[0.1, 1, 10, 100] 和 gamma=['scale', 'auto', 0.001, 0.01]
param_grid = {
'C': [0.1, 1, 10, 100],
'gamma': ['scale', 'auto', 0.001, 0.01],
'kernel': ['rbf'] # 我们只尝试 RBF 核,因为 gamma 主要与 RBF 核相关
}
步骤 4:创建并运行 GridSearchCV
# 3. 创建 GridSearchCV 对象
# estimator: 我们要优化的模型
# param_grid: 参数网格
# cv=5: 5折交叉验证
# scoring='accuracy': 评估指标是准确率
# n_jobs=-1: 使用所有可用的 CPU 核心来加速计算
# verbose=2: 打印详细的进度信息
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1, verbose=2)
# 4. 在训练数据上运行网格搜索
# 这会非常耗时,因为它会训练 4 (C的值) * 4 (gamma的值) * 5 (折数) = 80 个模型
print("开始运行 GridSearchCV...")
grid_search.fit(X_train, y_train)
print("GridSearchCV 运行完成!")
步骤 5:查看结果
GridSearchCV 训练完成后,会包含很多有用的信息。
# 5. 查看最佳参数
print(f"最佳参数组合: {grid_search.best_params_}")
# 6. 查看最佳模型在交叉验证中的得分
print(f"最佳交叉验证得分: {grid_search.best_score_:.4f}")
# 7. 获取最佳模型
# best_estimator_ 是已经用最佳参数训练好的模型
best_svc = grid_search.best_estimator_
# 8. 在测试集上进行评估
# 注意:这里我们不再训练,直接用 best_svc 进行预测
y_pred = best_svc.predict(X_test)
print("\n在测试集上的分类报告:")
print(classification_report(y_test, y_pred))
完整代码与输出示例
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import classification_report
# 1. 准备数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 2. 定义模型和参数网格
svc = SVC()
param_grid = {
'C': [0.1, 1, 10, 100],
'gamma': ['scale', 'auto', 0.001, 0.01],
'kernel': ['rbf']
}
# 3. 创建并运行 GridSearchCV
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1, verbose=2)
grid_search.fit(X_train, y_train)
# 4. 查看结果
print(f"最佳参数组合: {grid_search.best_params_}")
print(f"最佳交叉验证得分: {grid_search.best_score_:.4f}")
# 5. 在测试集上评估最佳模型
best_svc = grid_search.best_estimator_
y_pred = best_svc.predict(X_test)
print("\n在测试集上的分类报告:")
print(classification_report(y_test, y_pred))
可能的输出:
Fitting 5 folds for each of 16 candidates, totalling 80 fits
开始运行 GridSearchCV...
GridSearchCV 运行完成!
最佳参数组合: {'C': 1, 'gamma': 'scale', 'kernel': 'rbf'}
最佳交叉验证得分: 0.9750
在测试集上的分类报告:
precision recall f1-score support
0 1.00 1.00 1.00 10
1 1.00 1.00 1.00 9
2 1.00 1.00 1.00 11
accuracy 1.00 30
macro avg 1.00 1.00 1.00 30
weighted avg 1.00 1.00 1.00 30
GridSearchCV 的关键参数
| 参数 | 描述 |
|---|---|
estimator |
你想要优化的 Scikit-learn 模型对象。 |
param_grid |
一个字典,定义了要搜索的参数和它们的候选值。 |
scoring |
评估模型性能的指标,可以是字符串(如 'accuracy', 'f1_macro', 'roc_auc')或一个可调用对象,默认是 estimator 的默认评分方法。 |
cv |
交叉验证的折数,可以是整数(如 5),也可以是交叉验证生成器对象,默认是 5。 |
n_jobs |
并行运行的作业数。-1 表示使用所有可用的 CPU 核心,可以大大缩短搜索时间。 |
verbose |
控制输出的详细程度。0 不输出,1 偶尔输出,>1 输出更详细的信息。 |
refit |
布尔值,如果为 True(默认),在找到最佳参数后,会用整个训练集(X_train, y_train)和最佳参数重新训练一个模型,并将该模型保存在 best_estimator_ 属性中。 |
重要注意事项与局限性
- 计算成本高:
GridSearchCV的最大缺点是它尝试所有组合,如果你有 3 个参数,每个参数有 10 个候选值,并且使用 5 折交叉验证,那么它总共要训练3 * 10 * 5 = 150个模型,参数空间越大,时间呈指数级增长。 - 可能错过最优解:它只在你给定的网格中寻找最优解,如果你设置的候选值范围不包含真正的最佳参数,它就找不到,参数网格的设置需要一定的先验知识或进行初步探索。
- 替代方案:
RandomizedSearchCV:这是GridSearchCV的一个流行替代方案,它不是尝试所有组合,而是在参数空间中进行随机采样,你只需要指定要尝试的组合次数 (n_iter),这在参数空间很大时非常高效,往往能用更少的计算量找到接近最优的解。- 更高级的优化库:如
Optuna,Hyperopt等,它们提供了更智能的搜索策略(如贝叶斯优化),可以更高效地找到全局最优解。
GridSearchCV 是 Scikit-learn 中一个极其实用且强大的工具,是机器学习工作流程中模型调优环节的标准配置,虽然它计算成本高,但对于中小型数据集和参数空间来说,它提供了一种简单、全面且可靠的超参数优化方法,掌握它的使用,能显著提升你的模型性能。
