python - AWS SageMaker 训练脚本:如何传递自定义用户参数
问题描述
我正在使用 Scikit-learn 和 SageMaker Python SDK 训练分类器。
整个过程涉及三个连续的阶段:
- 使用训练和验证数据集的超参数调整作业
- 使用 1. 中建立的最佳超参数和整个数据集(从 1. 开始的训练 + 验证)进行训练工作
- 使用 2. 中提供的“预拟合”模型和用于校准的附加数据集训练校准模型。
我需要拆分过程的原因是保存在步骤 2 中创建的未校准模型。
对于这一步的每一步,我都会准备一个训练脚本,如下所述:https ://sagemaker.readthedocs.io/en/stable/using_sklearn.html#prepare-a-scikit-learn-training-script
这三个脚本非常相似,为了避免代码冗余,我想在这三种情况下使用一个带有额外逻辑的脚本。更准确地说,我想将额外的自定义参数传递给和对象的.fit
方法,以便能够根据使用情况(阶段 1. ,2. 或 3.)来操作脚本中的逻辑。 sagemaker.tuner.HyperparameterTuner
sagemaker.sklearn.estimator.SKLearn
我已经尝试过破解SM_CHANNEL_XXX
parser.add_argument('--myparam', type=str, default=os.environ.get('SM_CHANNEL_MYPRAM'))
while 调用.fit(inputs={'train': ..., 'test': ..., 'myparam': myvalue})
,但它需要一个有效的 s3 URI。
关于如何将额外的自定义参数传递给训练脚本的任何想法?
解决方案
您可以不在 fit 方法中传递超参数,而是在创建估计器之前直接传递。文档中的示例是:
sklearn_estimator = SKLearn('sklearn-train.py',
train_instance_type='ml.m4.xlarge',
framework_version='0.20.0',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-
rate': 0.1})
sklearn_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
这就是您将参数(从笔记本中)带入训练脚本以通过 parser.add_argument 访问的方式。如果您只有一个脚本,您可以在脚本中处理您的逻辑。但这不会将自定义参数添加到 sagemaker.tuner.HyperparameterTuner 的 .fit 方法中。
我使用以下序列来优化脚本中的参数,然后应用最佳参数(也只使用一个训练脚本)。也许您将此应用于您的案例。您应该能够在脚本中使用 joblib.dump 保存中间模型:
param_grid = [{'vect__ngram_range': [(1, 1)],
'vect__stop_words': [stop, None],
'clf__penalty': ['l1', 'l2'],
'clf__C': [1.0, 10.0, 100.0]},
{'vect__ngram_range': [(1, 1)],
'vect__stop_words': [stop, None],
'vect__use_idf':[False],
'vect__norm':[None],
'clf__penalty': ['l1', 'l2'],
'clf__C': [1.0, 10.0, 100.0]},
]
lr_tfidf = Pipeline([('vect', tfidf),
('clf', LogisticRegression(random_state=0))])
gs_lr_tfidf = GridSearchCV(lr_tfidf, param_grid,
scoring='accuracy',
cv=5,
verbose=1,
n_jobs=-1)
gs_lr_tfidf.fit(X_train, y_train)
推荐阅读
- ruby-on-rails - 将 Rails 应用程序部署到 Heroku 时出现 NoMethodError(在模型上调用 .create 时)
- flutter - 如何为 FCM 推送通知设置字幕/字幕?
- php - Gettig 以分层形式包含子类别的所有类别帖子
- python - 在 pandas 中使用 read_csv 和 to_csv 时如何保留 csv 数据
- prolog - 分支中的单例变量
- r - 如何检查一个数据框中的值是否存在于R中的另一个数据框中?
- openiddict - 即使使用 EnableRequestCaching(),是否有任何方法可以在控制器中找到获取自定义查询参数?
- wkwebview - 获取 WKWebView 的渲染内容
- javascript - 以正确的顺序放置反应钩子
- bash - 如何将 bash 脚本中的参数作为不带引号的参数传递?