python - 使用 Joblib 时返回 scikit-learn 对象
问题描述
我有一个 numpy 数组,我正在使用 sklearn 沿第一个轴转换数组。我还想将转换器对象保存在 dict 中,以便稍后在代码中使用。这是我的代码:
scalers_dict = {}
for i in range(train_data_numpy.shape[1]):
for j in range(train_data_numpy.shape[2]):
scaler = QuantileTransformer(n_quantiles=60000, output_distribution='uniform')
train_data_numpy[:,i,j] = scaler.fit_transform(train_data_numpy[:,i,j].reshape(-1,1)).reshape(-1)
scalers_dict[(i,j)] = scaler
我的 train_data_numpy 是 shape (60000, 28,28)
。问题是这需要很长时间来处理(train_data_numpy 是 MNIST 数据集)。我有一个 16 核的 AMD Ryzen 5950X,我想并行化这段代码。
例如,我知道我可以写这样的东西:
Parallel(n_jobs=16)(delayed(QuantileTransformer(n_quantiles=60000, output_distribution='uniform').fit_transform)(train_data_numpy[:,i,j].reshape(-1,1)) for j in range(train_data_numpy.shape[2]))
但这不会返回缩放器对象,而且我不知道如何利用 Joblib 来完成这项任务。
解决方案
您可以使用在Dask Library之上实现的Dask-ML,但它与.scikit-learn
安装:
conda install -c conda-forge dask-ml
or
pip install dask-ml
例子
import time
from sklearn.datasets import make_classification
from sklearn.preprocessing import QuantileTransformer as skQT
from dask_ml.preprocessing import QuantileTransformer as daskQT
# toy big dataset for testing
X, y = make_classification(n_samples=1000000, n_features=100, random_state=2021)
# Comparison
scaler = skQT()
start_ = time.time()
scaler.fit_transform(X)
end_ = time.time() - start_
print("No Parallelism -- Time Elapsed: {}".format(end_))
# Using Dask ML
scaler = daskQT()
start_ = time.time()
scaler.fit_transform(X)
end_ = time.time() - start_
print("With Parallelism -- Time Elapsed: {}".format(end_))
结果
No Parallelism -- Time Elapsed: 18.680
With Parallelism -- Time Elapsed: 2.982
我的设备规格:
Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
Number of Cores: 12
推荐阅读
- php - Google Suite Rest APi 更新用户照片 PHP
- java - 通过 Spring-boot 中的 REST 端点读取环境变量值
- azure - 在使用多个身份验证方案时防止 IDX10501 错误 (Microsoft.Identity.Web)
- flutter - 执行无头飞镖代码的 Flutter 插件
- json - 在 Spark scala 中将数据框列的数组展平为单独的列和相应的值
- docker - 等待启动 Docker Stack,直到挂载文件系统
- flutter - 未处理的异常:在收到完整标头之前连接已关闭。在 IOClient.send 和 BaseClient._sendUnstreamed 上出现异常
- javascript - VSCode如何修复打字稿突出显示错误{}类型的参数不可分配给参数
- c# - 如何显示列表的值
在使用 XAML 绑定的 ListView 中? - sass - 以数字开头的 ID 上的 Dart sass 错误