python - GridSearchCV 估计我的模型很慢
问题描述
我在 MLP 分类器上使用 GridSearchCV,这是我的代码...
normalized_features.shape # (50000,784)
len(labels) # 50000
X_train, X_test, Y_train, Y_test = train_test_split(normalized_features, labels, test_size=0.2)
mlp = MLPClassifier(max_iter=100)
parameter_space = {
'hidden_layer_sizes': [(50,50,50), (50,100,50), (100,)],
'activation': ['tanh', 'relu'],
'solver': ['sgd', 'adam'],
'alpha': [0.0001, 0.05],
'learning_rate': ['constant','adaptive'],
}
这是我被击中的阶段,已经两个多小时了,仍然继续加载并抛出警告
clf = GridSearchCV(mlp, parameter_space, n_jobs=-1, cv=10)
clf.fit(X_train, Y_train)
警告:
/usr/local/lib/python3.6/dist-packages/joblib/externals/loky/process_executor.py:706:UserWarning:当一些工作被分配给执行程序时,工作人员停止了。这可能是由于工作人员超时时间过短或内存泄漏造成的。“超时或内存泄漏。”,用户警告
谁能帮我解决这个问题,让我知道我哪里出错了!先感谢您。
解决方案
正如@mujjiga 提到的,Gridsearch 将为72
您的每个折叠尝试不同的参数组合,如果您10
折叠了,则将被训练的总型号为720
.
您可能想要使用 RandomizedSearch,它可以通过较少的实验为您提供与 GridSearch 方法相似的结果。所以你可以减少你的训练时间。
您可以在scikit-learnrandomizedsearch
中找到实现。
您还可以通过此链接更详细地阅读网格搜索和随机搜索之间的比较。
推荐阅读
- matlab - 如何在 MatLab 中的两个曲面的交点处获取切片数据
- php - 在 mysql DB 中插入复选框值
- python - 确定文件相对于目录的路径,包括符号链接
- javascript - 在没有源代码或构建过程的 Electron 应用程序的生产构建中打开 Chromium DevTools
- ios - 如何在 Swift 中更改视频视图的大小
- google-apps-script - 如何将行移动到命名范围的最后一行(onEdit)
- java - # 在 XML 命名空间中似乎导致 java API 中的异常——为什么?
- optimization - 如何在 CPLEX ILOG 上创建的模型中实现本地搜索算法?
- gmp - 为什么我不能将这个科学记数法读入 GMP mpf_t?
- ios - iOS Swift 代码 - 当我有断点时才工作