tensorflow - Tensorflow:如何通过整数张量传递梯度?
问题描述
我在使用 TensorFloww 和 Keras 构建的 W-GAN 网络中的自定义操作中维护梯度流时遇到问题。我已经构建了 W-GAN 来学习生成图像“x”(为简单起见,我们假设 batch_size=1)。然后将该图像传递给下面的函数,该函数使用学习的图像来索引滤波器组查找表 (LUT)。我收到以下错误:
Traceback (most recent call last):
File "/home/PycharmProjects/ngan/main.py", line 141, in <module>
gan.fit(training_images, batch_size=opts["BATCH_SIZE"], epochs=opts["EPOCHS"],
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 848, in fit
tmp_logs = train_function(iterator)
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 571, in train_function
outputs = self.distribute_strategy.run(
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 951, in run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2290, in call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2649, in _call_for_each_replica
return fn(*args, **kwargs)
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 282, in wrapper
return func(*args, **kwargs)
File "/home/PycharmProjects/ngan/ngan_test.py", line 352, in train_step
self.generator_optimizer.apply_gradients(
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 472, in apply_gradients
grads_and_vars = _filter_grads(grads_and_vars)
File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 1218, in _filter_grads
raise ValueError("No gradients provided for any variable: %s." %
ValueError: No gradients provided for any variable: ['conv_layer_0/kernel:0', 'batch_normalization/gamma:0', 'batch_normalization/beta:0', 'conv2d/kernel:0', 'conv2d/bias:0', 'dense/kernel:0', 'layer_normalization/gamma:0', 'layer_normalization/beta:0'].
即使通过 tf.gather_nd 是不可微的操作,TensorFlow 开发人员认为至少让它通过梯度是合适的,所以我认为这不是问题。我认为问题在于 tf.cast 到整数操作。据我所知,TensorFlow 在整数张量上有一个梯度硬停止。我什至尝试像为 tf.round 操作那样创建一个 tf.custom_gradient 传递,但这也不起作用。有人可以建议一种索引LUT的方法吗?
下面的代码:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.keras.backend as K
def test_function(x, lut, xmin, xmax, filter_bank):
"""
Parameters:
x: learned image (float32); shape=(batch_size, width, height, 1)
lut: look-up table (float32); shape=(1, 194, 12)
xmin: left limit (float32); shape=(1,)
xmax: right limit (float32); shape=(1,)
flter_bank: a filter bank of functions (float32); shape=(width*heigh, 12)
Returns:
proc_image: processed image (float32); shape=(batch_size, width, height, 1)
"""
@tf.custom_gradient
def custom_round(x):
""" Apply Floor operation and pass through gradient unchanged """
def grad_fn(dy):
return dy
return tf.round(x), grad_fn
# Convert incoming variable to a map of indices:
index = (x - xmin) / xmax * lut.shape[1]
index = K.cast(custom_round(index), dtype=tf.int32)
index = layers.Flatten()(index)
# Index the LUT:
b = layers.Flatten()(K.cast(K.zeros_like(x), tf.int32))
indices = K.stack([b, index], -1)
indexed_weights = tf.gather_nd(lut, indices, batch_dims=0)
# Calculated weighted sum of filter bank:
proc_image = tf.math.reduce_sum(tf.math.multiply(indexed_weights, filter_bank), axis=2, keepdims=True)
return layers.Flatten()(proc_image)
我在跑步:
版本信息:
tensorflow-gpu-v2.0.0
NVIDIA-SMI 430.64 / Driver Version: 430.64 / CUDA Version: 10.1
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
解决方案
推荐阅读
- html - 如何使用 AJAX 响应使用 Django 更改动态呈现的表内的跨度值
- spring - Spring security 自定义用户详细信息@Id 问题
- performance - perfmon 中缺少 SSRS 收集集和性能计数
- html - 如何在父 div 之上制作 div 行?
- future - APIError(code=-1022): 此请求的签名无效
- python - 如何绕过方法中的变量
- linux - 大会中的计数字母
- c - 如何在C中找出二次方程的根?
- .net-core - 是否可以使用 .NET Standard 库生成 CMAC?
- java - 为什么我的插入排序代码中出现 ArrayIndexOutOfBoundsException: 1?