python - 在 python 中使用 sklearn 自己的估计器进行网格搜索 CV
问题描述
我正在尝试构建自己的估计器(回归器)并将其用于插补(KnnImputation)。我在使用网格搜索“GridSearchCV”时遇到问题。有什么想法有什么问题吗?
我的代码:
class KnnImputation(BaseEstimator, RegressorMixin):
def __init__(self, k=5, distance='euclidean'):
self.k = k
self.distance = distance
def get_params(self, deep=False):
return {'k': self.k, 'distance': self.distance}
def set_params(self, **parameters):
self.k = parameters['k']
self.distance = parameters['distance']
def fit(self, X, y):
self.xTrain = X.values
self.yTrain = y.values
return self
def predict(self, X):
........
return yPred
# scorer:
scorer = make_scorer(mean_squared_error)
kf = KFold(n_splits=10, shuffle=False, random_state=23)
NN = KnnImputation()
gridSearchNN = GridSearchCV(NN, param_grid=params, scoring=scorer, n_jobs=1,
cv=kf.split(xTrain, yTrain), verbose=1)
gridSearchNN.fit(X=xTrain, y=yTrain)
我的错误:
....
File "C:\Users\...........\dataImputation.py", line 85, in knnImputationMethod
gridSearchNN.fit(X=xTrain, y=yTrain)
File "C:\Users\.....\Anaconda3\lib\site-packages\sklearn\model_selection\_search.py", line 740, in fit
self.best_estimator_.fit(X, y, **fit_params)
AttributeError: 'NoneType' object has no attribute 'fit'
解决方案
从sklearn
源代码中sklearn.model_selection._search
,我们在方法中有以下代码fit
:
if self.refit:
self.best_estimator_ = clone(base_estimator).set_params(
**self.best_params_)
refit_start_time = time.time()
if y is not None:
self.best_estimator_.fit(X, y, **fit_params)
这里重要的是这条线:
self.best_estimator_ = clone(base_estimator).set_params(**self.best_params_)
克隆由base_estimator
对象组成,它只是您的KNNImputation
类。set_params()
然后在该克隆的估计器上调用实例方法。然后变量self.best_estimator
指向 的返回值set_params()
。
在您提供的代码中,该set_params()
方法没有return
语句,因此它返回None
. 因此,调用self.best_estimator_.fit()
等同于None.fit()
,这显然是行不通的。self
您需要通过在set_params()
函数内返回来启用方法链接。
相关代码将是:
def set_params(self, **parameters):
self.k = parameters['k']
self.distance = parameters['distance']
return self
TL;博士:
您需要set_params
通过返回来启用方法链接self
。
推荐阅读
- javascript - 使用 rangy 创建拖放
- c++ - C++ 得到 -243030403 和 \300\371 数字
- php - HTML实体到十六进制
- python-3.x - 发电机 | 最后一次 yield 后触发数据库 INSERT
- javascript - Javascript 不呈现条件语句 - Django 项目
- javascript - scrollIntoView 在延迟图像加载时无法正常工作
- javascript - jQuery AJAX POST 到 PHP 有效,XMLHttpRequest 无效
- algorithm - 具有两种成本的有向无环图中的最短路径
- python - 使用 Python 解析频率分布图的嵌套行文本文档
- docker - Docker 挂载点是如何决定的?