tensorflow - 如何使用 TensorFlow Estimator 在 Amazon SageMaker 中获得可重现的结果?
问题描述
我目前正在使用 AWS SageMaker Python SDK为我的数据训练 EfficientNet 模型 ( https://github.com/qubvel/efficientnet )。具体来说,我使用 TensorFlow 估计器,如下所示。此代码在 SageMaker 笔记本实例中
import sagemaker
from sagemaker.tensorflow.estimator import TensorFlow
### sagemaker version = 1.50.17, python version = 3.6
estimator = TensorFlow("train.py", py_version = "py3", framework_version = "2.1.0",
role = sagemaker.get_execution_role(),
train_instance_type = "ml.m5.xlarge",
train_instance_count = 1,
image_name = 'xxx.dkr.ecr.xxx.amazonaws.com/xxx',
hyperparameters = {list of hyperparameters here: epochs, batch size},
subnets = [xxx],
security_group_ids = [xxx]
estimator.fit({
'class_1': 's3_path_class_1',
'class_2': 's3_path_class_2'
})
train.py 的代码包含通常的训练过程,从 S3 获取图像和标签,将它们转换为 EfficientNet 输入的正确数组形状,然后拆分为训练集、验证集和测试集。为了获得可重现的结果,我在调用 EfficientNet 模型本身之前使用以下 reset_random_seeds 函数(如果 Keras 结果不可重现,比较模型和选择超参数的最佳做法是什么? )。
### code of train.py
import os
os.environ['PYTHONHASHSEED']=str(1)
import numpy as np
import tensorflow as tf
import efficientnet.tfkeras as efn
import random
### tensorflow version = 2.1.0
### tf.keras version = 2.2.4-tf
### efficientnet version = 1.1.0
def reset_random_seeds():
os.environ['PYTHONHASHSEED']=str(1)
tf.random.set_seed(1)
np.random.seed(1)
random.seed(1)
if __name__ == "__main__":
### code for getting training data
### ... (I have made sure that the training input is the same every time i re-run the code)
### end of code
reset_random_seeds()
model = efn.EfficientNetB5(include_top = False,
weights = 'imagenet',
input_shape = (80, 80, 3),
pooling = 'avg',
classes = 3)
model.compile(optimizer = 'Adam', loss = 'categorical_crossentropy')
model.fit(X_train, Y_train, batch_size = 64, epochs = 30, shuffle = True, verbose = 2)
### Prediction section here
但是,每次我运行笔记本实例时,我总是得到与上次运行不同的结果。当我将 train_instance_type 切换为“local”时,每次运行笔记本时总是得到相同的结果。因此,是否是我选择的训练实例类型导致的不可重现的结果?此实例 (ml.m5.xlarge) 有 4 个 vCPU、16 个内存 (GiB) 和没有 GPU。如果是这样,如何在这个训练实例下获得可重现的结果?
解决方案
推荐阅读
- c# - C#在数组列表中搜索特定值并返回相关值
- reactjs - dockerized react - 如何自动启用热重载
- reactjs - 将元素列表推送到 React 中的组件
- oracle - 如何使用立即执行添加列
- python - 将字符串转换为日期对象(csv 条目中不包含时间戳)
- javascript - 通过 CKEditor 4 中的 API 上传图片
- asp.net-core - 在 asp.net core web api 控制器中读取表单数据时出现 405 http 错误
- json - FHIR 覆盖资源
- angular - 在插入项上刷新数据表
- ios - 登录后如何在 SwiftUI 中显示新视图