python - 来自 sklearn 的 MLPclassifier 在其他机器上执行时显示不同的精度?
问题描述
看起来在不同设备上运行具有相同输入的 sklearn MLPclassifier 会给出不同的准确度结果,即使设置了全局种子也是如此。
MWE:
import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
np.random.seed(1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y, random_state=np.random.RandomState(0))
nn = MLPClassifier(hidden_layer_sizes=(100,100),
activation='relu',
solver='adam',
alpha=0.001,
batch_size=50,
learning_rate_init=0.01,
max_iter=1000,
random_state=np.random.RandomState(0))
nn.fit(X_train, y_train)
y_train_pred = nn.predict(X_train)
acc_train = np.sum(y_train == y_train_pred, axis=0) / X_train.shape[0]
y_test_pred = nn.predict(X_test)
acc_test = np.sum(y_test == y_test_pred, axis=0) / X_test.shape[0]
results.append([acc_train,acc_test])
如何保证再现性(独立于执行设备)?
解决方案
我无法重现这一点。
如果有问题,这可能需要有关不同机器的更多信息。分别调用的结果是什么python -c 'import sklearn; sklearn.show_versions()'
?
以下代码在 Ubuntu/Red Hat 上给出了相同的结果scikit-learn==0.24.2
(我尝试使用不同的:numpy==1.19.1/1.20.2
和scipy==1.5.2/1.6.3
)。
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = MLPClassifier(
hidden_layer_sizes=(100, 100),
activation="relu",
solver="adam",
alpha=0.001,
batch_size=50,
learning_rate_init=0.01,
max_iter=1000,
random_state=0,
)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
print(clf.score(X_test, y_test))
输出:
0.9272300469483568
0.9370629370629371
推荐阅读
- python - Tensorflow:使用 argmax 对张量进行切片
- powershell - 如何使用 powershell 和 azure 函数创建 VM
- php - str_ireplace() 不只替换最后一个匹配
- sql-server - 如何在sqlserver中获取重复记录
- html - 删除内部的border-bottom nth-child
- javascript - Ionic 3 导航控制器后退按钮未显示
- android - Google Places API 错误 - ApiException: 9008: PLACES_API_INVALID_APP
- android - Iframe 中的 Ionic 3 网站专注于导致页面重定向的输入
- java - 什么参数需要传递给 DAO 方法调用
- node.js - 使用节点获取 dynamoDB 中没有一个属性的所有记录