python - 使用 sparse_categorical_crossentropy 时如何修复形状不匹配错误
问题描述
我想知道为什么在输入两个具有相同形状和相同内容的张量时,我在调用 sparse_categorical_crossentropy 损失函数时遇到了错误。我有以下代码:
labels = np.array([0,1,2])
logits = np.array([0,1,2])
tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
我收到以下错误:
/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/losses.py in sparse_categorical_crossentropy(y_true, y_pred, from_logits, axis)
976 def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
977 return K.sparse_categorical_crossentropy(
--> 978 y_true, y_pred, from_logits=from_logits, axis=axis)
979
980
/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/backend.py in sparse_categorical_crossentropy(target, output, from_logits, axis)
4574 else:
4575 res = nn.sparse_softmax_cross_entropy_with_logits_v2(
-> 4576 labels=target, logits=output)
4577
4578 if update_shape and output_rank >= 3:
/tensorflow-2.1.0/python3.6/tensorflow_core/python/ops/nn_ops.py in sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name)
3535 """
3536 return sparse_softmax_cross_entropy_with_logits(
-> 3537 labels=labels, logits=logits, name=name)
3538
3539
/tensorflow-2.1.0/python3.6/tensorflow_core/python/ops/nn_ops.py in sparse_softmax_cross_entropy_with_logits(_sentinel, labels, logits, name)
3451 "should equal the shape of logits except for the last "
3452 "dimension (received %s)." % (labels_static_shape,
-> 3453 logits.get_shape()))
3454 # Check if no reshapes are required.
3455 if logits.get_shape().ndims == 2:
ValueError: Shape mismatch: The shape of labels (received (3,)) should equal the shape of logits except for the last dimension (received (1, 3)).
解决方案
labels = tf.convert_to_tensor(np.asarray([0,1,2]))
logits = tf.convert_to_tensor(np.asarray([[1,0,0],[0,1,0],[0,0,1]]), dtype=tf.float32)
tf.keras.losses.sparse_categorical_crossentropy(labels, logits)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([2.3841855e-07, 2.3841855e-07, 2.3841855e-07], dtype=float32)>
这sparse
意味着您提供的整数labels
实际上代表类标签,这意味着sparse_categorical_crossentropy_loss
会将它们转换为一种热编码(如 logits)。不知道from_logits
代表什么,对不起。
注意:您必须这样做tf.convert_to_tensor
是为了避免tensorflow
.
推荐阅读
- java - 我可以在 JPA 实体构造函数中使用 use 'System.IdentityHashCode()' 吗?
- spring-boot - Livereload 不会在 chrome 上为 spring 重新加载静态文件
- kubernetes - ClusterIP 是否能够负载平衡多个 pod 的流量?
- php - 内部服务器从 ajax 发送导致的错误
- amazon-web-services - Lamba 执行角色 - 调用 putObject 时拒绝访问
- vb.net - Clipboard.GetText() 函数不适用于组框内的文本框控件
- mysql - MySQL 在 Raspberry Pi 3 的清单条目中没有与 linux/arm/v7 匹配的清单
- python - 如何使用python从驱动器的文件夹中检索文件?
- java - java.util.Calendar getMaximum 最后一天返回错误
- node.js - 安装 npm 包时的 UNABLE_TO_VERIFY_LEAF_SIGNATURE