首页 > 解决方案 > 将 SageMaker 管道模式与 tfrecords 的 s3 目录一起使用

问题描述

当我sagemaker.tensorflow.TensorFlow.fit()使用Pipe而不是File作为input_mode. 我相应地将 TensorFlow 替换DatasetPipemodedataset. 模式下的训练File成功完成。

我的数据由两个 s3 存储桶组成,每个存储桶中有多个 tfrecord 文件。尽管仔细阅读了文档,但我对Pipemodedataset在这种情况下如何使用 没有信心 - 特别是如何设置channel.

这是我的 Sagemaker 笔记本设置:

hyperparameters = {
    "batch-size": 1,
    "pipe_mode": 1,
}

estimator_config = {
    "entry_point": "tensorflow_train.py",
    "source_dir": "source",
    "framework_version": "2.3",
    "py_version": "py37",
    "instance_type": "ml.p3.2xlarge",
    "instance_count": 1,
    "role": sagemaker.get_execution_role(),
    "hyperparameters": hyperparameters,
    "output_path": f"s3://{bucket_name}",
    "input_mode": "Pipe",
}

tf_estimator = TensorFlow(**estimator_config)

s3_data_channels = {
    "training": f"s3://{bucket_name}/data/training",
    "validation": f"s3://{bucket_name}/data/validation",
}

tf_estimator.fit(s3_data_channels)

如果我要在aws s3 ls上运行s3_data_channels,我会得到一个 tfrecord 文件列表。

这是我设置数据集的方式(根据是否pipe_mode选择查看 if / else 语句:

import tensorflow as tf

if __name__ == "__main__":

    arg_parser = argparse.ArgumentParser()
    ...
    arg_parser.add_argument("--pipe_mode", type=int, default=0)

    arg_parser.add_argument("--train_dir", type=str, default=os.environ.get("SM_CHANNEL_TRAINING"))
    arg_parser.add_argument(
        "--validation_dir", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION")
    )
    arg_parser.add_argument("--model_dir", type=str)
    args, _ = arg_parser.parse_known_args()

    AUTOTUNE = tf.data.experimental.AUTOTUNE

    if args.pipe_mode == 1:
        from sagemaker_tensorflow import PipeModeDataset
        train_ds = PipeModeDataset(channel="training", record_format='TFRecord')
        val_ds = PipeModeDataset(channel="validation", record_format='TFRecord')

    else:
        train_files = tf.data.Dataset.list_files(args.train_dir + '/*tfrecord')
        val_files = tf.data.Dataset.list_files(args.validation_dir + '/*tfrecord')
        train_ds = tf.data.TFRecordDataset(filenames=train_files, num_parallel_reads=AUTOTUNE)
        val_ds = tf.data.TFRecordDataset(filenames=val_files, num_parallel_reads=AUTOTUNE)

    train_ds = (
        train_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
        .batch(args.batch_size)
        .prefetch(AUTOTUNE)
    )

    val_ds = (
        val_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
        .batch(args.batch_size)
        .prefetch(AUTOTUNE)
    )
    ...

标签: tensorflowdeep-learningamazon-sagemakertensorflow-datasets

解决方案


推荐阅读