python - 将输入输入到 tensorflow.keras 模型的中间层
问题描述
我正在尝试使用 tensorflow.keras.applications EfficientNetB0 实现一个 hydranet 架构。该架构的目标是将网络分成两部分(第一部分:主干,第二部分:头部)。然后,输入图像应该只提供一次到主干,并且应该存储它的输出。之后,这个输出应该直接馈送到头部(根据要分类的类的数量可以不止一个)。最佳方法:
- 我不想为每个头重绘整个模型。
- 主干应该只执行一次。
如果已经查看此论坛帖子:keras-give-input-to-intermediate-layer-and-get-final-output ,但提出的解决方案要么需要重新编码头部,要么不起作用。
我尝试了以下方法:
from tensorflow.keras.applications import EfficientNetB0 as Net
from tensorflow.keras.models import Model
split_idx = 73
input_shape = (250, 250, 3) # use depth=3 because imagenet is trained on RGB images
model = Net(weights="imagenet", include_top = True)
# Approach 1:
# create the full network so we can train on it
model_backbone = keras.models.Model(inputs=model.input, outputs=model.layers[split_idx].output)
# create new model taking the output from backbone as input and creating final output of head
model_head = keras.models.Model(inputs=model.layers[split_idx].output,
outputs=model.layers[-1].output)
# Approach 2:
# create function for feeding input through backbone
# the function takes a normal input image as input and returns the output of the final backbone layer
create_backbone_output = K.function([model.layers[0].input], model.layers[split_idx].output)
# create function for feeding output of backbone through heads
create_heads_output = K.function([model.layers[split_idx].output],
model.output)
但是,当我尝试执行此操作时,两种方法都会出现“图形断开错误”:
WARNING:tensorflow:Functional model inputs must come from `tf.keras.Input` (thus holding past layer
metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input
to "model_5" was not an Input tensor, it was generated by layer block3b_drop.
Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.
The tensor that caused the issue was: block3b_drop/Identity:0
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-33-64dd6f6430a1> in <module>
6 # create function for feeding output of backbone through heads
7 create_heads_output = K.function([model.layers[split_idx].output],
----> 8 model.output)
~\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\backend.py in
function(inputs,
outputs, updates, name, **kwargs)
4067 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
4068 from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
-> 4069 model = models.Model(inputs=inputs, outputs=outputs)
4070
4071 wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
~\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\tracking\base.py in
_method_wrapper(self, *args, **kwargs)
515 self._self_setattr_tracking = False # pylint: disable=protected-access
516 try:
--> 517 result = method(self, *args, **kwargs)
518 finally:
519 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\functional.py in
__init__(self, inputs, outputs, name, trainable, **kwargs)
118 generic_utils.validate_kwargs(kwargs, {})
119 super(Functional, self).__init__(name=name, trainable=trainable)
--> 120 self._init_graph_network(inputs, outputs)
121
122 @trackable.no_automatic_dependency_tracking
~\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\tracking\base.py in
_method_wrapper(self, *args, **kwargs)
515 self._self_setattr_tracking = False # pylint: disable=protected-access
516 try:
--> 517 result = method(self, *args, **kwargs)
518 finally:
519 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\functional.py in
_init_graph_network(self, inputs, outputs)
202 # Keep track of the network's nodes and layers.
203 nodes, nodes_by_depth, layers, _ = _map_graph_network(
--> 204 self.inputs, self.outputs)
205 self._network_nodes = nodes
206 self._nodes_by_depth = nodes_by_depth
~\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\functional.py in
_map_graph_network(inputs, outputs)
981 'The following previous layers '
982 'were accessed without issue: ' +
--> 983 str(layers_with_complete_input))
984 for x in nest.flatten(node.outputs):
985 computable_tensors.add(id(x))
ValueError: Graph disconnected: cannot obtain value for tensor
KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_4'),
name='input_4', description="created by layer 'input_4'") at layer "rescaling_3". The following
previous layers were accessed without issue: []
我知道错误源于提供的张量不是输入张量的问题。这个问题有什么解决办法吗?
解决方案
1.) 该模型不像您尝试处理它那样顺序 -> split_idx +1 是添加另一层的 add_operation,必须将其添加到第一个输出并添加到第二个输入。
block3b_drop (Dropout) (None, 28, 28, 40) 0 block3b_project_bn[0][0]
__________________________________________________________________________________________________
block3b_add (Add) (None, 28, 28, 40) 0 block3b_drop[0][0]
block3a_project_bn[0][0]
__________________________________________________________________________________________________
2.)将所有需要的输入与给定的输出相加:
second_input1 = keras.Input(shape=model.layers[split_idx].output.shape[1:])
second_input2 = keras.Input(shape=model.get_layer(name='block3a_project_bn').output.shape[1:])
3.)在这里重新连接模型的其余部分,你需要添加一些东西,但我给你一些片段让你开始:
for sequentially rewiring it it would be:
tmp = [second_input1,second_input2]
for l in range(split_idx+1, len(model.layers)):
layer = model.layers[l]
print(layer.name, layer.input)
tmp = layer(tmp)
在您的情况下,这还不够,您需要找到正确的输入,下面的片段就是这样做的。找到正确的输入,将其用于下一个输出(跟踪输出),然后通过图表按自己的方式工作
for l in model.layers:
# multiple inputs
if type(l.input) is list:
for li,lv in enumerate(l.input):
print('o ', li, lv.name)
else:
print('- ', l.input.name)
另一种便宜的方法是 -> 将其保存为 json,添加您的输入节点,在那里删除未使用的节点。加载新的 json 文件,在这种情况下,您不需要重新连接。
推荐阅读
- java - 无法使用 FileWriter 在 Android 上写入文件
- java - 使用 zuul 的错误请求(服务异常)spring boot 响应时间慢
- java - 调用 adapter.notifyItemChanged(position); 工作很奇怪
- jquery - 在jQuery中动态添加元素后无法获取字典的长度
- php - json_extract sql迭代?
- jquery - 按下按钮后刷新页面以使用 ajax 上传文件
- java - 如何动态定义 JPA 与 Injected EntityManager 连接的属性
- ember.js - 错误:断言失败:您必须向链接组件提供一个或多个参数
- sql-server - SQL Server EXECUTE AS 不当行为
- python - BiLSTM(双向长短期记忆网络)和 MLP(多层感知器)