python - 通过回调函数在 keras/tensorflow 中动态更新变量
问题描述
我想将参数从回调函数传输到自定义正则化函数。
以下回调函数将从混淆矩阵中计算正则化器值:
class call_evaluator(keras.callbacks.Callback):
def __init__(self):
self.regularizer = sym_regularizer()
def on_epoch_end(self, batch, logs=None):
y_pred = tf.cast(tf.math.argmax(model.predict(x_train), axis=1), tf.float32)
y_true = np.argmax(y_train, axis=1)
con_mat = tf.math.confusion_matrix(y_pred, y_true)
diag_sum = tf.linalg.trace(con_mat)
mat_sum = tf.reduce_sum(con_mat)
buffer = tf.math.sqrt(mat_sum / diag_sum)
buffer = buffer.numpy()
self.regularizer.set_penalty(buffer)
计算值用于以下正则化函数:
class sym_regularizer(regularizers.Regularizer):
def __init__(self, sym_penalty=10.0):
# with K.name_scope(self.__class__.__name__):
# self.sym_penalty = K.variable(sym_penalty, name='sym_penalty')
# self.val_sym_penalty = sym_penalty
self.sym_penalty = K.variable(sym_penalty, name='sym_penalty')
self.val_sym_penalty = sym_penalty
def set_penalty(self, sym_penalty):
K.set_value(self.sym_penalty, sym_penalty)
self.val_sym_penalty = sym_penalty
tf.print("self.val_sym_penalty = ", self.val_sym_penalty)
def __call__(self, weights):
regularization = 0
regularization += K.sum(1e-3 * K.square(weights)) + self.val_sym_penalty
return regularization
def get_config(self):
return {'sym_penalty': float(K.get_value(self.sym_penalty))}
我使用的模型如下所示:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(30,
kernel_regularizer=sym_regularizer(),
activation=tf.nn.tanh,
input_shape=(x_train.shape[1],)))
model.add(tf.keras.layers.Dense(10,
kernel_regularizer=sym_regularizer(),
activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(4,
kernel_regularizer=sym_regularizer(),
activation=tf.nn.softmax))
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
model.compile(loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train,
y_train,
batch_size=100000,
epochs=100,
verbose=1,
callbacks=[call_evaluator()])
代码将成功访问正则化函数中的更新set_penalty(self, sym_penalty)并设置sym_penalty变量。不幸的是,该值不会在函数的 _ call _ 部分中更新。
该代码基于以下来源:
带有 activity_regularizer 的 Keras,每次迭代都会更新
我无法找到错误,并且无法完全理解代码本身。
解决方案
推荐阅读
- python - 在 python 中使用 pandas 时如何修复“属性错误”
- javascript - 如何在点击时动态添加到 DOM?
- javascript - Django 应用程序不与 JavaScript 代码块通信
- mysql - Mysql 更新相关表时出现错误 1442
- python - 如何根据另一个表的列填充列?
- java - 如何使用非对称加密在另一个应用程序中解密来自一个应用程序的文本?
- xamarin - Xamarin Forms ListView 取消 ItemSelected 事件
- firebase - Vuejs 和 Firebase 存储问题。未捕获的类型错误:存储不是函数
- php - 将可点击的电话号码添加到 php 邮件
- asp.net-core - 使用 VS2019 和 ASP.NET 核心通过 SSH 附加到进程不起作用