首页 > 解决方案 > 升级到 Tensorflow 2.5 现在在使用预训练的 Keras 应用程序模型时会出现 Lambda 层错误

问题描述

我按照本教程为我的问题构建了一个连体网络。我使用的是 Tensorflow 2.4.1,现在升级了

这段代码以前工作得很好

base_cnn = resnet.ResNet50(
    weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

现在每个 resnet 层或 mobilenet 或高效网络(都尝试过)都会抛出这些错误:

WARNING:tensorflow:
The following Variables were used a Lambda layer's call (tf.nn.convolution_620), but
are not present in its tracked objects:
  <tf.Variable 'stem_conv/kernel:0' shape=(3, 3, 3, 48) dtype=float32>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.

它编译并且似乎适合。

但是我们必须在 2.5 中以不同的方式初始化模型吗?

感谢您的任何指点!

标签: tensorflowmachine-learningkerasdeep-learningkeras-layer

解决方案


这里没有必要恢复到TF2.4.1. 我总是建议尝试使用最新版本,因为它解决了许多性能问题和新功能。

我能够执行上面的代码而没有任何问题TF2.5

import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, Model


img_width, img_height = 224, 224
target_shape = (img_width, img_height, 3)


base_cnn = ResNet50(
    weights="imagenet", input_shape=target_shape, include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

输出:

2.5.0
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 1s 0us/step

根据@Olli,重新启动和清除会话内核已经解决了这个问题。


推荐阅读