python - Tensorflow Keras 从回调中修改模型变量
问题描述
我正在尝试从每个时期开始时的回调中修改不可训练的模型变量。本质上,我希望有一个类似于学习率调度器的机制(它在 TF 中内置了基础设施),但适用于任意模型变量。下面的代码是显示该概念的最小示例。我正在尝试修改衰减变量,但它不起作用。显然,变量 (1.0) 的初始值被视为常量并被图形折叠,并且随着训练的进行,即使该变量似乎已被回调正确修改(到 0.5),也不会再查看。
dense1 = tf.keras.layers.Dense(10)
decay = tf.Variable(1.0, trainable=False)
dense2 = tf.keras.layers.Dense(10)
def epoch_callback(epoch):
nonlocal decay
tf.keras.backend.set_value(decay, 0.5)
#decay.assign(0.5)
print(tf.keras.backend.get_value(decay))
input = tf.keras.layers.Input((MAX_LENGTH,))
x = dense1(input)
with tf.control_dependencies([decay]):
x = x * decay
prediction = dense2(x)
model = tf.keras.Model(inputs=[input], outputs=[prediction])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
callbacks = [tf.keras.callbacks.LambdaCallback(on_epoch_begin = lambda epoch, logs: epoch_callback(epoch))]
model.fit(train_ds, epochs=EPOCHS, verbose=1, callbacks=callbacks, validation_data=eval_ds)
@nbro:给你。下面的代码对我有用。我使用教师强制协议,并且每个时期的衰减变量用于随着训练的进行“降低教师的声音”。
class Teacher(tf.keras.layers.Layer):
def __init__(self, embedding, name='teacher', **kwargs):
super().__init__(name=name, **kwargs)
...
def build(self, input_shape):
...
def call(self, inputs, training=None):
x, y, decay = inputs
...
if training:
y = tf.multiply(y, decay)
else:
y = tf.multiply(y, tf.constant(0.0))
...
return x
def get_config(self):
return {}
class MyNet(tf.keras.Model):
def __init__(self, name='mynet', **kwargs):
super().__init__(name=name, **kwargs)
def build(self, input_shape):
...
self.teacher = Teacher()
self.decay = tf.Variable(1.0, trainable=False)
...
def set_decay(self, decay):
self.decay.assign(decay)
@tf.function
def call(self, example, training=None):
x, y = example
...
x = self.teacher((x, y, self.decay))
...
return x
def get_config(self):
return {}
def main():
train_ds = ...
eval_ds = ...
train_ds = train_ds.map(lambda data, label: ((data, label), label), num_parallel_calls=tf.data.experimental.AUTOTUNE)
eval_ds = eval_ds.map(lambda data, label: ((data, label), label), num_parallel_calls=tf.data.experimental.AUTOTUNE)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
the_net = MyNet()
inputs = tf.keras.layers.Input((MAX_LENGTH,), dtype='int64', name='inputs')
targets = tf.keras.layers.Input((MAX_LENGTH,), dtype='int64', name='targets')
prediction = the_net((inputs, targets))
model = tf.keras.Model(inputs=[inputs, targets], outputs=[prediction])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=CosineSimilarity(name='val_loss'))
def _callback_fun(epoch, start = 0, steps = 8):
the_net.set_decay(tf.clip_by_value((start+steps-epoch)/steps, clip_value_min=tf.constant(0.0), clip_value_max=tf.constant(1.0)))
callbacks = [tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: _callback_fun(epoch))]
model.fit(train_ds, epochs=EPOCHS, verbose=2, callbacks=callbacks, validation_data=eval_ds)
if __name__ == '__main__':
main()
解决方案
推荐阅读
- iis - 负载均衡器到 IIS 超时
- azure - 每个国家/地区架构的数据库
- android - 在 Android 项目中发现奇怪的类。如何修复它
- python - Python 语法错误:非 ASCII
- angular - 如何避免 Angular 5 中多个组件中的代码重复?
- android - Android : 为内部团队上传应用程序并不适合所有用户
- python - 将 Python 脚本转换为 CLI 并启用其分发
- java - TransformException:java.util.zip.ZipException:重复条目:
- c# - 依赖注入对象创建说明
- java - 关于 Jboss 4.0.2 的 JMS 主题