杰瑞科技汇

Python的BaseEstimator核心作用与使用场景?

BaseEstimatorscikit-learn 库中一个非常重要的基类,它本身不实现任何机器学习算法,而是提供了一套标准化的接口工具,让你的自定义类可以无缝地与 scikit-learn 的生态系统(如 GridSearchCV, Pipeline, cross_val_score 等)协同工作。

Python的BaseEstimator核心作用与使用场景?-图1
(图片来源网络,侵删)

继承 BaseEstimator 是让你的自定义模型“获得” scikit-learn 兼容性的“金钥匙”。


为什么需要 BaseEstimator

当你从头开始编写一个自己的机器学习模型时,你可能会遇到以下问题:

  • 参数管理混乱:你的模型可能有多个超参数(比如学习率、树的深度等),如何以一种统一、标准化的方式存储和管理它们?
  • 无法与 scikit-learn 工具链集成:scikit-learn 的核心功能,如超参数调优 (GridSearchCV)、交叉验证 (cross_val_score) 和模型管道 (Pipeline),都依赖于一个标准的接口,如果你的模型不遵循这个接口,这些强大的工具就无法使用。
  • 代码不规范:缺乏一个统一的框架,容易写出难以维护和复用的代码。

BaseEstimator 解决了所有这些问题,它通过要求你的类实现特定的方法(fit, predict 等)并提供标准的参数管理机制,让你的自定义模型立刻成为 scikit-learn 生态系统中的一等公民。


BaseEstimator 提供的核心功能

BaseEstimator 主要提供了两大核心功能:

Python的BaseEstimator核心作用与使用场景?-图2
(图片来源网络,侵删)

a) 标准化的参数管理

这是 BaseEstimator 最强大的特性之一,它通过 __init__ 方法的特殊实现,允许你像这样定义参数:

class MyCustomEstimator(BaseEstimator):
    def __init__(self, param_1=5, param_2="default_string"):
        self.param_1 = param_1
        self.param_2 = param_2

当你这样写时,BaseEstimator 会自动帮你做两件事:

  1. 实例化时赋值:当你创建 MyCustomEstimator(param_1=10) 时,self.param_1 会被自动设置为 10。
  2. 获取所有参数:它提供了一个 get_params() 方法,可以返回一个包含所有参数及其当前值的字典。

示例:

model = MyCustomEstimator(param_1=10, param_2="hello")
print(model.get_params())

输出:

Python的BaseEstimator核心作用与使用场景?-图3
(图片来源网络,侵删)
{'param_1': 10, 'param_2': 'hello'}

这个 get_params() 方法是 GridSearchCV 等工具进行超参数搜索的基础。GridSearchCV 会调用 get_params() 来获取模型的所有可调参数,然后尝试不同的组合。

b) 标准的 API 接口

虽然 BaseEstimator 本身不实现 fitpredict,但它通常是另一个基类 Mixin 的组合,即 ClassifierMixin, RegressorMixin, TransformerMixin 等。

  • BaseEstimator: 提供参数管理。
  • ClassifierMixin: 提供分类器相关的工具,最重要的是实现了 score() 方法,该方法默认使用准确率作为评估指标。
  • RegressorMixin: 提供回归器相关的工具,同样实现了 score() 方法,该方法默认使用 R² 分数作为评估指标。

一个标准的自定义模型通常会这样继承:

from sklearn.base import BaseEstimator, ClassifierMixin
class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, some_param=1.0):
        self.some_param = some_param
    def fit(self, X, y):
        # 在这里实现模型的训练逻辑
        # ...
        self.is_fitted_ = True  # 标记模型已训练
        return self
    def predict(self, X):
        # 在这里实现模型的预测逻辑
        # ...
        return predictions

通过继承 ClassifierMixin,你的 MyClassifier 自动就拥有了 score() 方法,可以直接用于评估。


如何使用 BaseEstimator 创建一个自定义模型(完整示例)

