首页 > 解决方案 > 将参数传递给 tf.keras.Model 中的 model.predict

问题描述

我有一个需要自定义推理的模型,因此我修改了类的predict_step方法tf.keras.Model。我希望根据某些参数修改推理,有没有一种简单的方法可以让predict方法接收参数并将它们传递给predict_step函数?

就像是:

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.threshold = None

    def call(self, inputs, training=None, mask=None):
        return inputs

    def predict(self, x, threshold=0.5, *args, **kwargs):
        self.threshold = threshold
        return super().predict(x, *args, **kwargs)

    def predict_step(self, data):
        return tf.greater(self(data, training=False), self.threshold)


if __name__ == "__main__":
    x = tf.convert_to_tensor([0.0, 0.55, 0.85, 0.9])
    model = SimpleModel()
    model.predict(x, threshold=0.5)
    model.predict(x, threshold=0.75)

该方法的问题在于,由于predict_step已经创建了阈值,因此阈值不会改变。

更新1:

这似乎有效,但不确定它是否是最好的方法:

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.threshold = None

    def call(self, inputs, training=None, mask=None):
        return inputs

    def predict(self, x, threshold=0.5, *args, **kwargs):
        self.threshold = threshold
        self.predict_function = None
        return super().predict(x, *args, **kwargs)

    def predict_step(self, data):
        return tf.greater(self(data, training=False), self.threshold)


if __name__ == "__main__":
    x = tf.convert_to_tensor([0.0, 0.55, 0.85, 0.9])
    model = SimpleModel()
    pred = model(x)
    pred_1 = model.predict(x, threshold=0.5)
    pred_2 = model.predict(x, threshold=0.75)
    print(pred, pred_1, pred_2, sep="\n")

更新 2: 继我在此处发布的关于predict_step在图形模式下运行的函数的问题之后,似乎解决问题的其他方法是设置self.run_eagerly = True模型的。

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.run_eagerly = True
        self.threshold = None

    def call(self, inputs, training=None, mask=None):
        return inputs

    def predict(self, x, threshold=0.5, *args, **kwargs):
        self.threshold = threshold
        return super().predict(x, *args, **kwargs)

    def predict_step(self, data):
        return tf.greater(self(data, training=False), self.threshold)


if __name__ == "__main__":
    x = tf.convert_to_tensor([0.0, 0.55, 0.85, 0.9])
    model = SimpleModel()
    pred_1 = model.predict(x, threshold=0.5)
    pred_2 = model.predict(x, threshold=0.75)
    print(pred_1, pred_2, sep="\n")

它现在可以在不使用的情况下工作tf.Variable(由于急切模式可能会运行得更慢)。

标签: tensorflowmachine-learningkerastensorflow2.0tf.keras

解决方案


我对你所追求的有一个更好的想法。请参阅这个玩具示例,看看它是否是您所追求的。

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None, mask=None):
        return inputs

    def custom_predict(func):
        def threshold_handler(self, x, threshold=None, *args, **kwargs):
            if threshold is None:
                return func(self, x, *args, **kwargs)
            else:
                vals = func(self, x, *args, **kwargs)
                return list(filter(lambda x: x > threshold, vals))
        return threshold_handler
    
    # fancy way of saying predict = custom_predict(predict)
    # really, it's running custom_predict masquerading as predict
    @custom_predict
    def predict(self, x, *args, **kwargs):
        return super().predict(x, *args, **kwargs)

x = tf.convert_to_tensor([0.0, 0.55, 0.85, 0.9])
model = SimpleModel()
pred = model(x)
pred_0 = model.predict(x, steps=1)
pred_1 = model.predict(x, threshold=0.5, steps=1)
pred_2 = model.predict(x, threshold=0.75, steps=1)
print(pred, pred_0, pred_1, pred_2, sep="\n")

当然,当您可以在自己的 predict 函数中处理逻辑时,装饰器完全是矫枉过正,但也许更高级别的想法会让您自己的想法流向您想要处理的方式。可定制性的另一个选择是使用回调(例如,参见 fastai 或 Pytorch Lightning)。


推荐阅读