首页 > 解决方案 > 在训练期间无法更改 tensorflow 优化器中的 learning_rate

问题描述

有人可以解释一下为什么我不能learningrate在训练期间改变它,在旧的优化器中我可以改变它self.updates.append(K.update(self.learning_rate, new_learning_rate))但不能再这样做并且self._set_hyper("learning_rate", new_learning_rate)不起作用,它告诉我:TypeError:__array__() takes 1 positional argument but 2 were given

class SGD(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.01, name="SGD", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        
    @tf.function
    def _resource_apply_dense(self, grad, var):
        learning_rate = self._get_hyper("learning_rate")
        new_learning_rate = learning_rate * 0.001
        new_var_m = var - grad * new_learning_rate
        self._set_hyper("learning_rate", new_learning_rate) #dont work
        var.assign(new_var)

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
        }

标签: tensorflowtensorflow2.0

解决方案


作为一种解决方法,您可以_set_hyper直接访问属性集,如文档中所述:

超参数可以通过用户代码覆盖

因为它是一个tf.Variable,然后您可以使用assign您的设置一个新值tf.Tensor

class SGD(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.01, name="SGD", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        
    @tf.function
    def _resource_apply_dense(self, grad, var):
        learning_rate = self._get_hyper("learning_rate")
        new_learning_rate = learning_rate * 0.001
        new_var_m = var - grad * new_learning_rate
        self.learning_rate.assign(new_learning_rate)
        var.assign(new_var_m)



推荐阅读