python - 打印损失函数 tensorflow 2.0 的所有项
问题描述
我正在定义一个自定义损失函数。例如,让我们以loss function = L1 loss + L2 loss.
When I domodel.fit_generator()
为例,在每批之后打印整体损失函数。但我想查看 和 的各个L1 loss
值L2 loss
。我怎样才能做到这一点?我想知道个别术语的价值以了解它们的相对比例。
tf.print(l1_loss, output_stream=sys.stdout)
正在抛出异常说tensorflow.python.eager.core._FallbackException: This function does not handle the case of the path where all inputs are not already EagerTensors.
。甚至
tf.print('---')
只是---
在开始时打印,而不是每批都打印。tf.keras.backend.print_tensor(l1_loss)
没有打印任何东西
解决方案
没有看到你的代码,我只能猜测你没有用@tf.function
装饰器装饰你的自定义损失函数。
import numpy as np
import tensorflow as tf
@tf.function # <-- Be sure to use this decorator.
def custom_loss(y_true, y_pred):
loss = tf.reduce_mean(tf.math.abs(y_pred - y_true))
tf.print(loss) # <-- Use tf.print(), instead of print(). You can print not just 'loss', but any TF tensor in this function using this approach.
return loss
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=[8]))
model.compile(loss=custom_loss, optimizer="sgd")
x_data = tf.data.Dataset.from_tensor_slices([np.ones(8)] * 100)
y_data = tf.data.Dataset.from_tensor_slices([np.ones(1)] * 100)
data = tf.data.Dataset.zip((x_data, y_data)).batch(2)
model.fit_generator(data, steps_per_epoch=10, epochs=2)
输出如下所示,它告诉您逐批的损失值。
Epoch 1/2
0.415590227 1/10 [==>...........................] - ETA: 0s - loss: 0.41560.325590253
0.235590339
0.145590425
0.0555904508
0.034409523
0.0555904508
0.034409523
0.0555904508
0.034409523 10/10 [==============================] - 0s 11ms/step - loss: 0.1392 Epoch 2/2
0.0555904508 1/10 [==>...........................] - ETA: 0s - loss: 0.05560.034409523
0.0555904508
0.034409523
0.0555904508
0.034409523
0.0555904508
0.034409523
0.0555904508
0.034409523 10/10 [==============================] - 0s 498us/step - loss: 0.0450
推荐阅读
- node.js - Autodesk Forge - 将文件作为块上传到 Node JS 中的 BIM 360 存储时出现 504 网关超时
- c++ - (重新)在 boost 1.7x boost_system 库中导出 boost::system::generic_category() ?
- javascript - 未捕获的 SyntaxError:输入 js 的意外结束
- sed - 如何替换以文件中某些内容开头的行的第二次出现
- typescript - 如何使用 storybook 和 lerna 在 NPM 上创建和发布多个 Vuejs 组件?
- angular - 模板驱动的表单不起作用..“错误:未找到名称 'ngForm' 的导出”
- kivy - 无法从 Kivy 中的 io.BytesIO 加载图像
- css - 如何在css中使用flex进行溢出自动
- postgresql - 字段的长度对于正则表达式查找是否重要?索引是否有助于加快正则表达式查询?
- ruby-on-rails - 无法在 bitbucket 上使用 ssh 进行连接