python - 损失函数提供 NAN 梯度
问题描述
我使用直方图损失作为模型的损失函数,但它提供了 NAN 梯度。代码片段(损失函数):
def histogram_loss(y_true, y_pred):
h_true = tf.histogram_fixed_width( y_true, value_range=(-1., 1.), nbins=20)
h_pred = tf.histogram_fixed_width( y_pred, value_range=(-1., 1.), nbins=20)
h_true = tf.cast(h_true, dtype=tf.dtypes.float32)
h_pred = tf.cast(h_pred, dtype=tf.dtypes.float32)
return K.mean(K.square(h_true - h_pred))
错误信息:
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
为什么我得到值错误(NAN 梯度)?
解决方案
tf.histogram 的梯度是 None...不是微分函数
x = tf.Variable(np.random.uniform(0,10, 100), dtype=tf.float32)
with tf.GradientTape() as tape:
hist = tf.histogram_fixed_width(x, value_range=(-1., 1.), nbins=20)
hist = tf.cast(hist, dtype=tf.dtypes.float32)
grads = tape.gradient(hist, x)
grads
推荐阅读
- javascript - 如何允许多个业务用户通过 API 为 Angular 应用程序自动生成 Paypal Express 客户端 ID
- javascript - 将数组转换为 Observable 的类型
- php - 使用具有相同列名的 3 个表排序
- authentication - SPA-PWA 身份验证的最佳实践是什么?
- java - 有没有办法使用 Spring WebFlux 和 MongoDB 从手动参考中获取文档?
- reactjs - 处理功能组件中的过期状态
- java - 使用扫描仪从控制台读取数字的最佳方法是什么?
- flutter - 如何在颤动中显示工具栏操作批量计数?
- data-structures - 为什么有一长串没有任何底片的 PolygonVertexIndex(在一个 Fbx 文件中)?
- amazon-web-services - 在 Amazon S3 中启用版本控制的成本