python - 修改 Keras 类以包含调用函数的问题
问题描述
我想训练一个拥有庞大数据集的 VAE,并决定使用为时尚 MNIST 制作的 VAE代码以及使用我在 github 上找到的文件名进行批量加载的流行修改。我的研究合作笔记本在这里和dataset的一个示例部分。
但是 VAE 类的编写方式没有根据 keras文档应该存在的调用函数。我收到错误NotImplementedError:子类化Model
类时,您应该实现一个call
方法。
class VAE(tf.keras.Model):
"""a basic vae class for tensorflow
Extends:
tf.keras.Model
"""
def __init__(self, **kwargs):
super(VAE, self).__init__()
self.__dict__.update(kwargs)
self.enc = tf.keras.Sequential(self.enc)
self.dec = tf.keras.Sequential(self.dec)
def encode(self, x):
mu, sigma = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
return ds.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * 0.5) + mean
def reconstruct(self, x):
mu, _ = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
return self.decode(mu)
def decode(self, z):
return self.dec(z)
def compute_loss(self, x):
q_z = self.encode(x)
z = q_z.sample()
x_recon = self.decode(z)
p_z = ds.MultivariateNormalDiag(
loc=[0.] * z.shape[-1], scale_diag=[1.] * z.shape[-1]
)
kl_div = ds.kl_divergence(q_z, p_z)
latent_loss = tf.reduce_mean(tf.maximum(kl_div, 0))
recon_loss = tf.reduce_mean(tf.reduce_sum(tf.math.square(x - x_recon), axis=0))
return recon_loss, latent_loss
def compute_gradients(self, x):
with tf.GradientTape() as tape:
loss = self.compute_loss(x)
return tape.gradient(loss, self.trainable_variables)
@tf.function
def train(self, train_x):
gradients = self.compute_gradients(train_x)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
编码器和解码器分别定义并编译为
N_Z = 8
filt_base = 32
DIMS = (128,128,3)
encoder = [
tf.keras.layers.InputLayer(input_shape=DIMS),
tf.keras.layers.Conv2D(
filters=filt_base, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*2, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*2, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*3, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*3, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*4, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*4, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=N_Z*2),
]
decoder = [
tf.keras.layers.Dense(units=8 * 8 * 128, activation="relu"),
tf.keras.layers.Reshape(target_shape=(8, 8, 128)),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*4, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*4, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*3, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*3, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*2, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*2, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"
),
]
optimizer = tf.keras.optimizers.Adam(1e-3)
model = VAE(
enc = encoder,
dec = decoder,
optimizer = optimizer,
)
model.compile(optimizer=optimizer)
并尝试使用fit_generator 函数训练模型
num_epochs = 50
model.fit_generator(generator=my_training_batch_generator,
steps_per_epoch=(num_training_samples // batch_size),
epochs=num_epochs,
verbose=1,
validation_data=my_validation_batch_generator,
validation_steps=(num_validation_samples // batch_size),
use_multiprocessing=True,
workers=16,
max_queue_size=32)
我是机器学习的新手,任何解决问题的帮助都将不胜感激。我认为问题在于 VAE 类中的 def 火车线。
一个可选的要求是,如果可以进行培训,以便我可以看到每个 epoch 之后的重建,将不胜感激。为此,我在研究协作笔记本中已经有一个plot_reconstruction函数需要调用。
解决方案
推荐阅读
- python - Multiprocessing.Process 不并行运行进程
- bash - 如何仅匹配模式的前 N 个实例,然后在每个模式之后打印行直到空行?
- python - 有没有办法删除电子邮件地址中@之后的字符?
- python - PyCharm:如何在本地禁用缺少文档字符串检查
- c# - 用户或管理员的 ASP.Net Core WebAPI 授权策略
- excel - “第一个星期一”公式的Excel变体
- c++ - Visual Studio 2015 中的外部“C”显式类型错误,DLL 测试代码
- python - 如何让用户名成为计算机可以使用和读取的变量?
- visual-studio-code - VS Code API 获取右括号的位置
- twitter-bootstrap - 为什么单击到其他选项卡后第一个导航选项卡的内容会被替换?