python - 尝试使用估算器进行分布式训练时如何设置 STANDALONE_CLIENT 模式?
问题描述
编码:
job_name = FLAGS.job_name
task_index = FLAGS.task_index
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
chief_hosts = FLAGS.chief_hosts.split(",")
evaluator_hosts = FLAGS.evaluator_hosts.split(",")
tf.logging.info('Chief host is :%s' % chief_hosts)
tf.logging.info('PS hosts are: %s' % ps_hosts)
tf.logging.info('Worker hosts are: %s' % worker_hosts)
tf.logging.info('eval hosts are: %s' % evaluator_hosts)
cluster = {'chief': chief_hosts, "ps": ps_hosts,
"worker": worker_hosts}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': job_name, 'index':
task_index}})
dist_strategy = tf.contrib.distribute.MirroredStrategy(
num_gpus=FLAGS.n_gpus,
cross_device_ops=AllReduceCrossDeviceOps('nccl', num_packs=FLAGS.n_gpus),
# cross_device_ops=AllReduceCrossDeviceOps('hierarchical_copy'),
)
log_every_n_steps = 8
run_config = RunConfig(
train_distribute=dist_strategy,
eval_distribute=dist_strategy,
log_step_count_steps=log_every_n_steps,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps)
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=FLAGS.num_train_steps,
num_warmup_steps=FLAGS.num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = Estimator(
model_fn=model_fn,
params={},
config=run_config)
日志:
INFO:tensorflow:RunConfig initialized for Distribute Coordinator with INDEPENDENT_WORKER mode
错误:
ValueError: Only `STANDALONE_CLIENT` mode is supported when you call `estimator.train`
版本:
Linux OS
TF 1.15 or 1.14
解决方案
TensorFlow 1.15
对于多 GPU 多机,使用tf.estimator.train_and_evaluate
代替estimator.train
和删除tf.contrib.distribute.MirroredStrategy
.
对于多 GPU-1 机器,使用estimator.train
.
推荐阅读
- laravel - 无法在 Laravel 的单元测试中禁用 Msurguy 蜜罐
- python - 将 UTC 时间戳(或自纪元以来的秒数)转换为本地日期/时间,结合 DST,使用纬度+经度
- javascript - 上次重置前的请求数
- bash - FFMPEG GIF 没有循环,没有透明度
- algorithm - 子句包含算法
- google-cloud-platform - Google BigQuery 数据集上 getIamPolicy 的正确 Cloud Resource Manager URI 是什么
- r - 使用 ROAuth 和 Twitter 的控制台中的直接身份验证问题
- ms-access - 访问查询不适用于 CheckBox 值 true
- android - 在 Pie 中使用 startActivityForResult 打开设置时删除向上按钮
- c# - 如何执行作为参数传递给模拟服务的通用函数