首页 > 解决方案 > Tensorflow MeanSquaredError 不适用于一个数字

问题描述

我正在尝试使用张量流均方误差计算网络的损失,但由于某种原因,如果输入张量只有一个数字,它就不起作用。应该怎么做呢。

这是一些代码:

import tensorflow as tf

loss = tf.keras.losses.MeanSquaredError()

a = loss(y_true=tf.constant([1.0, 2.0, 3.0]), y_pred=tf.constant([2.0, 2.0, 4.0]))
print(a)

a = loss(y_true=tf.constant(1.0, dtype=tf.float32), y_pred=tf.constant(2.0, dtype=tf.float32)) #this is where the error occurs.
print(a)

错误

tensorflow.python.framework.errors_impl.InvalidArgumentError:无效的缩减维度(-1 用于 0 维度的输入 [Op:Mean]

标签: pythontensorflowloss-functionmean-square-error

解决方案


推荐阅读