python - Keras 2.2.4 如何从 Keras 1.xx 复制 merge()
问题描述
我正在尝试使用 TensorFlow 后端将 Keras 1.xx 代码转换为 2.2.x。
我在 Keras 1.xx 中有以下内容,它接受以下输入:
org_image
3 个 RGB 颜色通道上的 256x256 图像shape=(256,256,3)
mask
1 B/W 颜色通道上的 256x256 蒙版shape=(256,256,1)
我希望将图像与蒙版结合起来,以获得缺少蒙版区域的新裁剪图像。为此,我首先取mask
using的倒数1 - mask
,其中1
是一个张量。然后我逐元素相乘org_image * (1 - mask)
以获得新裁剪的图像。Keras 1.xx 中的代码如下所示
from keras.layers import Input, merge
input_shape = (256,256,3)
org_img = Input(shape=input_shape)
mask = Input(shape=(input_shape[0], input_shape[1], 1))
input_img = merge([org_img, mask],
mode=lambda x: x[0] * (1 - x[1]),
output_shape=input_shape)
在 Keras 2.2.xa 中引入了重大更改,将merge()
函数替换为Add()
、Subtract()
、Multiply()
...等。前面merge()
的说服力mode=lambda x: x[0] * (1 - x[1])
等于mode=lambda [org_img, mask]: org_img * (1 - mask)
。
如何1 - mask
在 Keras 2.2.x 中复制?我需要导入tf.backend.ones
吗?
或者也许我需要tf.enable_eager_execution()
?
我对此很陌生,所以我知道很多事情都在我头上。如果有人能澄清我的误解在哪里,我将不胜感激,谢谢!
解决方案
Lambda
为自定义函数或 lambda 表达式使用层:
input_img = Lambda(lambda x: x[0] * (1 - x[1]), output_shape=input_shape)([org_img, mask])
output_shape
如果您使用 tensorflow 作为后端,则where是可选的。
其他有用的层:
Concatenate(axis=...)(list_of_inputs)
Add()(list_of_inputs)
Multiply()(list_of_inputs)
推荐阅读
- r - R:如何按他们的组对箱线图上的样本进行着色?
- android - libgdx 内存分配在动画上失败
- akka - Unable to resolve akka.pattern.AskTimeoutException: Ask timed out
- linux - 内核如何管理虚拟内存
- css - Django FilteredSelectMultiple 右半部分不渲染
- python - 有什么方法可以在 PIL 中指定颜色(用于绘制多边形)而不会出错?
- android - Kotlin 中的 URI 解析不起作用
- opencl - 在主机和内核中同时使用 OpenCL 缓冲区
- ruby-on-rails - application.rb 语法错误,意外 ',',期待 ')'
- python - 插件模板中的占位符未显示 - djangocms