下面我们创建一个非常简单的自定义分类器:一个基于阈值的分类器,如果某个特征的平均值大于给定阈值,就预测为类别 1,否则为类别 0。

步骤:

  1. 导入 BaseEstimatorClassifierMixin
  2. 创建一个新类,继承这两个基类。
  3. __init__ 中定义模型的超参数(这里是 threshold)。
  4. 实现 fit 方法(在这个简单模型中,我们不需要训练,但为了兼容性,我们仍然需要实现它)。
  5. 实现 predict 方法(这里是核心逻辑)。
  6. (可选但推荐)实现 get_feature_names_out 方法,以支持 Pipeline
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
class MeanThresholdClassifier(BaseEstimator, ClassifierMixin):
    """
    一个简单的自定义分类器。
    它计算输入特征 X 的平均值,如果平均值大于给定的阈值,
    则预测为类别 1,否则为类别 0。
    """
    def __init__(self, threshold=0.5):
        """
        初始化分类器。
        Parameters
        ----------
        threshold : float, default=0.5
            用于分类的阈值。
        """
        self.threshold = threshold
    def fit(self, X, y):
        """
        训练模型,对于这个简单的模型,训练过程是空的,
        但我们仍然需要这个方法来满足 scikit-learn 的 API。
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            训练数据。
        y : array-like of shape (n_samples,)
            目标值。
        Returns
        -------
        self : object
            返回实例本身。
        """
        # scikit-learn 提供的验证工具,可以检查 X 和 y 的格式是否正确
        X, y = check_X_y(X, y)
        # 存储训练数据的维度,用于在 predict 时验证
        self.n_features_in_ = X.shape[1]
        # 标记模型已训练
        self.is_fitted_ = True
        return self
    def predict(self, X):
        """
        使用训练好的模型进行预测。
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            输入数据。
        Returns
        -------
        y_pred : array-like of shape (n_samples,)
            预测的类别标签。
        """
        # 检查模型是否已经训练过
        check_is_fitted(self)
        # 检查输入数据的维度是否与训练时一致
        X = check_array(X)
        if X.shape[1] != self.n_features_in_:
            raise ValueError("Number of features does not match input")
        # 核心预测逻辑
        mean_value = np.mean(X)
        predictions = np.full(shape=X.shape[0], fill_value=0, dtype=int)
        if mean_value > self.threshold:
            predictions[:] = 1
        return predictions
    def get_feature_names_out(self, input_features=None):
        """
        为了与 Pipeline 兼容,推荐实现此方法。
        """
        return [f"feature_{i}" for i in range(self.n_features_in_)]

测试我们的自定义模型

我们可以像使用 scikit-learn 内置模型一样使用它。

# 1. 创建一些样本数据
X_train = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y_train = np.array([0, 0, 1, 1])
X_test = np.array([[10, 11]]) # 这个样本的平均值是 10.5,远大于阈值
# 2. 实例化我们的自定义分类器
# 默认阈值是 0.5
clf = MeanThresholdClassifier(threshold=5.0)
# 3. 训练模型
clf.fit(X_train, y_train)
print(f"模型是否已训练: {clf.is_fitted_}")
# 4. 进行预测
prediction = clf.predict(X_test)
print(f"输入数据: {X_test}, 预测结果: {prediction[0]}") # 应该预测为 1
# 5. 使用 scikit-learn 的工具
from sklearn.model_selection import cross_val_score
# 交叉验证会自动调用 fit 和 predict
scores = cross_val_score(clf, X_train, y_train, cv=2)
print(f"交叉验证分数: {scores}")
print(f"平均交叉验证分数: {scores.mean()}")
# 6. 检查参数
print(f"模型参数: {clf.get_params()}")

输出:


模型是否已训练: True
输入数据: [[10 11]], 预测结果:
分享:
扫描分享到社交APP
上一篇
下一篇