tensorflow - Keras 模型 - 在自定义损失函数中获取输入
问题描述
我在使用 Keras 自定义损失函数时遇到问题。我希望能够以numpy 数组的形式访问真相。因为它是一个回调函数,我想我不是在急切执行中,这意味着我无法使用 backend.get_value() 函数访问它。我也尝试了不同的方法,但总是回到这个“张量”对象不存在的事实。
我需要在自定义损失函数中创建一个会话吗?
我正在使用最新的 Tensorflow 2.2。
def custom_loss(y_true, y_pred):
# 4D array that has the label (0) and a multiplier input dependant
truth = backend.get_value(y_true)
loss = backend.square((y_pred - truth[:,:,0]) * truth[:,:,1])
loss = backend.mean(loss, axis=-1)
return loss
model.compile(loss=custom_loss, optimizer='Adam')
model.fit(X, np.stack(labels, X[:, 0], axis=3), batch_size = 16)
我希望能够接触到真相。它有两个组件(标签、乘数,每个项目都不同。我看到了一个依赖于输入的解决方案,但我不确定如何访问该值。基于输入数据的 Keras 中的自定义损失函数
解决方案
我认为您可以通过启用来做到这一点,run_eagerly=True
如下model.compile
所示。
model.compile(loss=custom_loss(weight_building, weight_space),optimizer=keras.optimizers.Adam(), metrics=['accuracy'],run_eagerly=True)
我认为您还需要更新custom_loss
,如下所示。
def custom_loss(weight_building, weight_space):
def loss(y_true, y_pred):
truth = backend.get_value(y_true)
error = backend.square((y_pred - y_true))
mse_error = backend.mean(error, axis=-1)
return mse_error
return loss
我用一个简单的mnist
数据来展示这个想法。请看这里的代码。
推荐阅读
- python - Python 导入 pyLDAvis 模块 - 弃用警告
- asp.net - ASP.NET Web 表单应用程序文件夹未部署到 AWS Elastic Beanstalk 应用程序
- c++ - 带有 g++ 的 SDL:stdint.h:没有这样的文件或目录
- python - Python 导入名称冲突
- cesium - Cesium:是否有可能使拖尾路径更细?
- gridview - DotVVM:在 GridViewTemplateColumn ContentTemplate 中使用自定义绑定
- java - 无法解析符号“之前” - Intellij jUnit
- python - 我怎样才能分别绘制高斯与 python 拟合?
- reactjs - React Redux,如何处理表单项的创建和重定向?
- sql - SQL Server:检查约束中的用户定义函数