amazon-sagemaker - 如何使用 Estimator 在 sagemaker 中保存 model.tar.gz 文件
问题描述
我无法使用以下代码将模型工件保存在 S3 存储桶中。我成功地将结果保存在输出数据路径中,并且培训工作已成功完成。我正在使用以下代码。
任何人都可以请确认我们如何使用下面的代码将 model_artifacts 保存在 model-dir 中。
# train.py code
#!/usr/bin/env python
from __future__ import print_function
import os
import sys
import pandas as pd
prefix = '/opt/ml/'
input_dir = prefix + 'input/data'
output_data_dir = os.path.join(prefix, 'output/data')
model_dir = os.path.join(prefix, 'model')
channel_name='training'
training_path = os.path.join(input_dir, channel_name)
# The function to execute the training.
def train():
print('Starting the training.')
# Take the set of files and read them all into a single pandas dataframe
input_files = [ os.path.join(training_path, file) for file in os.listdir(training_path) ]
raw_data = [ pd.read_csv(file, header=None) for file in input_files ]
input_data = pd.concat(raw_data)
print(pd.DataFrame(input_data))
output_data = input_data.to_csv(os.path.join(output_data_dir, 'output.csv'), header=False, index=False)
if __name__ == '__main__':
train()
# Below are the S3 input and output paths :
output_path = "s3://{}/{}".format(bucket, prefix_output)
S3_input = "s3://{}/{}".format(bucket, prefix)
#Estimator Code
test_estimator = sagemaker.estimator.Estimator(ecr_image, # ECR image arn,
role=role, # execution role
instance_count=1, # no. of sagemaker instances
instance_type='ml.m4.xlarge', # instance type
output_path=output_path, # output path to store model outputs
base_job_name='sagemaker-job1', # job name prefix
sagemaker_session=session # session
)
# Launch instance and start training
test_estimator.fit({'training':S3_input})
这段代码缺少什么?
解决方案
Sagemaker 会自动保存到output_path
模型目录中的所有内容,因此/opt/ml/model
. 如果训练作业成功完成,最后 Sagemaker 会获取该文件夹中的所有内容,创建一个model.tar.gz
并上传到output_path
与您的训练作业同名的文件夹中(sagemaker 创建此文件夹)。您还可以使用环境变量SM_OUTPUT_DATA_DIR
,它默认指向/opt/ml/output/data
并放置非模型训练工件(例如评估结果),Sagemaker 将从该文件夹创建一个存档,output.tar.gz
并将其上传model.tar.gz
到 S3 上的同一文件夹中。
我不明白您对“结果”的确切含义,但是无论您想放入该存档中的任何内容,都取决于您将其保存在您的model_dir
. 因此,例如,我如何将模型保存在 json 和 H5 中,第一个将在output.tar.gz
存档中,后者在model.tar.gz
output_artifacts = os.environ.get('SM_OUTPUT_DATA_DIR')
with open(os.path.join(output_artifacts,"model.json"), "w") as json_file:
json_file.write(model_json)
model_directory = os.environ.get('SM_MODEL_DIR')
model.save(os.path.join(model_directory, 'model.h5'))
推荐阅读
- tomcat - IIS Apache Tomcat 重定向
- javascript - Django 静态文件不加载或不工作?
- docker - Docker - Spring Boot 应用程序 - 无法访问本地主机上的 MySql 服务器
- angularjs - angular JS + Spring rest + 文件附件
- html - 将边距应用于居中
- docker - 如何将命令行参数传递给 java spring docker 容器
- django - 如何正确捕获和嵌套 ValidationErrors
- javascript - 使用按钮向表格添加行
- reactjs - 为什么我的应用程序状态的更改不会导致我的子组件重新渲染?
- class - 类的对象没有在 Flutter 的循环中实例化