python - GridSearchCV/RandomizedSearchCV 的结果无法通过使用相同参数运行单个模型来重现
问题描述
我正在运行 RandomizedSearchCV 5 倍以找到最佳参数。我有一个X_test
用来预测的保留集 ( )。我的部分代码是:
svc= SVC(class_weight=class_weights, random_state=42)
Cs = [0.01, 0.1, 1, 10, 100, 1000, 10000]
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
param_grid = {'C': Cs,
'gamma': gammas,
'kernel': ['linear', 'rbf', 'poly']}
my_cv = TimeSeriesSplit(n_splits=5).split(X_train)
rs_svm = RandomizedSearchCV(SVC(), param_grid, cv = my_cv, scoring='accuracy',
refit='accuracy', verbose = 3, n_jobs=1, random_state=42)
rs_svm.fit(X_train, y_train)
y_pred = rs_svm.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
clfreport = classification_report(y_test, y_pred)
print (rs_svm.best_params_)
现在,我有兴趣使用具有选定参数的独立运行模型(无随机搜索CV)重现此结果:
from sklearn.model_selection import TimeSeriesSplit
tcsv=TimeSeriesSplit(n_splits=5)
for train_index, test_index in tcsv.split(X_train):
train_index_ = int(train_index.shape[0])
test_index_ = int(test_index.shape[0])
X_train_, y_train_ = X_train[0:train_index_],y_train[0:train_index_]
X_test_, y_test_ = X_train[test_index_:],y_train[test_index_:]
class_weights = compute_class_weight('balanced', np.unique(y_train_), y_train_)
class_weights = dict(enumerate(class_weights))
svc= SVC(C=0.01, gamma=0.1, kernel='linear', class_weight=class_weights, verbose=True,
random_state=42)
svc.fit(X_train_, y_train_)
y_pred_=svc.predict(X_test)
cm = confusion_matrix(y_test, y_pred_)
clfreport = classification_report(y_test, y_pred_)
据我了解, clfreports 应该是相同的,但我在这次运行后的结果是:
有没有人有任何建议为什么会发生这种情况?
解决方案
鉴于您使用 RandomizedSearchCV 查找最佳超参数的第一个代码片段,您无需再次进行任何拆分;因此,在您的第二个片段中,您应该使用找到的超参数和使用整个训练集的类权重进行拟合,然后在您的测试集上进行预测:
class_weights = compute_class_weight('balanced', np.unique(y_train), y_train)
class_weights = dict(enumerate(class_weights))
svc= SVC(C=0.01, gamma=0.1, kernel='linear', class_weight=class_weights, verbose=True, random_state=42)
svc.fit(X_train, y_train)
y_pred_=svc.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
clfreport = classification_report(y_test, y_pred)
使用验证集、训练集和测试集之间的顺序讨论可能有助于阐明程序......
推荐阅读
- angular - 私有标识符仅在面向 ECMAScript 2015 及更高版本(Angular 9)时可用
- kotlin - 为什么我会收到 TornadoFX DataGrid 类型错误?
- lua - lua 5.3 lpeg:Cmt、Cb 和 / 运算符
- java - 没有 complexType 的 Apache cxf soap wsdl
- python - 在 Mask RCNN 中计算 mAP,实例分割
- c - 如何在C中比较两个没有秒的日期
- windows - Windows 命令列出过期帐户并排除设置为从不的帐户
- javascript - 无法在剧作家中捕获 response.json()
- python - 手动将 StreamHandler 添加到记录器会中断 IPython 提示
- laravel - 如何在 Laravel 中防止重复发票编号