dask - 如何使用 dask 将 SVC 分发给不同的工作人员(在其他计算机上)
问题描述
我的 PC 上运行了一个调度程序,我想在不同的工作计算机上训练 10 个 SVC 实例。我摆弄着,但找不到解决办法
解决方案
我假设你想用不同的超参数训练这 10 个 SVC 并找到最好的一个(也就是你可以使用 gridsearchCV 进行的超参数优化)。我还假设您正在使用 scikit learn。
通常你会使用如下代码训练 SVC:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC
# Loading the Digits dataset
digits = datasets.load_digits()
# To apply an classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target
# Split the dataset in two equal parts
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.5, random_state=0)
# Set the parameters by cross-validation
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
'C': [1, 10, 100, 1000]},
{'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
scores = ['precision', 'recall']
for score in scores:
print("# Tuning hyper-parameters for %s" % score)
print()
clf = GridSearchCV(SVC(), tuned_parameters, cv=5,
scoring='%s_macro' % score)
clf.fit(X_train, y_train)
print("Best parameters set found on development set:")
print()
print(clf.best_params_)
print()
print("Grid scores on development set:")
print()
means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
for mean, std, params in zip(means, stds, clf.cv_results_['params']):
print("%0.3f (+/-%0.03f) for %r"
% (mean, std * 2, params))
print()
print("Detailed classification report:")
print()
print("The model is trained on the full development set.")
print("The scores are computed on the full evaluation set.")
print()
y_true, y_pred = y_test, clf.predict(X_test)
print(classification_report(y_true, y_pred))
print()
但它只会在一个线程上按顺序训练。
如果您安装 dask-ML,则可以利用替换网格搜索
conda install dask-searchcv -c conda-forge
更换
from sklearn.model_selection import GridSearchCV
经过
from dask_searchcv import GridSearchCV
应该足够了。
但是,在您的情况下,您不想使用线程调度程序,而是使用分布式调度程序。因此,您必须在开头添加以下代码
# Distribute grid-search across a cluster
from dask.distributed import Client
scheduler_address = '127.0.0.1:8786'
client = Client(scheduler_address)
最终代码应如下所示(未经测试)
from sklearn import datasets
from sklearn.model_selection import train_test_split
from dask_searchcv import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC
# Distribute grid-search across a cluster
from dask.distributed import Client
scheduler_address = '127.0.0.1:8786'
client = Client(scheduler_address)
# Loading the Digits dataset
digits = datasets.load_digits()
# To apply an classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target
# Split the dataset in two equal parts
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.5, random_state=0)
# Set the parameters by cross-validation
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
'C': [1, 10, 100, 1000]},
{'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
scores = ['precision', 'recall']
for score in scores:
print("# Tuning hyper-parameters for %s" % score)
print()
clf = GridSearchCV(SVC(), tuned_parameters, cv=5,
scoring='%s_macro' % score)
clf.fit(X_train, y_train)
print("Best parameters set found on development set:")
print()
print(clf.best_params_)
print()
print("Grid scores on development set:")
print()
means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
for mean, std, params in zip(means, stds, clf.cv_results_['params']):
print("%0.3f (+/-%0.03f) for %r"
% (mean, std * 2, params))
print()
print("Detailed classification report:")
print()
print("The model is trained on the full development set.")
print("The scores are computed on the full evaluation set.")
print()
y_true, y_pred = y_test, clf.predict(X_test)
print(classification_report(y_true, y_pred))
print()
推荐阅读
- arm - 如何使用 arm v7 neon 内在函数获得 Q 寄存器(int64x2_t)的绝对值?
- r - 如何在一个坐标系中绘制两个图表,结合一列的所有值?
- android - 如何在GitHub方形日历android中更改月份的文本颜色
- php - mysql - 如果存在则更新日期字段
- python - 将动画情节从 vscode 保存到我的电脑
- jenkins - 通过 Jenkins 执行 UFT 测试时的 UI 问题
- python - 在 Python 中格式化请求查询时出错
- c - 在这种情况下,是否会由于空闲堆块中的延迟数据而发生错误?
- mesibo - 调用 Mesibo API 的奇怪反应
- vb.net - 如何将文本框的值/字符串传递给 vb.net 中 XtraReport 中的 XtraLabel?