tensorflow - 当包装函数需要非浮点参数时,Keras Lambda 层出错
问题描述
我想根据文档在 Keras Lambda 层中包装一个 tensorflow 函数。但是,我的输入是复杂的 64。这是我用来复制此行为的代码的更完整示例:
import numpy as np
from keras.models import Model
from keras.layers import Input, Lambda
import tensorflow as tf
np.set_printoptions(precision=3, threshold=3, edgeitems=3)
def layer0(inp):
z = inp[0] + inp[1]
num = tf.cast(tf.real(z), tf.complex64)
return z/num
if __name__ == "__main__":
shape = (1,10,5)
z1 = Input(shape=shape[1:], dtype=np.complex64)
z2 = Input(shape=shape[1:], dtype=np.complex64)
#s = Lambda(layer0, output_shape=shape)([z1, z2])
s = Lambda(layer0)([z1, z2])
model = Model(inputs=[z1,z2], outputs=s)
z1_in = np.asarray(np.random.normal(size=shape) + np.random.normal(size=shape)*1j, 'complex64')
z2_in = np.asarray(np.random.normal(size=shape) + np.random.normal(size=shape)*1j, 'complex64')
s_out = model.predict([z1_in, z2_in])
print(s_out)
这给出了以下错误:
Traceback (most recent call last):
File "complex_lambda.py", line 32, in <module>
s = Lambda(layer0)([z1, z2])
File "complex_lambda.py", line 18, in layer0
return z/num
TypeError: x and y must have the same dtype, got tf.float32 != tf.complex64
但是,如果我改用注释行:
s = Lambda(layer0, output_shape=shape)([z1, z2])
代码运行得很好。似乎需要“output_shape=(...)”才能使 lambda 函数中的除法起作用。虽然此解决方案解决了单个输出变量的问题,但它在具有多个输出时不起作用。
解决方案
我无法复制您的问题。你用的是哪个版本的tensorflow
?您使用的是keras
包还是tensorflow.keras
子模块?
无论如何,我认为您可以通过指定layerdtype
来解决您的问题:Lambda
s = Lambda(lambda x: tf.math.real(x[0] + x[1]), dtype='complex64')([z1, s2])
推荐阅读
- angularjs - angularjs-route 未加载 angular --prod 标志(angular-cli)
- android - Change hint programmaticaly with kotlin for material.textfield
- c++ - C++::使用模板技巧访问多个私有成员
- php - 从 foreach xml 值中的 foreach 在数组中创建数组
- robotframework - 网页上的文字可见
- python - Abaqus python脚本:选择共享相同坐标的两个连接器之一
- r - 尝试在 R 中生成相关矩阵的“未使用参数”错误消息
- python - Plotly:如何查找和注释两条线之间的交点?
- timestamp - Sectigo 时间戳服务器总是使用 SHA384 而不是 SHA1
- python - 如何从列表中创建 Pandas 数据框?