首页 > 解决方案 > 如何在 Python 中 Sagemaker 的 XGBoost 训练作业中指定 content_type?

问题描述

我正在尝试使用 sagemaker 库训练模型。到目前为止,我的代码如下:

container = get_image_uri(boto3.Session().region_name,
                      'xgboost', 
                      repo_version='0.90-1')

estimator = sagemaker.estimator.Estimator(container, 
                                          role = 'AmazonSageMaker-ExecutionRole-20190305TXXX',
                                          train_instance_count = 1,
                                          train_instance_type = 'ml.m4.2xlarge',
                                          output_path = 's3://antifraud/production/',
                                          hyperparameters = {'num_rounds':'400',
                                                             'objective':'binary:logistic',
                                                             'eval_metric':'error@0.1'})

train_config = training_config(estimator=estimator,
                               inputs = {'train':'s3://antifraud/production/train',
                                         'validation':'s3://-antifraud/production/validation'})

我在解析超参数时遇到错误。此命令在控制台中为我提供了配置 JSON 输出。我已经能够使用配置为 Json 的 boto3 运行训练作业,所以我发现我的代码生成的 json 配置中缺少的是 content_type 参数,它应该如下所示:

"InputDataConfig": [
    {
        "ChannelName": "train",
        "DataSource": {
            "S3DataSource": {
                "S3DataType": "S3Prefix",
                "S3Uri": "s3://antifraud/production/data/train",
                "S3DataDistributionType": "FullyReplicated" 
            }
        },
        "ContentType": "text/csv",
        "CompressionType": "None"
    },
    {
        "ChannelName": "validation",
        "DataSource": {
            "S3DataSource": {
                "S3DataType": "S3Prefix",
                "S3Uri": "s3://antifraud/production/validation",
                "S3DataDistributionType": "FullyReplicated"
            }
        },
        "ContentType": "text/csv",
        "CompressionType": "None"
    }
]

我尝试在容器、估计器和 train_config 中编码 content_type = 'text/csv' 作为参数,并在输入中作为字典的另一个键,但没有成功。我怎样才能使这项工作?

标签: pythonamazon-web-servicesamazon-sagemaker

解决方案


我已经使用 s3_input 对象解决了它:

s3_input_train = sagemaker.s3_input(s3_data='s3://antifraud/production/data/{domain}-{product}-{today}/train_data.csv',
content_type='text/csv')
s3_input_validation = sagemaker.s3_input(s3_data='s3://antifraud/production/data/{domain}-{product}-{today}/validation_data.csv',
content_type='text/csv')

train_config = training_config(estimator=estimator,
inputs = {'train':s3_input_train,
          'validation':s3_input_validation})

推荐阅读