tensorflow - 在训练期间无法更改 tensorflow 优化器中的 learning_rate
问题描述
有人可以解释一下为什么我不能learningrate
在训练期间改变它,在旧的优化器中我可以改变它self.updates.append(K.update(self.learning_rate, new_learning_rate))
但不能再这样做并且self._set_hyper("learning_rate", new_learning_rate)
不起作用,它告诉我:TypeError:__array__() takes 1 positional argument but 2 were given
class SGD(keras.optimizers.Optimizer):
def __init__(self, learning_rate=0.01, name="SGD", **kwargs):
"""Call super().__init__() and use _set_hyper() to store hyperparameters"""
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
@tf.function
def _resource_apply_dense(self, grad, var):
learning_rate = self._get_hyper("learning_rate")
new_learning_rate = learning_rate * 0.001
new_var_m = var - grad * new_learning_rate
self._set_hyper("learning_rate", new_learning_rate) #dont work
var.assign(new_var)
def _resource_apply_sparse(self, grad, var):
raise NotImplementedError
def get_config(self):
base_config = super().get_config()
return {
**base_config,
"learning_rate": self._serialize_hyperparameter("learning_rate"),
}
解决方案
作为一种解决方法,您可以_set_hyper
直接访问属性集,如文档中所述:
超参数可以通过用户代码覆盖
因为它是一个tf.Variable
,然后您可以使用assign
您的设置一个新值tf.Tensor
:
class SGD(keras.optimizers.Optimizer):
def __init__(self, learning_rate=0.01, name="SGD", **kwargs):
"""Call super().__init__() and use _set_hyper() to store hyperparameters"""
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
@tf.function
def _resource_apply_dense(self, grad, var):
learning_rate = self._get_hyper("learning_rate")
new_learning_rate = learning_rate * 0.001
new_var_m = var - grad * new_learning_rate
self.learning_rate.assign(new_learning_rate)
var.assign(new_var_m)
推荐阅读
- asp.net-mvc - 在 Umbraco 中上传会员图片时出现问题
- angular - @NgModule 的构造函数中是否应该有代码,如果有,有什么原因?
- google-chrome - 阻止 Chrome 在开发者工具中自动调整我的缩放或宽度 - Windows 7
- python - 神经计算棒 2:我已经完成了使用 NCS2 的所有处理,但它太慢了
- c# - unity c# xmlDocument 和 Resources.Load/Unload
- java - JAXB/XSD:数字而不是元素名称
- wordpress - 使用带有倒计时 JS 的 ACF 日期时间选择器。WordPress
- ruby - Net::HTTP 和 Nokogiri - nil:NilClass (NoMethodError) 的未定义方法“body”
- google-cloud-firestore - Firestore:为地图值创建索引
- sql - 使用 SQL 进行多重连接