python - Keras `multi_gpu_model` 使用导致错误 `yolo_head` 未定义
问题描述
我有一个 keras_yolo python 实现,我试图让学习跨多个 GPU 工作,而 multi_gpu_mode 选项听起来是一个不错的起点。
但是,我的问题是相同的代码在单个 CPU/GPU 设置中工作得很好,但由于 NameError 失败:名称 'yolo_head' 在作为 multi_gpu_mode 模型运行时未定义。完整的堆栈:
parallel_model = multi_gpu_model(model, cpu_relocation=True)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/utils/multi_gpu_utils.py", line 200, in multi_gpu_model
model = clone_model(model)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/models.py", line 251, in clone_model
return _clone_functional_model(model, input_tensors=input_tensors)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/models.py", line 152, in _clone_functional_model
layer(computed_tensors, **kwargs))
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/engine/base_layer.py", line 457, in __call__
output = self.call(inputs, **kwargs)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/layers/core.py", line 687, in call
return self.function(inputs, **arguments)
File "/mnt/data/DeepLeague/YAD2K/yad2k/models/keras_yolo.py", line 199, in yolo_loss
pred_xy, pred_wh, pred_confidence, pred_class_prob = yolo_head(
这是定义的链接yolo_head
:https ://github.com/farzaa/DeepLeague/blob/c87fcd89d9f9e81421609eb397bf95433270f0e2/YAD2K/yad2k/models/keras_yolo.py#L66
我还没有深入研究multi_gpu_model
代码以了解复制在后台是如何工作的,并且希望避免这样做。
解决方案
问题是因为 Keras 中使用的 lambda 中的自定义导入必须在引用它的函数中显式导入。
例如。在这种情况下,yolo_head
必须在 'yolo_loss' 的功能级别上“重新导入”,如下所示:
def yolo_loss(args, anchors, num_classes, rescore_confidence=False, print_loss=False):
from yad2k.models.keras_yolo import yolo_head
推荐阅读
- service-worker - Service Worker:如何缓存远程 Web 服务的调用,以便当我再次调用它时,它将从缓存中加载以使我的网站加载更快?
- javascript - 三元运算符,希望在 JavaScript 中使用多个三元运算符
- php - 将下拉选择发布为链接
- typescript - 如何为 html 元素列表设置类型?
- javascript - Discord JS - 音频播放器状态
- python - Numpy比较两个数组,每个条目取较小的一个
- github - 在子模块路径中运行操作?
- reactjs - 创建反应库 - 资产包
- python - 如何使用来自另一个字典的聚合值创建新字典
- java - Apache camel onException 在原始消息中添加错误详细信息