首页 > 解决方案 > 自定义层中的 TensorFlow 自定义渐变

问题描述

我正在使用自定义渐变设置自定义图层。输入是形状为 (?, 2) 的单个二维张量。输出也是具有形状 (?, 2) 的单个二维张量。

我正在努力理解这些对象的行为方式。我从文档中收集到的是,对于给定的输入,渐变将具有与输出相同的形状,并且我需要为每个输入返回渐变列表。我一直假设由于我的输入看起来像 (?, 2) 而我的输出看起来像 (?, 2),那么 grad 函数应该返回一个长度为 2 的列表:[input_1_grad, input_2_grad],其中两个列表项都是具有输出形状的张量 (?, 2)。

这不起作用,这就是为什么我希望这里有人可以提供帮助。

这是我的错误(似乎发生在编译时):

ValueError: Num gradients 3 generated for op name: "custom_layer/IdentityN" op: "IdentityN" input: "custom_layer_2/concat" input: "custom_layer_1/concat" attr { key: "T" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_gradient_op_type" value { s: "CustomGradient-28729" } } 不匹配 num 个输入 2

另一个问题是自定义层的输入本身也是一个自定义层(尽管没有自定义渐变)。如果有帮助,我将提供这两个层的代码。

另外,请注意,如果我不尝试指定自定义渐变,网络将编译并运行。但是,由于我的函数需要帮助区分自己,我需要手动干预,所以有一个有效的自定义渐变是至关重要的。

第一个自定义图层(无自定义渐变):

class custom_layer_1(tensorflow.keras.layers.Layer):
    def __init__(self):
        super(custom_layer_1, self).__init__()
    
    def build(self, input_shape):
        self.term_1 = self.add_weight('term_1', trainable=True)
        self.term_2 = self.add_weight('term_2', trainable=True)
    
    def call(self, x):
        self.term_1 = formula in terms of x
        self.term_2 = another formula in terms of x
        
        return tf.concat([self.term_1, self.term_2], axis=1)

第二个自定义图层(使用自定义渐变):

class custom_layer_2(tensorflow.keras.layers.Layer):
    ### the inputs
    # x is the concatenation of term_1 and term_2
    def __init__(self):
        super(custom_layer_2, self).__init__()
    
    def build(self, input_shape):
        #self.weight_1 = self.add_weight('weight_1', trainable=True)
        #self.weight_2 = self.add_weight('weight_2', trainable=True)
    
    def call(self, x):
        return custom_function(x)

自定义函数:

@tf.custom_gradient
def custom_function(x):
    ### the inputs
    # x is a concatenation of term_1 and term_2
    
    weight_1 = function in terms of x
    weight_2 = another function in terms of x
    
    ### the gradient
    def grad(dy):
        # assuming dy has the output shape of (?, 2). could be wrong.
        d_weight_1 = K.reshape(dy[:, 0], shape=(K.shape(x)[0], 1))
        d_weight_1 = K.reshape(dy[:, 1], shape=(K.shape(x)[0], 1))
        
        term_1 = K.reshape(x[:, 0], shape=(K.shape(x)[0], 1))
        term_2 = K.reshape(x[:, 1], shape=(K.shape(x)[0], 1))
        
        d_weight_1_d_term_1 = tf.where(K.equal(term_1, K.zeros_like(term_1)), K.zeros_like(term_1), -term_2 / term_1) * d_weight_1
        d_weight_1_d_term_2 = tf.where(K.equal(term_1, K.zeros_like(term_1)), K.zeros_like(term_1), 1 / term_1) * d_weight_1
        
        d_weight_2_d_term_1 = tf.where(K.equal(term_2, K.zeros_like(term_2)), K.zeros_like(term_1), 2 * term_1 / term_2) * d_weight_2
        d_weight_2_d_term_2 = tf.where(K.equal(term_2, K.zeros_like(term_2)), K.zeros_like(term_1), -K.square(term_1 / term_2)) * d_weight_2
        
        return tf.concat([d_weight_1_d_term_1, d_weight_1_d_term_2], axis=1), tf.concat([d_weight_2_d_term_1, d_weight_2_d_term_2], axis=1)
  
  return tf.concat([weight_1, weight_2], axis=1), grad

任何帮助将非常感激!

标签: pythontensorflow

解决方案


推荐阅读