python - ValueError:应在动态构造图的输入列表上调用合并层
问题描述
我的要求与替换占位符 for tensorflow v2中的要求相同。我从@AlexisBRENON(不是兼容模式下的代码)获取了为 Tensorflow 2 重构的代码并逐字运行。我收到一个ValueError: A merge layer should be called on a list of inputs
指向multiply
此片段部分的错误:
make_dict[outg[0]] = tf.keras.layers.add([
make_dict[outg[0]],
tf.keras.layers.multiply(
[[outg[1]], make_dict[queue[0]]],
)],
)
那里可能出了什么问题?从那时起 API 是否发生了变化?
我尝试包裹[[outg[1]], make_dict[queue[0]]]
,tf.keras.layers.concatenate
因为它是我能找到的最接近合并层的东西。它导致了另一个错误ValueError: A Concatenate layer should be called on a list of at least 2 inputs
,即使列表中看起来有两个项目。
我正在使用 TensorFlow 2.4.1
解决方案
为了使旧答案起作用,我们需要进行以下更改。
multiply在 tensorflow 版本中设计2.4.1
为仅接收张量作为列表中的输入。因此,我们需要将计算图的恒定权重更改为有效张量。我不确定它在原始答案中是如何工作的,这是现在的工作示例。
唯一真正的改变outg[1]
是K.constant([outg[1]])
import tensorflow as tf
from keras import backend as K
def construct_graph(graph_dict, inputs, outputs):
queue = inputs[:]
make_dict = {}
for key, val in graph_dict.items():
if key in inputs:
# Use keras.Input instead of placeholders
make_dict[key] = tf.keras.Input(name=key, shape=(), dtype=tf.dtypes.float32)
else:
make_dict[key] = None
# Breadth-First search of graph starting from inputs
print(make_dict)
while len(queue) != 0:
cur = graph_dict[queue[0]]
for outg in cur["outgoing"]:
#print(outg)
if make_dict[outg[0]] is not None: # If discovered node, do add/multiply operation
#print(type(make_dict[outg[0]]))
print(outg[1])
make_dict[outg[0]] = tf.keras.layers.add([
make_dict[outg[0]],
tf.keras.layers.multiply(
[ K.constant([outg[1]]), make_dict[ queue[0] ] ]
)],
)
else: # If undiscovered node, input is just coming in multiplied and add outgoing to queue
make_dict[outg[0]] = tf.keras.layers.multiply(
[make_dict[queue[0]], K.constant([outg[1]]) ]
)
for outgo in graph_dict[outg[0]]["outgoing"]:
queue.append(outgo[0])
queue.pop(0)
# Returns one data graph for each output
model_inputs = [make_dict[key] for key in inputs]
model_outputs = [make_dict[key] for key in outputs]
return [tf.keras.Model(inputs=model_inputs, outputs=o) for o in model_outputs]
def make_graph():
graph_def = {
"B": {
"incoming": [],
"outgoing": [("A", 1.0)]
},
"C": {
"incoming": [],
"outgoing": [("A", 1.0)]
},
"A": {
"incoming": [("B", 2.0), ("C", -1.0)],
"outgoing": [("D", 3.0)]
},
"D": {
"incoming": [("A", 2.0)],
"outgoing": []
}
}
outputs = construct_graph(graph_def, ["B", "C"], ["A"])
print("Builded models:", outputs)
for o in outputs:
o.summary(120)
print("Output:", o((1.0, 1.0)))
make_graph()
给出输出
Model: "model"
________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
========================================================================================================================
B (InputLayer) [(None,)] 0
________________________________________________________________________________________________________________________
C (InputLayer) [(None,)] 0
________________________________________________________________________________________________________________________
multiply_33 (Multiply) (None,) 0 B[0][0]
________________________________________________________________________________________________________________________
multiply_34 (Multiply) (None,) 0 C[0][0]
________________________________________________________________________________________________________________________
add (Add) (None,) 0 multiply_33[0][0]
multiply_34[0][0]
========================================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
________________________________________________________________________________________________________________________
Output: tf.Tensor([2.], shape=(1,), dtype=float32)
推荐阅读
- r - 删除 paste0 输出中的单个反斜杠
- sql - 将多列插入映射的 1-1 单列
- reactjs - React Native onChangeText 就像 ReactJS 中的 onChange
- java - 灰熊 PRESERVE_HEADER_CASE
- javascript - 附加时节点从循环中消失
- angular - 在 Angular 7 中,如何访问发出事件的组件?
- c# - MVC 登录表单提交挂起而不点击代码
- go - 在 ElasticSearch 上使用特定搜索类型进行分页
- flutter - 颤动Transform.translate动画从和到小部件的中心
- c# - 当尝试从 AppointmentItem 获取 GlobalAppointmentId 时,它有时会返回 null