python - TensorFlow-1 分布式底层代码和 Ray
问题描述
我正在尝试分发我使用 Ray 和 Tensorflow 1 构建的深度强化学习系统的训练。同时,我正在使用 ray,因为我有很多代码可以并行化与训练没有直接关系的逻辑,我想使用 tf. 并行化训练(即不同 GPU 上不同工作人员的梯度减少)。分发实用程序,主要是因为它可以使用 NCCL 通信库,我认为与其他方法相比,它会提高训练速度。
问题是我不想重构我的 tensorflow 代码(在低级别用旧的 Tensorflow 1 编写,带有自定义训练循环,我没有使用任何像 Keras 这样的 API),但我不知道如何使用,tf.distribute
即 MirrorStrategy,使用 Tensorflow 1 代码分发训练。
我在 Tensorflow 1 中找到了本指南tf.distribute
,但即使在自定义循环中,他们也在使用 Keras API 进行模型和优化器构建。我正在尝试尽可能遵循本指南,以构建一个使用我在主项目中使用的库/API 的简单示例,但我无法使其正常工作。
示例代码是这样的:
import numpy as np
import tensorflow.compat.v1 as tf
import ray
tf.disable_v2_behavior()
@ray.remote(num_cpus=1, num_gpus=0)
class Trainer:
def __init__(self, local_data):
tf.disable_v2_behavior()
self.current_w = 1.0
self.local_data = local_data
self.strategy = tf.distribute.MirroredStrategy()
with self.strategy.scope():
self.w = tf.Variable(((1.0)), dtype=tf.float32)
self.x = tf.placeholder(shape=(None, 1), dtype=tf.float32)
self.y = self.w * self.x
self.grad = tf.gradients(self.y, [self.w])
def train_step_opt():
def grad_fn():
grad = tf.gradients(self.y, [self.w])
return grad
per_replica_grad = self.strategy.experimental_run_v2(grad_fn)
red_grad = self.strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_grad, axis=None)
minimize = self.w.assign_sub(red_grad[0])
return minimize
self.minimize = self.strategy.experimental_run_v2(train_step_opt)
def train_step(self):
with self.strategy.scope():
with tf.Session() as sess:
sess.run(self.minimize, feed_dict={self.x: self.local_data})
self.current_w = sess.run(self.w)
return self.current_w
ray.init()
data = np.arange(4) + 1
data = data.reshape((-1, 1))
data_w = [data[None, i] for i in range(4)]
trainers = [Trainer.remote(d) for d in data_w]
W = ray.get([t.train_step.remote() for t in trainers])[0]
print(W)
它应该简单地计算线性函数在不同过程中的导数,将所有导数减少为单个值并将其应用于唯一参数“w”。
当我运行它时,我收到以下错误:
Traceback (most recent call last):
File "dtfray.py", line 49, in <module>
r = ray.get([t.train_step.remote() for t in trainers])
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
return func(*args, **kwargs)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/ray/worker.py", line 1456, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TypeError): ray::Trainer.train_step() (pid=25401, ip=10.128.0.46)
File "python/ray/_raylet.pyx", line 439, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 473, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 476, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
File "dtfray.py", line 32, in __init__
self.minimize = self.strategy.experimental_run_v2(train_step_opt)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 957, in experimental_run_v2
return self.run(fn, args=args, kwargs=kwargs, options=options)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 951, in run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2290, in call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 770, in _call_for_each_replica
fn, args, kwargs)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 201, in _call_for_each_replica
coord.join(threads)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/six.py", line 703, in reraise
raise value
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 998, in run
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
File "/home/Adrian/miniconda3/envs/sukan_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 282, in wrapper
return func(*args, **kwargs)
File "dtfray.py", line 22, in train_step_opt
tf.distribute.get_replica_context().merge_call()
TypeError: merge_call() missing 1 required positional argument: 'merge_fn'
解决方案
如以下源代码部分中突出显示的那样:
def train_step_opt():
def grad_fn():
grad = tf.gradients(self.y, [self.w])
return grad
per_replica_grad = self.strategy.experimental_run_v2(grad_fn)
red_grad = self.strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_grad, axis=None)
minimize = self.w.assign_sub(red_grad[0])
return minimize
self.minimize = self.strategy.experimental_run_v2(train_step_opt)
您必须在 line之后为 train_step_opt 再减少一次 tf.distribute.MirroredStrategy() 的结果 self.minimize = self.strategy.experimental_run_v2(train_step_opt)
,因为您调用self.strategy.experimental_run_v2()
了两次,一次 ontrain_step_opt
然后 on grad_fn
。
此外,您可以查看TF Github repo 的 mirrored_run.py 文件的第 178 行到第 188 行的以下部分,因为get_replica_context() 是针对跨副本上下文触发的:
在_MirroredReplicaThread ( ) 线程上设置
fn
启动事件时。执行等待直到被设置,这表明要么完成要么调用 a。如果 是完整的,则设置为 True。否则,来自所有暂停线程的参数将被分组并执行。然后将的结果 设置为。当事件再次重置时,每个此类调用都会返回该 线程的 。执行简历。should_run
MRT
MRT.has_paused
fn
get_replica_context().merge_call()
fn
MRT.done
get_replica_context().merge_call
merge_fn
get_replica_context().merge_call
MRT.merge_result
get_replica_context().merge_call
MRT.merge_result
MRT.should_run
fn
推荐阅读
- javascript - 在 vue.js 组件中声明 typescript 接口道具
- json - 拆分两个字符串并配对对应的部分
- julia - 为什么 julia 需要很长时间才能导入包?
- node.js - 如何使用 JWT 为我的 Google Cloud Functions 实现 POST 请求
- mysql - Dbeaver 自动将字母放在我的查询后面
- jenkins - 从 Jenkins 控制台输出中提取错误描述
- android - 用户成功登录后需要 android 应用程序启动主要活动
- python - 所有元素在 Kivy + KivyMD 中渲染两次
- ios - 使用更长的字符串本地化 ios 应用程序
- java - 如何正确使用带有 libGDX 的 open gl 调用