python - 收到的标签值 17 超出 [0, 12) 的有效范围 - Keras Python
问题描述
Keras 对我来说很陌生。我正在尝试将一些编程付诸实践。
数据形状如下:
Train shape X: (249951, 5, 52) y (249951,)
Test shape X: (263343, 5, 52) y (263343,) # Do not confuse with the distribution, it is juts toy example
我的日期包含十二个标签。keras CNN架构如下:
def get_compiled_model():
# Make a simple 2-layer densely-connected neural network.
inputs = keras.Input(shape=(260,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(12)(x) # 12 classes
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
现在,我将 12 个神经元馈送到输出层,因为我的数据包含 12 个类。但是,会显示以下错误消息:
Use `tf.data.Iterator.get_next_as_optional()` instead.
2255/7811 [=======>......................] - ETA: 18s - loss: 0.1109 - sparse_categorical_accuracy: 0.99452021-04-19 16:32:33.591493: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at sparse_xent_op.cc:90 : Invalid argument: Received a label value of 17 which is outside the valid range of [0, 12). Label values: 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 17 17 17 17 17
Traceback (most recent call last):
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\contextlib.py", line 131, in __exit__
self.gen.throw(type, value, traceback)
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2804, in variable_creator_scope
yield
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 807, in _call
return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 2829, in __call__
return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 1843, in _filtered_call
return self._call_flat(
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 1923, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\function.py", line 545, in call
outputs = execute.execute(
File "C:\Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Received a label value of 17 which is outside the valid range of [0, 12). Label values: 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 17 17 17 17 17
[[node sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at Users\Nafees Ahmed\AppData\Local\Programs\Python\Python38\lib\threading.py:932) ]] [Op:__inference_train_function_846]
Function call stack:
train_function
主要错误: tensorflow.python.framework.errors_impl.InvalidArgumentError:收到的标签值 17 超出了 [0, 12) 的有效范围。标签值: 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 17 17 17 17 17
解决方案
如果您的标签(y)是数字,它将不起作用。您需要将其转换为二进制数据,因为它是多类二进制问题。也许你可以使用下面的来做到这一点。
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)
num_classes
是 12,在你的情况下。但是,对我来说,12 节课听起来太多了。这真的是多类而不是多标签吗?
推荐阅读
- javascript - 将音频文件保存到 mongodb
- php - 使用php计算折旧
- mongoose-schema - 如何将新的博客文章保存给使用猫鼬护照的用户?
- excel - 如何在Excel中的不同字符之间进行子串?
- java - jsp中如何获取特定属性的下一个索引值
- wordpress - Woocommerce 变量不再可见
- r - 为什么我不能阅读使用 rvest 进行网页抓取的可点击链接?
- node.js - Shop pay 与 Vue 前端的集成
- api - 无法获取数据 API Fluter
- java - 日志库以抽象 systemd 日志、Windows 事件日志、macOS 日志(控制台)?