首页 > 解决方案 > 当包装函数需要非浮点参数时,Keras Lambda 层出错

问题描述

我想根据文档在 Keras Lambda 层中包装一个 tensorflow 函数。但是,我的输入是复杂的 64。这是我用来复制此行为的代码的更完整示例:

import numpy as np
from keras.models import Model
from keras.layers import Input, Lambda
import tensorflow as tf
np.set_printoptions(precision=3, threshold=3, edgeitems=3)


def layer0(inp):
    z = inp[0] + inp[1]
    num = tf.cast(tf.real(z), tf.complex64)
    return z/num


if __name__ == "__main__":

    shape = (1,10,5)
    z1 = Input(shape=shape[1:], dtype=np.complex64)
    z2 = Input(shape=shape[1:], dtype=np.complex64)

    #s = Lambda(layer0, output_shape=shape)([z1, z2])
    s = Lambda(layer0)([z1, z2])

    model = Model(inputs=[z1,z2], outputs=s)

    z1_in = np.asarray(np.random.normal(size=shape) + np.random.normal(size=shape)*1j, 'complex64')
    z2_in = np.asarray(np.random.normal(size=shape) + np.random.normal(size=shape)*1j, 'complex64')

    s_out = model.predict([z1_in, z2_in])
    print(s_out)

这给出了以下错误:

Traceback (most recent call last):
  File "complex_lambda.py", line 32, in <module>
    s = Lambda(layer0)([z1, z2])
  File "complex_lambda.py", line 18, in layer0
    return z/num
TypeError: x and y must have the same dtype, got tf.float32 != tf.complex64

但是,如果我改用注释行: s = Lambda(layer0, output_shape=shape)([z1, z2])

代码运行得很好。似乎需要“output_shape=(...)”才能使 lambda 函数中的除法起作用。虽然此解决方案解决了单个输出变量的问题,但它在具有多个输出时不起作用。

标签: tensorflowlambdakeraslayer

解决方案


我无法复制您的问题。你用的是哪个版本的tensorflow?您使用的是keras包还是tensorflow.keras子模块?

无论如何,我认为您可以通过指定layerdtype来解决您的问题:Lambdas = Lambda(lambda x: tf.math.real(x[0] + x[1]), dtype='complex64')([z1, s2])


推荐阅读