tensorflow2.0 - Ray Tensorflow-gpu 2.0 RecursionError
问题描述
系统信息
操作系统平台和发行版(例如,Linux Ubuntu 16.04):Ubuntu 18.04
Ray 安装自(源代码或二进制文件):binary
射线版本:0.7.3
Python版本:3.7
TensorFlow 版本:tensorflow-gpu 2.0.0rc0
重现的确切命令:
# Importing packages
from time import time
import gym
import tensorflow as tf
import ray
# Creating our initial model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, input_shape=(24,), activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
# Setting parameters
episodes = 64
env_name = 'BipedalWalker-v2'
# Initializing ray
ray.init(num_cpus=8, num_gpus=1)
# Creating our ray function
@ray.remote
def play(weights):
actor = tf.keras.Sequential([
tf.keras.layers.Dense(64, input_shape=(24,), activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
actor = actor.set_weights(weights)
env = gym.make('BipedalWalker-v2').env
env._max_episode_steps=1e20
obs = env.reset()
for _ in range(1200):
action = actor.predict_classes(obs).flatten()[0]
action = env.action_space.sample()
obs, rt, done, info = env.step(action)
return rt
# Testing ray
start = time()
weights = model.get_weights()
weights = ray.put(weights)
results = ray.get([play.remote(weights) for i in range(episodes)])
ray.shutdown()
print('Ray done after:',time()-start)
描述问题
我正在尝试使用 Ray 使用 Tensorflow 2.0-gpu Keras 演员并行化 OpenAI 健身房环境的推出。每次我尝试使用 @ray.remote 实例化 Keras 模型时,它都会引发递归深度达到错误。我正在关注 Ray 概述的文档,建议传递权重而不是模型。我不确定我在这里做错了什么,有什么想法吗?
源代码/日志
文件“/home/jacob/anaconda3/envs/tf-2.0-gpu/lib/python3.7/site-packages/tensorflow/init.py”,第 50 行,在 getattr 模块 = self._load()
_load module = _importlib.import_module(self.name) 中的文件“/home/jacob/anaconda3/envs/tf-2.0-gpu/lib/python3.7/site-packages/tensorflow/init.py”,第 44 行
RecursionError:超出最大递归深度
解决方案
核心问题似乎是 cloudpickle(Ray 使用它来序列化远程函数并将它们发送到工作进程)无法 pickletf.keras.Sequential
类。例如,我可以重现该问题如下
import cloudpickle # cloudpickle.__version__ == '1.2.1'
import tensorflow as tf # tf.__version__ == '2.0.0-rc0'
def f():
tf.keras.Sequential
cloudpickle.loads(cloudpickle.dumps(f)) # This fails.
最后一行失败了
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
<ipython-input-23-25cc307e6227> in <module>
----> 1 cloudpickle.loads(cloudpickle.dumps(f))
~/anaconda3/lib/python3.6/site-packages/tensorflow/__init__.py in __getattr__(self, item)
48
49 def __getattr__(self, item):
---> 50 module = self._load()
51 return getattr(module, item)
52
~/anaconda3/lib/python3.6/site-packages/tensorflow/__init__.py in _load(self)
42 def _load(self):
43 """Import the target module and insert it into the parent's namespace."""
---> 44 module = _importlib.import_module(self.__name__)
45 self._parent_module_globals[self._local_name] = module
46 self.__dict__.update(module.__dict__)
... last 2 frames repeated, from the frame below ...
~/anaconda3/lib/python3.6/site-packages/tensorflow/__init__.py in __getattr__(self, item)
48
49 def __getattr__(self, item):
---> 50 module = self._load()
51 return getattr(module, item)
52
RecursionError: maximum recursion depth exceeded while calling a Python object
有趣的是,这成功了tensorflow==1.14.0
,但我想 keras 在 2.0 中已经改变了很多。
解决方法
作为一种解决方法,您可以尝试f
在单独的模块或 Python 文件中定义,例如
# helper_file.py
import tensorflow as tf
def f():
tf.keras.Sequential
然后在您的主脚本中使用它,如下所示。
import helper_file
import ray
ray.init(num_cpus=1)
@ray.remote
def use_f():
helper_file.f()
ray.get(use_f.remote())
这里的不同之处在于,当 cloudpickle 尝试序列化时use_f
,它实际上不会查看helper_file
. 当某个工作进程尝试反序列use_f
化时,该工作进程将导入helper_file
。这种额外的间接性似乎使 cloudpickle 更可靠地工作。这与使用 tensorflow 或任何库腌制函数时发生的事情相同。Cloudpickle 不会序列化整个库,它只是告诉反序列化过程导入相关库。
注意:要使其在多台机器上工作,helper_file.py
必须存在并位于每台机器上的 Python 路径上(实现此目的的一种方法是将其作为 Python 模块安装在每台机器上)。
我证实这似乎解决了您示例中的问题。修复后,我遇到了
File "<ipython-input-4-bb51dc74442c>", line 3, in play
File "/Users/rkn/Workspace/ray/helper_file.py", line 15, in play
action = actor.predict_classes(obs).flatten()[0]
AttributeError: 'NoneType' object has no attribute 'predict_classes'
但这看起来是一个单独的问题。
推荐阅读
- python - 当我点击 + 运算符并立即我想点击 * 这将变成 * 但它没有发生
- sicp - SICP 第 2 版中的练习 2.42。我的解决方案不自然吗?我的解决方案好吗?
- azure - Azure 搜索获取多条记录
- php - 如何在 Magento 2 中使用 Web API 发送 XML 格式的响应?
- python - 如何在 Plotly Python 中编辑 hovertext 标签?
- javascript - 如何使用 v-model 值为 vue-multiselect 选择复选框
- flutter - Flutter,如何解决“RenderBox 未布置:...NEEDS-PAINT”
- angular - 错误:./node_modules/angular-auth-oidc-client/fesm2015/angular-auth-oidc-client.js 4790:36-60 "export 'ɵɵngDeclareInjectable'
- python-3.x - 如何使用python在不同的价格箱中绘制customer_id的计数
- python - 将数据绘制为来自 PosgreSQL 数据库的折线图