tensorflow - 加载 tf.keras.Model 子类调用超类中的 step_function
问题描述
我正在使用 keras 实现 GAN,其中我覆盖了 train_step 函数来自定义我自己的训练循环。当我创建模型并对其进行拟合时,一切正常。但是,当我加载之前保存的模型并尝试对其进行拟合时,调用的 train_step 函数属于 keras.Model 超类,这会导致抛出以下错误:
ValueError: No gradients provided for any variable: ['sequential/conv2d/kernel:0', 'sequential_1/conv2d_1/kernel:0', 'sequential_1/batch_normalization/gamma:0', 'sequential_1/batch_normalization/beta:0', 'sequential_2/conv2d_2/kernel:0', 'sequential_2/batch_normalization_1/gamma:0', 'sequential_2/batch_normalization_1/beta:0', 'sequential_3/conv2d_3/kernel:0', 'sequential_3/batch_normalization_2/gamma:0', 'sequential_3/batch_normalization_2/beta:0', 'sequential_4/conv2d_4/kernel:0', 'sequential_4/batch_normalization_3/gamma:0', 'sequential_4/batch_normalization_3/beta:0', 'sequential_5/conv2d_5/kernel:0', 'sequential_5/batch_normalization_4/gamma:0', 'sequential_5/batch_normalization_4/beta:0', 'sequential_6/conv2d_6/kernel:0', 'sequential_6/batch_normalization_5/gamma:0', 'sequential_6/batch_normalization_5/beta:0', 'sequential_7/conv2d_7/kernel:0', 'sequential_7/batch_normalization_6/gamma:0', 'sequential_7/batch_normalization_6/beta:0', 'sequential_8/conv2d_transpose/kernel:0', 'sequential_8/batch_normalization_7/gamma:0', 'sequential_8/batch_normalization_7/beta:0', 'sequential_9/conv2d_transpose_1/kernel:0', 'sequential_9/batch_normalization_8/gamma:0', 'sequential_9/batch_normalization_8/beta:0', 'sequential_10/conv2d_transpose_2/kernel:0', 'sequential_10/batch_normalization_9/gamma:0', 'sequential_10/batch_normalization_9/beta:0', 'sequential_11/conv2d_transpose_3/kernel:0', 'sequential_11/batch_normalization_10/gamma:0', 'sequential_11/batch_normalization_10/beta:0', 'sequential_12/conv2d_transpose_4/kernel:0', 'sequential_12/batch_normalization_11/gamma:0', 'sequential_12/batch_normalization_11/beta:0', 'sequential_13/conv2d_transpose_5/kernel:0', 'sequential_13/batch_normalization_12/gamma:0', 'sequential_13/batch_normalization_12/beta:0', 'sequential_14/conv2d_transpose_6/kernel:0', 'sequential_14/batch_normalization_13/gamma:0', 'sequential_14/batch_normalization_13/beta:0', 'conv2d_transpose_7/kernel:0', 'conv2d_transpose_7/bias:0', 'sequential_15/conv2d_8/kernel:0', 'sequential_16/conv2d_9/kernel:0', 'sequential_16/batch_normalization_14/gamma:0', 'sequential_16/batch_normalization_14/beta:0', 'sequential_17/conv2d_10/kernel:0', 'sequential_17/batch_normalization_15/gamma:0', 'sequential_17/batch_normalization_15/beta:0', 'conv2d_11/kernel:0', 'batch_normalization_16/gamma:0', 'batch_normalization_16/beta:0', 'conv2d_12/kernel:0', 'conv2d_12/bias:0'].
训练脚本如下所示:
if __name__ == '__main__':
ds_train = create_dataset('flic', test=False, batch_size=32)
ds_test = create_dataset('flic', test=True, batch_size=32)
model_path = './model/gan'
if os.path.exists(model_path):
print('Model found in disk, restoring...')
model = load_model(model_path)
else:
print('Model not found in disk, creating new one...')
os.makedirs(model_path)
generator = build_generator(img_width=256, img_height=256, output_channels=2)
discriminator = build_discriminator(img_width=256, img_height=256, output_channels=2)
model = PatchGAN(generator=generator, discriminator=discriminator)
model.compile()
model.fit(
ds_train,
epochs=1,
steps_per_epoch=1,
callbacks=[
TensorBoard(),
LambdaCallback(on_epoch_end=lambda epoch, logs: save_model(model, model_path)),
LambdaCallback(on_epoch_end=lambda epoch, logs: preview_output(model.generator, ds_test))
]
)
GAN 代码如下所示:
class PatchGAN(tf.keras.Model):
def __init__(self,
generator,
discriminator,
lamb=100,
loss_function=BinaryCrossentropy(from_logits=True),
generator_optimizer=Adam(learning_rate=0.0003),
discriminator_optimizer=Adam(learning_rate=0.0003),
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.lamb = lamb
self.generator = generator
self.discriminator = discriminator
self.loss_function = loss_function
self.generator_optimizer = generator_optimizer
self.discriminator_optimizer = discriminator_optimizer
self._set_inputs(generator.inputs)
def call(self, inputs, training=None, mask=None):
return self.generator(inputs)
def train_step(self, data):
input_image, target = data
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = self.generator(input_image, training=True)
disc_real_output = self.discriminator([input_image, target], training=True)
disc_generated_output = self.discriminator([input_image, gen_output], training=True)
gan_loss = self.loss_function(tf.ones_like(disc_generated_output), disc_generated_output)
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (self.lamb * l1_loss)
real_loss = self.loss_function(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = self.loss_function(tf.zeros_like(disc_generated_output), disc_generated_output)
total_disc_loss = real_loss + generated_loss
generator_gradients = gen_tape.gradient(total_gen_loss, self.generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(total_disc_loss, self.discriminator.trainable_variables)
self.generator_optimizer.apply_gradients(
zip(generator_gradients, self.generator.trainable_variables))
self.discriminator_optimizer.apply_gradients(
zip(discriminator_gradients, self.discriminator.trainable_variables))
return {
"d_loss": total_disc_loss,
"g_loss": total_gen_loss,
"gen_gan_loss": gan_loss,
"gen_l1_loss": l1_loss
}
def _downsample(filters, size, apply_batchnorm=True):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2D(filters, size, strides=2,
padding='same', kernel_initializer=initializer,
use_bias=False))
if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result
def _upsample(filters, size, apply_dropout=False):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
padding='same', kernel_initializer=initializer,
use_bias=False))
result.add(tf.keras.layers.BatchNormalization())
if apply_dropout:
result.add(tf.keras.layers.Dropout(0.5))
result.add(tf.keras.layers.ReLU())
return result
def build_generator(img_width, img_height, output_channels):
inputs = tf.keras.layers.Input(shape=[img_width, img_height, 1])
down_stack = [
_downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
_downsample(128, 4), # (bs, 64, 64, 128)
_downsample(256, 4), # (bs, 32, 32, 256)
_downsample(512, 4), # (bs, 16, 16, 512)
_downsample(512, 4), # (bs, 8, 8, 512)
_downsample(512, 4), # (bs, 4, 4, 512)
_downsample(512, 4), # (bs, 2, 2, 512)
_downsample(512, 4), # (bs, 1, 1, 512)
]
up_stack = [
_upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
_upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
_upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
_upsample(512, 4), # (bs, 16, 16, 1024)
_upsample(256, 4), # (bs, 32, 32, 512)
_upsample(128, 4), # (bs, 64, 64, 256)
_upsample(64, 4), # (bs, 128, 128, 128)
]
initializer = tf.random_normal_initializer(0., 0.02)
last = tf.keras.layers.Conv2DTranspose(output_channels, 4,
strides=2,
padding='same',
kernel_initializer=initializer,
activation='tanh') # (bs, 256, 256, 3)
x = inputs
# Downsampling through the model
skips = []
for down in down_stack:
x = down(x)
skips.append(x)
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip])
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
def build_discriminator(img_width, img_height, output_channels):
initializer = tf.random_normal_initializer(0., 0.02)
inp = tf.keras.layers.Input(shape=[img_width, img_height, 1], name='input_image')
tar = tf.keras.layers.Input(shape=[img_width, img_height, output_channels], name='target_image')
x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
down1 = _downsample(128, 4, False)(x) # (bs, 128, 128, 64)
down2 = _downsample(256, 4)(down1) # (bs, 64, 64, 128)
down3 = _downsample(256, 4)(down2) # (bs, 32, 32, 256)
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
conv = tf.keras.layers.Conv2D(512, 4, strides=1,
kernel_initializer=initializer,
use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
return tf.keras.Model(inputs=[inp, tar], outputs=last)
有人知道为什么会这样吗?
解决方案
推荐阅读
- java - 使用杰克逊解开内部 json 对象
- javascript - 如何定位元素
- html - 我如何从两个不同的类(td.fonce 和 td.clair)中提取数据:
- scala - 在 scala 中处理嵌套的 YAML 文件
- c# - 如何获取在我的网站中导航的设备的唯一 ID
- visual-studio-code - 从 VScode 扩展调用“添加所有缺少的导入”
- go - 如何确定函数返回的值的数量
- ios - ScrollView 缩放问题
- c# - 远程服务器返回错误 (400) Bad Request, status ProtocolError
- javascript - 通过另一个列表框的值过滤javascript中列表框的值