python - [自定义层]ValueError: tf.function-decorated 函数试图在非首次调用时创建变量
问题描述
问题
我正在尝试实现一个transformer encoder layer
我真正相信是导致上述错误的自定义层的自定义。除了上面的层,我还开发了 3 个更简单的自定义层,在这里我没有提到以保持整洁。
链接到代码
变压器层
class transformer(layers.Layer):
def __init__(self, num_heads, transformer_layers, patch_size):
super(transformer, self).__init__()
self.num_heads = num_heads
self.transformer_layers = transformer_layers
self.patch_size = patch_size
def call(self, encoded_patches):
for _ in range(self.transformer_layers):
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads = self.num_heads, key_dim = projection_dim, dropout = 0.1
)(x1,x1)
x2 = attention_output + encoded_patches
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = mlp(x3, transformer_units, 0.2)
encoded_patches = x3 + x2
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
resize_img = layers.Reshape([image_size // self.patch_size, image_size // self.patch_size, 64])(representation)
return resize_img
模型
def PVT():
# Inputs
input = layers.Input(shape=input_shape)
augment = data_augmentation(input)
# Stage 1
patches_1 = Patch_1(patch_size_1)(augment)
patches_1 = PatchEncoder(num_patches=(image_size // patch_size_1) ** 2, projection_dim=projection_dim)(patches_1)
input_2 = transformer(num_heads, transformer_layers, patch_size_1)(patches_1) #Output 1
# Stage 2
patches_2 = Patch_2(patch_size_2)(input_2)
patches_2 = PatchEncoder(num_patches=(image_size // patch_size_2) ** 2, projection_dim=projection_dim)(patches_2)
input_3 = transformer(num_heads, transformer_layers, patch_size_2)(patches_2) #Output 2
# Stage 3
patches_3 = Patch_3(patch_size_3)(input_3)
patches_3 = PatchEncoder(num_patches=(image_size // patch_size_3) ** 2, projection_dim=projection_dim)(patches_3)
input_4 = transformer(num_heads, transformer_layers, patch_size_3)(patches_3) #Output 3
# Stage 4
patches_4 = Patch_4(patch_size_4)(input_4)
patches_4 = PatchEncoder(num_patches=(image_size // patch_size_4) ** 2, projection_dim=projection_dim)(patches_4)
input_5 = transformer(num_heads, transformer_layers, patch_size_4)(patches_4) #Output 4
representation = layers.Flatten()(input_5)
representation = layers.Dropout(0.5)(representation)
# Classify outputs.
logits = layers.Dense(num_classes)(representation)
# Create the Keras model.
model = keras.Model(inputs=input, outputs=logits)
return model
编译器和优化器
def run_experiment(model):
optimizer = tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay, beta_1=0.9, beta_2=0.999
)
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
history = model.fit(
x=xtrain,
y=ytrain,
batch_size=batch_size,
epochs=5,
validation_split=0.1
)
model.save('model-5.h5')
return history
pvt = PVT()
history = run_experiment(pvt)
笔记
我已经检查了几个关于这个的来源,但我仍然对理解这个错误感到困惑。在您将我引向其他来源之前,我向您保证,我已经全部检查过了。因此,我真诚地请求您在这里提供一个干净的解决方案。
链接到代码
模型摘要(仅供参考)
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
data_augmentation (Sequentia (None, 72, 72, 3) 7
_________________________________________________________________
patch_1_2 (Patch_1) (None, None, 48) 0
_________________________________________________________________
patch_encoder_8 (PatchEncode (None, 324, 64) 23872
_________________________________________________________________
transformer_8 (transformer) (None, 18, 18, 64) 0
_________________________________________________________________
patch_2_2 (Patch_2) (None, None, 4096) 0
_________________________________________________________________
patch_encoder_9 (PatchEncode (None, 81, 64) 267392
_________________________________________________________________
transformer_9 (transformer) (None, 9, 9, 64) 0
_________________________________________________________________
patch_3_2 (Patch_3) (None, None, 16384) 0
_________________________________________________________________
patch_encoder_10 (PatchEncod (None, 16, 64) 1049664
_________________________________________________________________
transformer_10 (transformer) (None, 4, 4, 64) 0
_________________________________________________________________
patch_4_2 (Patch_4) (None, None, 65536) 0
_________________________________________________________________
patch_encoder_11 (PatchEncod (None, 4, 64) 4194624
_________________________________________________________________
transformer_11 (transformer) (None, 2, 2, 64) 0
=================================================================
Total params: 5,535,559
Trainable params: 5,535,552
Non-trainable params: 7
解决方案
推荐阅读
- highcharts - 如何在 Highcharts 工具提示或类别下方显示日期
- php - Laravel“目标 [Illuminate\Contracts\Bus\Dispatcher] 不可实例化。”
- java - Java-grpc 和 tikv-java: NoSuchFieldError: CONTEXT_SPAN_KEY
- c# - 如何在按钮单击和验证后调用使用 ajax 的 asp.net mvc 表单提交事件?
- python - 如何使用 Python 将动态数组值一一拆分?
- swift - Swift NavigationController 推送方法请求在 PopUp 视图中不起作用
- python - 如何实现请求的多线程或多处理
- java - 在 Android Studio 中调试时出现证书验证路径错误
- qnamaker - QnA Maker API - 更新后锁定
- rust - 如何在 Rust 中迭代宏的参数?