首页 > 解决方案 > 在 tf.case 上计算梯度时使用签名

问题描述

我正在尝试使用签名计算 tf.case 上的梯度。

例如,假设我有一个 case 函数,它需要一批输入并根据输入的符号计算输出:

def case_fn(x):                                                                                                                                                                                                                                                                                                          
    N = tf.shape(x)[0]                                                                                                                                                                                                                                                                                                   
    positive_idx = tf.cast(tf.squeeze(tf.where(tf.squeeze(tf.math.greater(x, 0.)))),tf.int32)                                                                                                                                                                                                                            
    negative_idx = tf.cast(tf.squeeze(tf.where(tf.squeeze(tf.math.less_equal(x, 0.)))),tf.int32)                                                                                                                                                                                                                         
    def all_positive_case():                                                                                                                                                                                                                                                                                             
        y_positive = x*2.                                                                                                                                                                                                                                                                                                

        return y_positive                                                                                                                                                                                                                                                                                                

    def all_negative_case():                                                                                                                                                                                                                                                                                             
        y_negative = x-2.                                                                                                                                                                                                                                                                                                

        return y_negative                                                                                                                                                                                                                                                                                                

    def some_positive_some_negative_case():                                                                                                                                                                                                                                                                              
        x_positive = tf.gather(x, positive_idx)                                                                                                                                                                                                                                                                          
        x_negative = tf.gather(x, negative_idx)                                                                                                                                                                                                                                                                          

        y_positive = x_positive*2.                                                                                                                                                                                                                                                                                       
        y_negative = x_negative-2.                                                                                                                                                                                                                                                                                       

        y_positive = tf.scatter_nd(tf.expand_dims(positive_idx,1),y_positive,tf.stack([N,1]))                                                                                                                                                                                                                            
        y_negative = tf.scatter_nd(tf.expand_dims(negative_idx,1),y_negative,tf.stack([N,1]))                                                                                                                                                                                                                            

        return y_positive + y_negative                                                                                                                                                                                                                                                                                   

    all_positive = tf.math.equal(tf.shape(negative_idx)[0], 0)                                                                                                                                                                                                                                                           
    all_negative = tf.math.equal(tf.shape(positive_idx)[0], 0)                                                                                                                                                                                                                                                           
    return tf.case([(all_positive, all_positive_case), (all_negative, all_negative_case)], default=some_positive_some_negative_case)

然后,我使用以下代码计算梯度:

trainable_variable = tf.Variable([[1.], [-1.], [2.], [-2.]])                                                                                                                                                                                                                                                             
@tf.function                                                                                                                                                                                                                                                                                                             
def compute_grad():                                                                                                                                                                                                                                                                                                      
    with tf.GradientTape() as tape:                                                                                                                                                                                                                                                                                      
        y = case_fn(trainable_variable)                                                                                                                                                                                                                                                                                  
    grad = tape.gradient(y, trainable_variable)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
    return grad                                                                                                                                                                                                                                                                                                          

print(compute_grad())   

如果我不使用@tf.function装饰器,它会返回一个正确的值,即IndexedSlices(indices=tf.Tensor([0, 2, 1, 3], shape=(4,), dtype=int32), values=tf.Tensor([[2.],[2.],[1.],[1.]], shape=(4, 1), dtype=float32), dense_shape=tf.Tensor([4 1], shape=(2,), dtype=int32)). 但是,如果我使用@tf.function装饰器,它会返回一个值错误说

Traceback (most recent call last):
  File "examples/case_gradient.py", line 102, in <module>
    print(compute_grad())
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
    result = self._call(*args, **kwds)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

    examples/case_gradient.py:99 compute_grad  *
        grad = tape.gradient(y, trainable_variable)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/backprop.py:1029 gradient
        unconnected_gradients=unconnected_gradients)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/imperative_grad.py:77 imperative_grad
        compat.as_str(unconnected_gradients.value))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/backprop.py:141 _gradient_function
        return grad_fn(mock_op, *out_grads)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:121 _IfGrad
        false_graph, grads, util.unique_grad_fn_name(false_graph.name))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:381 _create_grad_func
        func_graph=_CondGradFuncGraph(name, func_graph))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:380 <lambda>
        lambda: _grad_fn(func_graph, grads), [], {},
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:371 _grad_fn
        src_graph=func_graph)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:669 _GradientsHelper
        lambda: grad_fn(op, *out_grads))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:336 _MaybeCompile
        return grad_fn()  # Exit early
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:669 <lambda>
        lambda: grad_fn(op, *out_grads))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:183 _IfGrad
        building_gradient=True,
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:219 _build_cond
        _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:652 _make_indexed_slices_indices_types_match
        (current_index, len(branch_graphs[0].outputs)))

    ValueError: Insufficient elements in branch_graphs[0].outputs.
    Expected: 6
    Actual: 3

我在这里想念什么?

标签: tensorflowtensorflow2.0

解决方案


我检查了最新版本,2.2.0-rc3没有看到这个问题。可能会在新版本中解决。


推荐阅读