首页 > 解决方案 > 如何按照第三个参数确定的顺序构建合并两个图像的图层

问题描述

我想要一个以三个张量作为输入的层:两个 (n,m,k) 张量和一 (1) 个张量,即一个单个数字。输出应该是一个 (n,m,2k) 张量,通过简单地将前 k 个通道作为一个图像而其余的作为另一个图像来实现。现在,问题在于我们合并它们的顺序——我们是把图像一放在图像二的上面还是相反——应该由第三个输入是否大于 0 来确定。

根据我的想法,这是一个完全静态的层,没有任何可训练的参数,所以我尝试使用 Lambda 层进行排序选择,如下所示:

def image_scrambler(inp): #inp = [im1, im2, aux_input]
    im1, im2, aux_input = inp[0],inp[1],inp[2]
    assert aux_input==1 or aux_input==0
    if aux_input==0:
        return [im1, im2]
    else:
        return [im2,im1]
paired_images = Lambda(image_scrambler)([image_input, decoder, aux_input])

这不起作用,因为它抗议该层是动态的,需要使用 dynamic=True 构建。当我尝试这样做时,我得到一个 RecursionError 如下:

---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-15-a40adb50e97d> in <module>
      7         return [im2,im1]
      8 aux_input = Input(shape=(1))
----> 9 paired_images = Lambda(image_scrambler,dynamic=True)([image_input, decoder, aux_input])

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
    791             # TODO(fchollet): consider py_func as an alternative, which
    792             # would enable us to run the underlying graph if needed.
--> 793             outputs = self._symbolic_call(inputs)
    794 
    795           if outputs is None:

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in _symbolic_call(self, inputs)
   2126   def _symbolic_call(self, inputs):
   2127     input_shapes = nest.map_structure(lambda x: x.shape, inputs)
-> 2128     output_shapes = self.compute_output_shape(input_shapes)
   2129 
   2130     def _make_placeholder_like(shape):

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\utils\tf_utils.py in wrapper(instance, input_shape)
    304     if input_shape is not None:
    305       input_shape = convert_shapes(input_shape, to_tuples=True)
--> 306     output_shape = fn(instance, input_shape)
    307     # Return shapes from `fn` as TensorShapes.
    308     if output_shape is not None:

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\layers\core.py in compute_output_shape(self, input_shape)
    808       with context.eager_mode():
    809         try:
--> 810           return super(Lambda, self).compute_output_shape(input_shape)
    811         except NotImplementedError:
    812           raise NotImplementedError(

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in compute_output_shape(self, input_shape)
    552           try:
    553             if self._expects_training_arg:
--> 554               outputs = self(inputs, training=False)
    555             else:
    556               outputs = self(inputs)

... last 5 frames repeated, from the frame below ...

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
    791             # TODO(fchollet): consider py_func as an alternative, which
    792             # would enable us to run the underlying graph if needed.
--> 793             outputs = self._symbolic_call(inputs)
    794 
    795           if outputs is None:

RecursionError: maximum recursion depth exceeded while calling a Python object

所以这实际上并没有告诉我为什么它不起作用,它只是崩溃了。

如果有任何方法可以让不太复杂的方法工作,我宁愿不必摆弄从 Layer 继承的 layer 类。

标签: pythonkerasneural-networkkeras-layer

解决方案


始终使用“张量函数”而不是“Python 函数”:

import keras.backend as K

def image_scrambler(inp): #inp = [im1, im2, aux_input]
    im1, im2, aux_input = inp[0],inp[1],inp[2]

    is_greater = K.greater(aux_input, 0.5)
    return K.switch(is_greater,                     #this is a keras "if"
                    K.concatenate([img2, img1]),    #result if true
                    K.concatenate([img1, img2]))    #result if false

paired_images = Lambda(image_scrambler)([image_input, decoder, aux_input])

我认为断言不是一个好主意,你应该在检查数据时这样做,而不是在模型中。

尽管您说它不可训练,但您可能期望它以某种方式可训练?是什么决定了aux_inputwill 的价值?如果您希望在其他地方学习它,我怀疑它会起作用。也许它应该是某个地方的 sigmoid 给出的连续值。然后它“可能”有机会工作,尽管 if 部分会破坏(但不会破坏)反向传播。


推荐阅读