python - Keras multi_gpu_model 返回错误“tensorflow_core._api.v2.config”没有属性“experimental_list_devices”
问题描述
我在我的两个 gpu 上训练我的 unet 模型时遇到问题,
该模型是一个简单的 U-net 实现,我知道它是有效的,因为它的 testet 不是 multi_gpu_model
train_generator = zip(image_generator, mask_generator)
with tf.device("/cpu:0"):
# initialize the model
model = unet((512,512,3))
# make the model parallel
model = multi_gpu_model(model, gpus=2)
model.compile(optimizer='adam', loss="mean_squared_error")
model.fit_generator(train_generator, steps_per_epoch=250, epochs=10)
:output
File "C:/Users/PycharmProjects/U-net/U-net.py", line 29, in <module>
model = multi_gpu_model(model, gpus=2)
File "C:\Users\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\utils\multi_gpu_utils.py", line 150, in multi_gpu_model
available_devices = _get_available_devices()
File "C:\Users\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\utils\multi_gpu_utils.py", line 16, in _get_available_devices
return K.tensorflow_backend._get_available_gpus() + ['/cpu:0']
File "C:\Users\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\backend\tensorflow_backend.py", line 506, in _get_available_gpus
_LOCAL_DEVICES = tf.config.experimental_list_devices()
AttributeError: module 'tensorflow_core._api.v2.config' has no attribute 'experimental_list_devices'
我也试过 tf.distributed.mirroredStrategy() 但也没有运气
任何帮助将不胜感激
解决方案
experimental_list_devices 在 tf 2.1 中已弃用,使用 tf.config.list_logical_devices 替换。
def _get_available_gpus():
"""Get a list of available gpu devices (formatted as strings).
# Returns
A list of available GPU devices.
"""
global _LOCAL_DEVICES
if _LOCAL_DEVICES is None:
if _is_tf_1():
devices = get_session().list_devices()
_LOCAL_DEVICES = [x.name for x in devices]
else:
devices = tf.config.list_logical_devices()
_LOCAL_DEVICES = [x.name for x in devices]
return [x for x in _LOCAL_DEVICES if 'device:gpu' in x.lower()]
该链接有助于解决您的问题LINK
推荐阅读
- python - 设计通知系统数据库的正确方法是什么?
- javascript - 如何按字符串值过滤 mongoose 中的数组,然后使用 res.render 传递它...?
- typescript - RxJS 和 Typescript 捕获相同类型的错误
- python - SQLAlchemy:从引擎获取数据库名称
- java - 运行 java -jar "Spring Application" 时出现问题
- ruby-on-rails - 为什么 rails 资源不生成编辑和新路径?
- sql - 在 Postgres 中组合依赖于不同事务隔离级别的代码
- dev-c++ - DevC++:我如何正确编程迭代 STL 集合中的所有元素?
- c++ - 在函数调用中定义一个匿名结构
- java - 当我尝试在 Spring Boot 和 MVC 中使用获取映射时出现异常