python - 如何按照第三个参数确定的顺序构建合并两个图像的图层
问题描述
我想要一个以三个张量作为输入的层:两个 (n,m,k) 张量和一 (1) 个张量,即一个单个数字。输出应该是一个 (n,m,2k) 张量,通过简单地将前 k 个通道作为一个图像而其余的作为另一个图像来实现。现在,问题在于我们合并它们的顺序——我们是把图像一放在图像二的上面还是相反——应该由第三个输入是否大于 0 来确定。
根据我的想法,这是一个完全静态的层,没有任何可训练的参数,所以我尝试使用 Lambda 层进行排序选择,如下所示:
def image_scrambler(inp): #inp = [im1, im2, aux_input]
im1, im2, aux_input = inp[0],inp[1],inp[2]
assert aux_input==1 or aux_input==0
if aux_input==0:
return [im1, im2]
else:
return [im2,im1]
paired_images = Lambda(image_scrambler)([image_input, decoder, aux_input])
这不起作用,因为它抗议该层是动态的,需要使用 dynamic=True 构建。当我尝试这样做时,我得到一个 RecursionError 如下:
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
<ipython-input-15-a40adb50e97d> in <module>
7 return [im2,im1]
8 aux_input = Input(shape=(1))
----> 9 paired_images = Lambda(image_scrambler,dynamic=True)([image_input, decoder, aux_input])
c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
791 # TODO(fchollet): consider py_func as an alternative, which
792 # would enable us to run the underlying graph if needed.
--> 793 outputs = self._symbolic_call(inputs)
794
795 if outputs is None:
c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in _symbolic_call(self, inputs)
2126 def _symbolic_call(self, inputs):
2127 input_shapes = nest.map_structure(lambda x: x.shape, inputs)
-> 2128 output_shapes = self.compute_output_shape(input_shapes)
2129
2130 def _make_placeholder_like(shape):
c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\utils\tf_utils.py in wrapper(instance, input_shape)
304 if input_shape is not None:
305 input_shape = convert_shapes(input_shape, to_tuples=True)
--> 306 output_shape = fn(instance, input_shape)
307 # Return shapes from `fn` as TensorShapes.
308 if output_shape is not None:
c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\layers\core.py in compute_output_shape(self, input_shape)
808 with context.eager_mode():
809 try:
--> 810 return super(Lambda, self).compute_output_shape(input_shape)
811 except NotImplementedError:
812 raise NotImplementedError(
c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in compute_output_shape(self, input_shape)
552 try:
553 if self._expects_training_arg:
--> 554 outputs = self(inputs, training=False)
555 else:
556 outputs = self(inputs)
... last 5 frames repeated, from the frame below ...
c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
791 # TODO(fchollet): consider py_func as an alternative, which
792 # would enable us to run the underlying graph if needed.
--> 793 outputs = self._symbolic_call(inputs)
794
795 if outputs is None:
RecursionError: maximum recursion depth exceeded while calling a Python object
所以这实际上并没有告诉我为什么它不起作用,它只是崩溃了。
如果有任何方法可以让不太复杂的方法工作,我宁愿不必摆弄从 Layer 继承的 layer 类。
解决方案
始终使用“张量函数”而不是“Python 函数”:
import keras.backend as K
def image_scrambler(inp): #inp = [im1, im2, aux_input]
im1, im2, aux_input = inp[0],inp[1],inp[2]
is_greater = K.greater(aux_input, 0.5)
return K.switch(is_greater, #this is a keras "if"
K.concatenate([img2, img1]), #result if true
K.concatenate([img1, img2])) #result if false
paired_images = Lambda(image_scrambler)([image_input, decoder, aux_input])
我认为断言不是一个好主意,你应该在检查数据时这样做,而不是在模型中。
尽管您说它不可训练,但您可能期望它以某种方式可训练?是什么决定了aux_input
will 的价值?如果您希望在其他地方学习它,我怀疑它会起作用。也许它应该是某个地方的 sigmoid 给出的连续值。然后它“可能”有机会工作,尽管 if 部分会破坏(但不会破坏)反向传播。
推荐阅读
- javascript - 如何在 Windows 中更新 npm?
- c# - 在枚举 c# 中封装字符串
- vba - 显示所有记录的消息框条件过滤器
- powershell - PowerShell:如何使用remove-item批量删除文件?
- c++ - std 库等价于 boost::upgrade_lock 和 boost::upgrade_to_unique_lock
- asp.net-mvc - 你调用的对象是空的。发送数据时出现异常哟通过viewmodel查看
- c++ - lambda捕获中的值变成了常量?
- apache-spark - 根据列值加入
- c++ - 如何包含两次具有静态变量定义的标题?
- javascript - 在所有页面中显示搜索查询结果