首页 > 解决方案 > 自定义指标访问 X 输入数据

问题描述

我想为拼写校正模型编写一个自定义指标,该模型计算正确替换的以前不正确的字母。它应该被错误地计算为以前正确的替换字母。

这就是为什么我需要访问 x_input 数据。不幸的是,默认情况下只能访问 y_true 和 y_pred。是否有解决方法来获得匹配的 x_input?

是:

def custom_metric(y_true, y_pred):

通缉:

def custom_metric(x_input, y_true, y_pred):

标签: tensorflowkerasmetrics

解决方案


def custom_loss(x_input):
    def loss_fn(y_true, y_pred):
        # Use your x_input here directly
        return #Your loss value
    return loss_fn

model = # Define your model
model.compile(loss=custom_loss(x_input))   
# Values of y_true and y_pred will be passed implicitly by Keras

请记住,x_input在训练模型时,所有批次的输入都将具有相同的值。

编辑

由于您x_input需要每批的数据在损失函数期间进行估计,并且您拥有自己的自定义损失函数,为什么不传递as 标签。像这样的东西:x_input

model.fit(x=x_input, y=x_input)
model.compile(loss=custom_loss())

def custom_loss(y_true, y_pred):
  # y_true corresponds to x_input data

如果您需要 x_input 并且需要传递一些其他数据,您可以这样做:

model.fit(x=x_input, y=[x_input, other_data])
model.compile(loss=custom_loss())

你现在只需要解耦数据y_true


推荐阅读