首页 > 解决方案 > Tensorflow Estimator API 在分布式模式下不起作用

问题描述

这是我的测试代码`

from tensorflow.python.keras.layers import Conv1D, MaxPooling1D
from tensorflow.python.keras.models import Model
import logging
level = logging.getLevelName('INFO')
logging.getLogger().setLevel(level)
model = tf.keras.Sequential()
output = Dense(2, activation="softmax")
model.add(Dense(64, activation="relu", input_shape=(10,)))
model.add(output)
model.compile('rmsprop', 'categorical_crossentropy')
est_model = tf.keras.estimator.model_to_estimator(keras_model=model)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"dense_2_input": np.random.randint(10, size=(320, 10))},
        y=np.random.rand(320, 2),
        num_epochs=10000,
        shuffle=False)
est_model.train(train_input_fn)

我的 TF_CONFIG 是`

TF_CONFIG={
"cluster": {"chief": ["localhost:2223"], 
"worker": ["localhost:2221"], 
"ps": ["lcoalhost:2222"]}, 
"task": {"index": "0", "type": "chief"}
}

主管卡在日志记录上Restoring paramater from ...... ,没有端口在监听。

有什么建议吗?

标签: tensorflowtensorflow-estimator

解决方案


推荐阅读