python - 多输入 GAN 返回错误 ValueError: Graph disconnected:
问题描述
整个周末都在尝试解决这个问题。我希望有人能帮忙。
我有一个模型可以采用普通数组并在 GAN 中对其进行处理,它可以工作,但是一旦我将其更改为多输入,我就开始得到:
ValueError: Graph disconnected:
我的原始代码:
# Build stacked GAN model
gan_input = Input(shape=Xtrain.shape[1])
H = generator(gan_input)
gd_input=Concatenate()([gan_input,H])
gan_V = discriminator(gd_input)
GAN = Model(gan_input, [gan_V,H])
GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
GAN.summary()
然后我将其修改为多输入:
gan_dataframe_input = Input(shape=Xtrain[1][:-2].shape) #new testing
numpy_input = Input(shape=Xtrain[1][-1].shape)
gan_input = layers.concatenate([gan_dataframe_input, numpy_input])
print(gan_input)
print(mergedLayer)
H = generator([gan_dataframe_input,numpy_input]) <<--two shapes being imputed
gd_input=Concatenate()([gan_input,H]) <<--merged layer + above two shapes being imputed
gan_V = discriminator(gd_input)
GAN = Model(gan_input, [gan_V,H]) <<--this line returns an error
GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
GAN.summary()
堆栈跟踪:
KerasTensor(type_spec=TensorSpec(shape=(None, 736), dtype=tf.float32, name=None), name='concatenate_28/concat:0', description="created by layer 'concatenate_28'")
KerasTensor(type_spec=TensorSpec(shape=(None, 736), dtype=tf.float32, name=None), name='concatenate_27/concat:0', description="created by layer 'concatenate_27'")
WARNING:tensorflow:Functional model inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to "model_34" was not an Input tensor, it was generated by layer concatenate_28.
Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.
The tensor that caused the issue was: concatenate_28/concat:0
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-94-ac83091846e6> in <module>()
69 gd_input=Concatenate()([gan_input,H])
70 gan_V = discriminator(gd_input)
---> 71 GAN = Model(gan_input, [gan_V,H])
72 GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
73 GAN.summary()
4 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in _map_graph_network(inputs, outputs)
988 'The following previous layers '
989 'were accessed without issue: ' +
--> 990 str(layers_with_complete_input))
991 for x in nest.flatten(node.outputs):
992 computable_tensors.add(id(x))
ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 659), dtype=tf.float32, name='input_71'), name='input_71', description="created by layer 'input_71'") at layer "concatenate_28". The following previous layers were accessed without issue: []
奇怪地看着完整的轨迹,在我在图层上打印数据后,数组中的项目数似乎没有对齐?(659,) 是其中一个输入的大小,而另一个是 (77,)。我不确定我在这里做错了什么。有什么建议么?
解决方案
当您构建多输入/多输出模型时,您必须编译并将模型输入和输出作为数组提供,而不是像您那样连接它们。此外,模型的输入必须始终为tf.keras.layers.Input
。所以正确的代码是
gan_dataframe_input = Input(shape=Xtrain[1][:-2].shape) #new testing
numpy_input = Input(shape=Xtrain[1][-1].shape)
gan_input = layers.concatenate([gan_dataframe_input, numpy_input])
print(gan_input)
print(mergedLayer)
H = generator([gan_dataframe_input,numpy_input]) <<--two shapes being imputed
gd_input=Concatenate()([gan_input,H]) <<--merged layer + above two shapes being imputed
gan_V = discriminator(gd_input)
GAN = Model([gan_dataframe_input, numpy_input ], [gan_V,H]) <<--this line is modified
GAN.compile(loss=['categorical_crossentropy','mse'], optimizer=opt) #Complete GAN have both loss functions
GAN.summary()
推荐阅读
- bash - 如果语句未正确评估
- php - 为 Ansible Composer 模块定义 PHP 版本
- ios - 如何在 swift 中使打开 collectionView 的动画效果(如视频)
- c# - 错误:操作必须使用可更新查询
- node.js - 如何在 postgres 中进行多个查询
- ios - 当检测到的图像在 AR 中消失时,如何停止 videoNode?
- python - 如何使它成为帕斯卡三角形。表示第一行是 1,然后是第二行是 1,1,第三行是 1,2,1
- amazon-web-services - 从客户端调用 API 时出现错误代码 500
- omnet++ - Omnetpp.ini - 如何为主机参数创建循环
- android - 如何为 Facebook 登录设置登台和生产应用程序?