python - 在 keras 中创建自定义损失函数,合并数据集中的特征
问题描述
我想为 Keras 深度学习回归模型创建一个自定义损失函数。对于自定义损失函数,我想使用数据集中的一个特征,但我没有使用该特定特征作为模型的输入。
我的数据如下所示:
X | Y | feature
---|-----|--------
x1 | y1 | f1
x2 | y2 | f2
模型的输入是 X,我想使用模型预测 Y。我想要类似以下的东西作为损失函数:
def custom_loss(feature):
def loss(y_true, y_pred):
root_mean__square(y_true - y_pred) + std(y_pred - feature)
return loss
我不能像上面那样使用包装函数,因为特征值取决于训练和测试批次,因此不能在模型编译时传递给自定义损失函数。如何使用数据集中的附加功能来创建自定义损失函数?
编辑:
我根据这个线程上的答案做了以下事情。当我使用此模型进行预测时,它是对“Y”还是 Y 和附加特征的组合进行预测?我想确定,因为 model.fit() 将 'Y' 和 'feature' 作为 y 来训练,但 model.predict() 只给出一个输出。如果预测是 Y 和附加特征的组合,我怎样才能只提取 Y?
def custom_loss(data, y_pred):
y_true = data[:, 0]
feature = data[:, 1]
return K.mean(K.square((y_pred - y_true) + K.std(y__pred - feature)))
def create_model():
# create model
model = Sequential()
model.add(Dense(5, input_dim=1, activation="relu"))
model.add(Dense(1, activation="linear"))
(train, test) = train_test_split(df, test_size=0.3, random_state=42)
model = models.create_model(train["X"].shape[1])
opt = Adam(learning_rate=1e-2, decay=1e-3/200)
model.compile(loss=custom_loss, optimizer=opt)
model.fit(train["X"], train[["Y", "feature"]], validation_data=(test["X"], test[["Y", "feature"]]), batch_size = 8, epochs=90)
predY = model.predict(test["X"]) # what does the model predict here?
解决方案
首先在 fit 函数中检查输入 Y 的数据结构,看看它是否与您关注的线程中的答案具有相同的结构,如果您做的事情完全正确,那么它应该可以解决您的问题。
当我使用此模型进行预测时,它是对“Y”还是 Y 和附加特征的组合进行预测?
该模型将具有与您定义的完全相同的输出形状,在您的情况下,因为模型输出是Dense(1, activation="linear")
,所以它具有输出形状y_pred.shape == (batchsize, 1)
,仅此而已,您可以确定这一点,将其打印出来tf.print(y_pred)
以供您自己查看
我也不知道是不是你的打字错误,你的 custom_loss 函数的最后一行应该是:
return K.mean(K.square((y_pred - y_true) + K.std(y_pred - feature)))
代替
return K.mean(K.square((y_pred - y_true) + K.std(y__pred - feature)))
推荐阅读
- jenkins-pipeline - 使用jenkins在kubernetes集群上部署helm图表
- javascript - Angular Datatable:如何在悬停时获取列数据的工具提示
- python - 如何在使用 Beautiful Soup 和 Requests 按下按钮后获取 HTML 更改
- android - 在kotlin中以有效的方式查找第一个出现的索引值
- c++ - 通过基类引用访问派生类成员
- haproxy - 如何在 haproxy 中捕获转换后的后端 uri?
- php - 如何在 Symfony FormType 中为字段显示不同的值
- macos - 在 Mac 上安装 maven
- debugging - 调试 Apache2 RewriteRule(带有标题)?
- javascript - Next-Auth 管理员权限