首页 > 解决方案 > 通过回调函数在 keras/tensorflow 中动态更新变量

问题描述

我想将参数从回调函数传输到自定义正则化函数。

以下回调函数将从混淆矩阵中计算正则化器值:

class call_evaluator(keras.callbacks.Callback):
    def __init__(self):
        self.regularizer = sym_regularizer()
    def on_epoch_end(self, batch, logs=None):
        y_pred = tf.cast(tf.math.argmax(model.predict(x_train), axis=1), tf.float32)
        y_true = np.argmax(y_train, axis=1)
        con_mat = tf.math.confusion_matrix(y_pred, y_true)
        diag_sum = tf.linalg.trace(con_mat)
        mat_sum = tf.reduce_sum(con_mat)
        buffer = tf.math.sqrt(mat_sum / diag_sum)
        buffer = buffer.numpy()
        self.regularizer.set_penalty(buffer)

计算值用于以下正则化函数:

class sym_regularizer(regularizers.Regularizer):
    def __init__(self, sym_penalty=10.0):
        # with K.name_scope(self.__class__.__name__):
        #     self.sym_penalty = K.variable(sym_penalty, name='sym_penalty')
        #     self.val_sym_penalty = sym_penalty
        self.sym_penalty = K.variable(sym_penalty, name='sym_penalty')
        self.val_sym_penalty = sym_penalty

    def set_penalty(self, sym_penalty):
        K.set_value(self.sym_penalty, sym_penalty)
        self.val_sym_penalty = sym_penalty
        tf.print("self.val_sym_penalty = ", self.val_sym_penalty)

    def __call__(self, weights):
        regularization = 0
        regularization += K.sum(1e-3 * K.square(weights)) + self.val_sym_penalty
        return regularization

    def get_config(self):
        return {'sym_penalty': float(K.get_value(self.sym_penalty))} 

我使用的模型如下所示:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(30,
                                kernel_regularizer=sym_regularizer(),
                                activation=tf.nn.tanh,
                                input_shape=(x_train.shape[1],)))
model.add(tf.keras.layers.Dense(10,
                                kernel_regularizer=sym_regularizer(),
                                activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(4,
                                kernel_regularizer=sym_regularizer(),
                                activation=tf.nn.softmax))

optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)

model.compile(loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(x_train,
          y_train,
          batch_size=100000,
          epochs=100,
          verbose=1,
          callbacks=[call_evaluator()])

代码将成功访问正则化函数中的更新set_penalty(self, sym_penalty)并设置sym_penalty变量。不幸的是,该值不会在函数的 _ call _ 部分中更新。

该代码基于以下来源:

带有 activity_regularizer 的 Keras,每次迭代都会更新

在训练/动态正则化期间更改 keras 正则化器

我无法找到错误,并且无法完全理解代码本身。

标签: pythontensorflowkerascallbackdynamic-variables

解决方案


推荐阅读