首页 > 解决方案 > 通过回调 api 动态更新数据集过滤器中使用的 tensorflow 变量

问题描述

def myfilter(x, my_var):
    return tf.equal(x['vars'], my_var)


data = tf.data.TFRecordDataset(tf.io.match_filenames_once('part-*'))
my_var  = tf.Variable(1, trainable=False, name='my_var', dtype=tf.int64)

data = data.map(parsing_func, num_parallel_calls=multiprocessing.cpu_count() - 1)
data= data.filter(lambda x : myfilter(x, my_var) )
data = data.batch(batch_size=32)

在这里,使用 static my_var,我可以过滤数据。但是,我想继续从 [1, 2, .... n] 更新 var 值。关于在培训期间如何做到这一点的任何想法?

我正在尝试这样的事情:

class CustomVarScheduler(tf.keras.callbacks.Callback):


    def __init__(self, my_var):
        super(CustomPhaseScheduler, self).__init__()
        self.var = my_var

    def on_epoch_begin(self, epoch, logs=None):
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.my_var, tf.math.add(self.my_var, 1))
        print("\nEpoch %05d: my_var is %6.4f." % (epoch, self.my_var))

无法使其正常工作。有什么帮助吗?谢谢。

标签: tensorflowtensorflow2.0tensorflow-datasets

解决方案


推荐阅读