首页 > 解决方案 > “DBSCAN”对象没有使用 GridSearchCV 和管道的属性“预测”

问题描述

我正在使用scikit-learn库并从中构建管道

这是我构建的管道的最后(也是主要)部分:

preprocessor_steps = [('data_transformer', data_transformer),
                      ('reduce_dim', TruncatedSVD())]
preprocessor = Pipeline(steps=preprocessor_steps)

clustering_steps = [('preprocessor', preprocessor),
                    ('cluster', DummyEstimator())]
clustering = Pipeline(steps=clustering_steps)

data_transformer具有 OneHotEncoder、KNNImputer 等步骤。

现在我有GridSearchCV

param_grid = [{
      'cluster': [KMeans()],
      'cluster__n_clusters': range(1, 11),
      'cluster__init': ['k-means++', 'random']
    },  
    {
      'cluster': [DBSCAN()],
      'cluster__eps': [0.5, 0.7, 1],
    }]

grid_search = GridSearchCV(estimator=clustering, param_grid=param_grid, 
                           scoring='accuracy', verbose=2, n_jobs=1,
                           error_score='raise')
  
grid_search.fit(X_train, y_train)

它适用于KMeans的所有超参数,但对于DBSCAN失败。它给出了一个错误:

AttributeError: 'DBSCAN' object has no attribute 'predict'

我认为这是因为 DBSCAN 有“fit_predict”而不是“predict”。我不想改变我的布局(比如从 GridSearchCV 找到最佳管道),因为我有更多的参数和算法要比较。

标签: pythonscikit-learnpipelinedbscangridsearchcv

解决方案


我遇到了同样的问题AgglomerativeClustering 并解决了这个问题,我像这样使用 Wrapper:

class AgglomerativeClusteringWrapper(AgglomerativeClustering):
    def predict(self,X):
      return self.labels_.astype(int)

因此,您可以更改为 DBSCAN,一切都会正常工作。


推荐阅读