neural-network - 带有 Conv2D 层的 VAE 返回“InvalidArgumentError:不兼容的形状”
问题描述
import random
import os
from tensorflow.keras.layers import Flatten,Dense, Conv2D, MaxPooling2D, Input, UpSampling2D, Reshape
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf
seed_n=777
random.seed(seed_n)
tf.random.set_seed(seed_n)
ds_train = tf.keras.preprocessing.image_dataset_from_directory("./images/train",
labels='inferred',
#label_mode='int',
class_names=None,
color_mode='rgb',
image_size=(224,224),
shuffle=True,
seed=seed_n,
validation_split=None,
subset=None,
interpolation='bilinear',
follow_links=False
)
ds_test = tf.keras.preprocessing.image_dataset_from_directory("./images/test",
labels='inferred',
#label_mode=None,
class_names=None,
color_mode='rgb',
image_size=(224,224),
shuffle=True,
seed=seed_n,
validation_split=None,
subset=None,
interpolation='bilinear',
follow_links=False
)
def autoencoder(z_dim):
# inputs = Input(shape=[224,224,3])
inputs = Input((224,224,3))
x = inputs
x = Conv2D(filters=8, kernel_size=(3, 3), strides=2, padding="same", activation="relu")(x)
x = Conv2D(filters=8, kernel_size=(3, 3), strides=1, padding="same", activation="relu")(x)
x = Conv2D(filters=8, kernel_size=(3, 3), strides=2, padding="same", activation="relu")(x)
x = Conv2D(filters=8, kernel_size=(3, 3), strides=1, padding="same", activation="relu")(x)
x = Flatten()(x)
x = Dense(z_dim, activation="relu")(x)
x = Dense(7*7*64, activation="relu")(x)
x = Reshape((7, 7, 64))(x)
x = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding="same", activation="relu")(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(filters=32, kernel_size=(3, 3), strides=1, padding="same", activation="relu")(x)
x = UpSampling2D((16, 16))(x)
out = Conv2D(filters=3, kernel_size=(3, 3), strides=1, padding="same", activation="sigmoid")(x) # Filter value should be the number of colour channels
return Model(inputs=inputs, outputs=out, name="autoencoder")
z_dim = 1000
autoencoder = autoencoder(z_dim)
autoencoder.compile(loss="mse", optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-4))
autoencoder.fit(ds_train, validation_data=ds_test, epochs=1,callbacks = [EarlyStopping(monitor="val_loss", patience=2)])
网络架构如下:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 224, 224, 3)] 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 112, 112, 8) 224
_________________________________________________________________
conv2d_8 (Conv2D) (None, 112, 112, 8) 584
_________________________________________________________________
conv2d_9 (Conv2D) (None, 56, 56, 8) 584
_________________________________________________________________
conv2d_10 (Conv2D) (None, 56, 56, 8) 584
_________________________________________________________________
flatten_1 (Flatten) (None, 25088) 0
_________________________________________________________________
dense_2 (Dense) (None, 1000) 25089000
_________________________________________________________________
dense_3 (Dense) (None, 3136) 3139136
_________________________________________________________________
reshape_1 (Reshape) (None, 7, 7, 64) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 7, 7, 64) 36928
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_12 (Conv2D) (None, 14, 14, 32) 18464
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 224, 224, 32) 0
_________________________________________________________________
conv2d_13 (Conv2D) (None, 224, 224, 3) 867
=================================================================
Total params: 28,286,371
Trainable params: 28,286,371
Non-trainable params: 0
_________________________________________________________________
我收到以下错误:
InvalidArgumentError: Incompatible shapes: [3,224,224,3] vs. [3,1]
[[node gradient_tape/mean_squared_error/BroadcastGradientArgs (defined at <ipython-input-14-a033834688e2>:1) ]] [Op:__inference_train_function_3138]
Function call stack:
train_function
我怀疑错误出在批量大小或损失函数中。但是除非我设置batch_size = 1,否则我无法使其工作。模型摘要中的输入/输出形状对我来说似乎很好,但我显然遗漏了一些东西......
编辑:batch_size=1 有时有效,有时无效...只是尝试了一个更大的数据集(7730 个图像用于训练,499 个图像用于验证)并得到了不同的错误:
ValueError: No gradients provided for any variable: ['conv2d_7/kernel:0', 'conv2d_7/bias:0', 'conv2d_8/kernel:0', 'conv2d_8/bias:0', 'conv2d_9/kernel:0', 'conv2d_9/bias:0', 'conv2d_10/kernel:0', 'conv2d_10/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0', 'dense_3/kernel:0', 'dense_3/bias:0', 'conv2d_11/kernel:0', 'conv2d_11/bias:0', 'conv2d_12/kernel:0', 'conv2d_12/bias:0', 'conv2d_13/kernel:0', 'conv2d_13/bias:0'].
解决方案
推荐阅读
- oracle - 使用 pl/sql 或 sql 将数据拆分到学生表中的多个列
- clang-format - 尽可能保持参数/参数在同一行
- php - 如何检查我的工作是否在分派后调用了句柄方法?
- spring-boot - redis:使用来自多个服务实例的 zset 中的元素
- python - 如何将 spark.sql.dataframe 写入数据块中的 S3 存储桶?
- swift - 列表仅显示 Swift 中传递的数据中的最后一个数组
- autodesk-forge - 使用选项 setDisplayEdges(true) 时材质变为白色
- android - 如何在 jsoup 的 ui li 标签中获取文本?
- javascript - 遍历段落数组并从 JavaScript 中的另一个数组添加内容
- javascript - 当我使用 db.Entry(x).State = EntityState.Modified; 时,Json 结果不起作用