tensorflow - 将 SageMaker 管道模式与 tfrecords 的 s3 目录一起使用
问题描述
当我sagemaker.tensorflow.TensorFlow.fit()
使用Pipe
而不是File
作为input_mode
. 我相应地将 TensorFlow 替换Dataset
为Pipemodedataset
. 模式下的训练File
成功完成。
我的数据由两个 s3 存储桶组成,每个存储桶中有多个 tfrecord 文件。尽管仔细阅读了文档,但我对Pipemodedataset
在这种情况下如何使用 没有信心 - 特别是如何设置channel
.
这是我的 Sagemaker 笔记本设置:
hyperparameters = {
"batch-size": 1,
"pipe_mode": 1,
}
estimator_config = {
"entry_point": "tensorflow_train.py",
"source_dir": "source",
"framework_version": "2.3",
"py_version": "py37",
"instance_type": "ml.p3.2xlarge",
"instance_count": 1,
"role": sagemaker.get_execution_role(),
"hyperparameters": hyperparameters,
"output_path": f"s3://{bucket_name}",
"input_mode": "Pipe",
}
tf_estimator = TensorFlow(**estimator_config)
s3_data_channels = {
"training": f"s3://{bucket_name}/data/training",
"validation": f"s3://{bucket_name}/data/validation",
}
tf_estimator.fit(s3_data_channels)
如果我要在aws s3 ls
上运行s3_data_channels
,我会得到一个 tfrecord 文件列表。
这是我设置数据集的方式(根据是否pipe_mode
选择查看 if / else 语句:
import tensorflow as tf
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
...
arg_parser.add_argument("--pipe_mode", type=int, default=0)
arg_parser.add_argument("--train_dir", type=str, default=os.environ.get("SM_CHANNEL_TRAINING"))
arg_parser.add_argument(
"--validation_dir", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION")
)
arg_parser.add_argument("--model_dir", type=str)
args, _ = arg_parser.parse_known_args()
AUTOTUNE = tf.data.experimental.AUTOTUNE
if args.pipe_mode == 1:
from sagemaker_tensorflow import PipeModeDataset
train_ds = PipeModeDataset(channel="training", record_format='TFRecord')
val_ds = PipeModeDataset(channel="validation", record_format='TFRecord')
else:
train_files = tf.data.Dataset.list_files(args.train_dir + '/*tfrecord')
val_files = tf.data.Dataset.list_files(args.validation_dir + '/*tfrecord')
train_ds = tf.data.TFRecordDataset(filenames=train_files, num_parallel_reads=AUTOTUNE)
val_ds = tf.data.TFRecordDataset(filenames=val_files, num_parallel_reads=AUTOTUNE)
train_ds = (
train_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
.batch(args.batch_size)
.prefetch(AUTOTUNE)
)
val_ds = (
val_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
.batch(args.batch_size)
.prefetch(AUTOTUNE)
)
...
解决方案
推荐阅读
- php - 如何将易受攻击的 sql 查询转换为参数化查询 PHP
- python - 刻度在颜色栏上未正确显示
- python - 跨多个进程的共享字典未更新
- python - 如何修复python中的“名称'self'未定义”错误?
- visual-studio - Angular 项目无法在具有 Visual Studio 2019 的 docker 上运行
- javascript - 去抖功能单元测试
- c# - 如何访问嵌套命名空间以避免在 .NET 中完全限定
- python - 如果在 Python 3.7 中运行超过一次,函数将不再起作用
- python - 按值排序字典,然后按键
- c++ - 确定每次更改源文件时重新配置 CMake 的原因