首页 > 解决方案 > 修改 Keras 类以包含调用函数的问题

问题描述

我想训练一个拥有庞大数据集的 VAE,并决定使用为时尚 MNIST 制作的 VAE代码以及使用我在 github 上找到的文件名进行批量加载的流行修改。我的研究合作笔记本在这里和dataset的一个示例部分。

但是 VAE 类的编写方式没有根据 keras文档应该存在的调用函数。我收到错误NotImplementedError:子类化Model类时,您应该实现一个call方法。

class VAE(tf.keras.Model):
"""a basic vae class for tensorflow
Extends:
    tf.keras.Model
"""

def __init__(self, **kwargs):
    super(VAE, self).__init__()
    self.__dict__.update(kwargs)

    self.enc = tf.keras.Sequential(self.enc)
    self.dec = tf.keras.Sequential(self.dec)

def encode(self, x):
    mu, sigma = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
    return ds.MultivariateNormalDiag(loc=mu, scale_diag=sigma)

def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * 0.5) + mean

def reconstruct(self, x):
    mu, _ = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
    return self.decode(mu)

def decode(self, z):
    return self.dec(z)

def compute_loss(self, x):

    q_z = self.encode(x)
    z = q_z.sample()
    x_recon = self.decode(z)
    p_z = ds.MultivariateNormalDiag(
      loc=[0.] * z.shape[-1], scale_diag=[1.] * z.shape[-1]
      )
    kl_div = ds.kl_divergence(q_z, p_z)
    latent_loss = tf.reduce_mean(tf.maximum(kl_div, 0))
    recon_loss = tf.reduce_mean(tf.reduce_sum(tf.math.square(x - x_recon), axis=0))

    return recon_loss, latent_loss

def compute_gradients(self, x):
    with tf.GradientTape() as tape:
        loss = self.compute_loss(x)
    return tape.gradient(loss, self.trainable_variables)

@tf.function
def train(self, train_x):
    gradients = self.compute_gradients(train_x)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

编码器和解码器分别定义并编译为

N_Z = 8
filt_base = 32
DIMS = (128,128,3)

encoder = [
tf.keras.layers.InputLayer(input_shape=DIMS),
tf.keras.layers.Conv2D(
    filters=filt_base, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*2, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*2, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*3, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*3, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*4, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*4, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=N_Z*2),
]

decoder = [
tf.keras.layers.Dense(units=8 * 8 * 128, activation="relu"),
tf.keras.layers.Reshape(target_shape=(8, 8, 128)),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*4, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*4, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*3, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*3, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*2, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*2, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"
),
]

optimizer = tf.keras.optimizers.Adam(1e-3)

model = VAE(
  enc = encoder,
  dec = decoder,
  optimizer = optimizer,
)
model.compile(optimizer=optimizer)

并尝试使用fit_generator 函数训练模型

num_epochs = 50
model.fit_generator(generator=my_training_batch_generator,
                                      steps_per_epoch=(num_training_samples // batch_size),
                                      epochs=num_epochs,
                                      verbose=1,
                                      validation_data=my_validation_batch_generator,
                                      validation_steps=(num_validation_samples // batch_size),
                                      use_multiprocessing=True,
                                      workers=16,
                                      max_queue_size=32)

我是机器学习的新手,任何解决问题的帮助都将不胜感激。我认为问题在于 VAE 类中的 def 火车线。

一个可选的要求是,如果可以进行培训,以便我可以看到每个 epoch 之后的重建,将不胜感激。为此,我在研究协作笔记本中已经有一个plot_reconstruction函数需要调用。

标签: pythonmachine-learningkeras

解决方案


阿保罗31,

特别是在您的代码中,我建议向call()VAE 类添加函数:

def call(self, x):
    q_z = self.encode(x)
    z = q_z.sample()
    x_recon = self.decode(z)

我还建议对您的任务使用更标准的方法,尤其是作为初学者:

  1. 用于 tf.keras.preprocessing.image_dataset_from_directory()图像加载。教程在这里

  2. 使用自定义Model.train_step()来计算 VAE 损失,而不是 VAE 类中的多个函数。这里的例子。


推荐阅读