首页 > 解决方案 > Tensorflow:如何通过整数张量传递梯度?

问题描述

我在使用 TensorFloww 和 Keras 构建的 W-GAN 网络中的自定义操作中维护梯度流时遇到问题。我已经构建了 W-GAN 来学习生成图像“x”(为简单起见,我们假设 batch_size=1)。然后将该图像传递给下面的函数,该函数使用学习的图像来索引滤波器组查找表 (LUT)。我收到以下错误:

Traceback (most recent call last):
  File "/home/PycharmProjects/ngan/main.py", line 141, in <module>
    gan.fit(training_images, batch_size=opts["BATCH_SIZE"], epochs=opts["EPOCHS"],
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 848, in fit
    tmp_logs = train_function(iterator)
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 571, in train_function
    outputs = self.distribute_strategy.run(
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 951, in run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2290, in call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2649, in _call_for_each_replica
    return fn(*args, **kwargs)
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 282, in wrapper
    return func(*args, **kwargs)
  File "/home/PycharmProjects/ngan/ngan_test.py", line 352, in train_step
    self.generator_optimizer.apply_gradients(
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 472, in apply_gradients
    grads_and_vars = _filter_grads(grads_and_vars)
  File "/home/anaconda3/envs/ngan/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 1218, in _filter_grads
    raise ValueError("No gradients provided for any variable: %s." %
ValueError: No gradients provided for any variable: ['conv_layer_0/kernel:0', 'batch_normalization/gamma:0', 'batch_normalization/beta:0', 'conv2d/kernel:0', 'conv2d/bias:0', 'dense/kernel:0', 'layer_normalization/gamma:0', 'layer_normalization/beta:0'].

即使通过 tf.gather_nd 是不可微的操作,TensorFlow 开发人员认为至少让它通过梯度是合适的,所以我认为这不是问题。我认为问题在于 tf.cast 到整数操作。据我所知,TensorFlow 在整数张量上有一个梯度硬停止。我什至尝试像为 tf.round 操作那样创建一个 tf.custom_gradient 传递,但这也不起作用。有人可以建议一种索引LUT的方法吗?

下面的代码:

import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.keras.backend as K

def test_function(x, lut, xmin, xmax, filter_bank):
    """ 
    Parameters:
        x: learned image (float32); shape=(batch_size, width, height, 1)
        lut: look-up table (float32); shape=(1, 194, 12)
        xmin: left limit (float32); shape=(1,)
        xmax: right limit (float32); shape=(1,)
        flter_bank: a filter bank of functions (float32); shape=(width*heigh, 12)

    Returns:
        proc_image: processed image (float32); shape=(batch_size, width, height, 1)

    """

    @tf.custom_gradient
    def custom_round(x):
        """ Apply Floor operation and pass through gradient unchanged """
        def grad_fn(dy):
            return dy
        return tf.round(x), grad_fn

    # Convert incoming variable to a map of indices:
    index = (x - xmin) / xmax * lut.shape[1]
    index = K.cast(custom_round(index), dtype=tf.int32)
    index = layers.Flatten()(index)

    # Index the LUT:
    b = layers.Flatten()(K.cast(K.zeros_like(x), tf.int32))
    indices = K.stack([b, index], -1)
    indexed_weights = tf.gather_nd(lut, indices, batch_dims=0)

    # Calculated weighted sum of filter bank:
    proc_image = tf.math.reduce_sum(tf.math.multiply(indexed_weights, filter_bank), axis=2, keepdims=True)

    return layers.Flatten()(proc_image)

我在跑步:

版本信息:

tensorflow-gpu-v2.0.0
NVIDIA-SMI 430.64 / Driver Version: 430.64 / CUDA Version: 10.1
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04

标签: tensorflowgradientbackpropagation

解决方案


推荐阅读