python - 标签值超出有效范围
问题描述
X = tf.placeholder(shape=(1, 5, 7), name='inputs', dtype=tf.float32)
X_flat = tf.layers.flatten(X)
y = tf.placeholder(shape=(1), name='outputs', dtype=tf.int32)
hidden1 = tf.layers.dense(X_flat, 150, kernel_initializer=he_init)
hidden2 = tf.layers.dense(hidden1, 50, kernel_initializer=he_init)
logits = tf.layers.dense(hidden2, 1, kernel_initializer=he_init)
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
loss = tf.reduce_mean(xentropy, name="loss")
所以我收到以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Received a label value of 1 which is outside the valid range of [0, 1). Label values: 1
我的标签的整数范围从 0 到 4。我很好奇为什么这不起作用。当我在示例代码中使用 MNIST 时,我认为 y 训练集不在 0 到 1 的范围内,但显然这就是这里发生的情况。
如何使交叉熵函数起作用?是否有任何形式的规范化可以使代码正常工作?
另外,为什么 MNIST 可以使用整数,但这个不能用于标签?
解决方案
编辑
可以肯定的是,这就是我变成的。
X = tf.placeholder(shape=(1, 5, 7), name='inputs', dtype=tf.float32)
X_flat = tf.layers.flatten(X)
y = tf.placeholder(shape=(1), name='outputs', dtype=tf.int32)
hidden1 = tf.layers.dense(X_flat, 150, kernel_initializer=he_init)
hidden2 = tf.layers.dense(hidden1, 50, kernel_initializer=he_init)
logits = tf.layers.dense(hidden2, 5, kernel_initializer=he_init)
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
loss = tf.reduce_mean(xentropy, name="loss")
这对我来说没有问题。
原来的
好的。因此,如果您的y
变量具有 range (0,4)
,那么您的 logit 需要具有形状(batch_size, 5)
(在您的情况下为(1,5)
),因为每个值都是您的模型对特定标签的置信度。
这个:
logits = tf.layers.dense(hidden2, 1, kernel_initializer=he_init)
需要是这样的:
logits = tf.layers.dense(hidden2, 5, kernel_initializer=he_init)
为了做得更好,您可能应该定义这些变量。
num_classes = 5
# ...
logits = tf.layers.dense(hidden2, num_classes, kernel_initializer=he_init)
# ...
推荐阅读
- wpf - 如何正确实现 XAML INotifyPropertyChanged 以防止 GDI 泄漏
- gitlab - Gitlab shell runner 无法上传大于 63KB 的工件:“501 未实现”
- python - 应用 minmaxscaler 后,列中的 Pandas 科学形式值给出错误的输出
- r - 如何通过第二个 col 中的最大值找到第一个 col 值(在矩阵中)
- python - 将文本字符串拆分为 int 和 text 变量
- sql-server - 在 Azure Functions 上使用 Microsoft 报表查看器
- javascript - 如何在 React 中将 connect() 和 withStyles() 用于类组件?
- php - 如何从 Codeigniter 中的返回结果运行多个查询
- scala - 如果我修改指向缓存 rdd 的变量会发生什么?
- laravel - Laravel 项目在共享主机上使用 bitbucket 管道