首页 > 解决方案 > Keras 2.2.4 如何从 Keras 1.xx 复制 merge()

问题描述

我正在尝试使用 TensorFlow 后端将 Keras 1.xx 代码转换为 2.2.x。

我在 Keras 1.xx 中有以下内容,它接受以下输入:

我希望将图像与蒙版结合起来,以获得缺少蒙版区域的新裁剪图像。为此,我首先取maskusing的倒数1 - mask,其中1是一个张量。然后我逐元素相乘org_image * (1 - mask)以获得新裁剪的图像。Keras 1.xx 中的代码如下所示

from keras.layers import Input, merge

input_shape = (256,256,3)

org_img = Input(shape=input_shape)
mask = Input(shape=(input_shape[0], input_shape[1], 1))
input_img = merge([org_img, mask],
                   mode=lambda x: x[0] * (1 - x[1]),
                   output_shape=input_shape)

在 Keras 2.2.xa 中引入了重大更改,将merge()函数替换为Add()Subtract()Multiply()...等。前面merge()的说服力mode=lambda x: x[0] * (1 - x[1])等于mode=lambda [org_img, mask]: org_img * (1 - mask)

如何1 - mask在 Keras 2.2.x 中复制?我需要导入tf.backend.ones吗?

或者也许我需要tf.enable_eager_execution()

我对此很陌生,所以我知道很多事情都在我头上。如果有人能澄清我的误解在哪里,我将不胜感激,谢谢!

标签: pythontensorflowmergekerasdeprecated

解决方案


Lambda为自定义函数或 lambda 表达式使用层:

input_img = Lambda(lambda x: x[0] * (1 - x[1]), output_shape=input_shape)([org_img, mask])

output_shape如果您使用 tensorflow 作为后端,则where是可选的。

其他有用的层:

  • Concatenate(axis=...)(list_of_inputs)
  • Add()(list_of_inputs)
  • Multiply()(list_of_inputs)

推荐阅读