python - Tf 2.0 : RuntimeError: GradientTape.gradient 只能在非持久性磁带上调用一次
问题描述
在tensorflow 2.0 指南中的 tf 2.0 DC Gan 示例中,有两个梯度磁带。见下文。
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
正如你可以清楚地看到有两个渐变磁带。我想知道使用单个磁带有什么区别并将其更改为以下
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
这给了我以下错误:
RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.
我想知道为什么需要两盘磁带。到目前为止,关于 tf2.0 API 的文档很少。任何人都可以解释或指出正确的文档/教程吗?
解决方案
从以下文档:GradientTape
默认情况下, GradientTape 持有的资源会在调用 GradientTape.gradient() 方法后立即释放。要在同一计算中计算多个梯度,请创建一个持久梯度磁带。这允许对 gradient() 方法的多次调用,因为当磁带对象被垃圾回收时资源被释放。
可以使用创建持久渐变with tf.GradientTape(persistent=True) as tape
并且可以/应该手动删除del tape
(此@zwep,@Crispy13 的学分)。
推荐阅读
- android - 如何使用 Kotlin 工具在 Android Studio 中指定从一种形式转换为另一种形式
- c++ - 将 HyperLedger Fabric 与 C++ 应用程序一起使用
- react-native - 具有动态资源的图像源上的 500 错误反应本机
- scala - 如何在 Idea IntelliJ 编辑器中使用 Scala Jackson 库执行基本 Json 馈送器
- nativescript - imageSourceModule.fromFile(path) 在 nativescript 中返回 null
- python - 在 Python 中循环遍历数据框的更优雅的方式
- deployment - 如何在 siddhi 应用程序中使用 deployment.yaml 变量?
- angular - Angular 7 - 创建 Angular 应用程序时出错
- amazon-web-services - How do I get the instance name of a Lightsail instance
- javascript - Testcafe - 在测试用例之外测试命令行参数