tensorflow - 升级到 Tensorflow 2.5 现在在使用预训练的 Keras 应用程序模型时会出现 Lambda 层错误
问题描述
我按照本教程为我的问题构建了一个连体网络。我使用的是 Tensorflow 2.4.1,现在升级了
这段代码以前工作得很好
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
现在每个 resnet 层或 mobilenet 或高效网络(都尝试过)都会抛出这些错误:
WARNING:tensorflow:
The following Variables were used a Lambda layer's call (tf.nn.convolution_620), but
are not present in its tracked objects:
<tf.Variable 'stem_conv/kernel:0' shape=(3, 3, 3, 48) dtype=float32>
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.
它编译并且似乎适合。
但是我们必须在 2.5 中以不同的方式初始化模型吗?
感谢您的任何指点!
解决方案
这里没有必要恢复到TF2.4.1
. 我总是建议尝试使用最新版本,因为它解决了许多性能问题和新功能。
我能够执行上面的代码而没有任何问题TF2.5
。
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, Model
img_width, img_height = 224, 224
target_shape = (img_width, img_height, 3)
base_cnn = ResNet50(
weights="imagenet", input_shape=target_shape, include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
输出:
2.5.0
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 1s 0us/step
根据@Olli,重新启动和清除会话内核已经解决了这个问题。
推荐阅读
- arrays - 如何在数组中返回多个值 - Excel VBA
- microsoft-teams - 在 SSL 证书更新后,Microsoft Teams 似乎不会向服务器发送机器人请求
- java - 带有 MockMVC 的 Junit - 错误 - java.lang.IllegalArgumentException:实体不能为空
- html - 剑道图表 PDF/图像导出 - 特殊字符 html 解码问题
- macos - 如何在mac上使用oracle sql developer中的逻辑模型?
- amazon-web-services - 使用 Java CDK 在 AWS CodePipeline 中禁用阶段转换
- c# - 如何使用 Dapper 存储过程调用父子调用
- javascript - 检查第一个字符是否是数字 - Discord.js
- java - Spring: Does a method name in @Configuration class have any role
- oracle - 找不到 Oracle 客户端和网络组件?