python - 使用 TensorFlow 的多类文本分类。输出问题
问题描述
我正在检查来自面向DataScience多类问题的 BBC 新闻代码,有 5 个标签和 6 个预测输出,我尝试给它提供娱乐新闻
txt = ["last star wars ..... the film is much more dark more emotional. it s much more of a tragedy."]
seq = tokenizer.texts_to_sequences(txt)
padded = pad_sequences(seq, maxlen=max_length)
pred = model.predict(padded)
labels = ['sport', 'business', 'politics', 'tech', 'entertainment']
print(pred, labels[np.argmax(pred)])
它返回
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-20-05bd09f68629> in <module>
4 pred = model.predict(padded)
5 labels = ['sport', 'business', 'politics', 'tech', 'entertainment']
----> 6 print(pred, labels[np.argmax(pred)])
IndexError: list index out of range
我只是在娱乐标签和输出之后添加了testlabel!
[[7.9564927e-07 7.9830606e-06 8.0213253e-04 3.8535093e-04 7.4164674e-02
9.2463899e-01]] test label
模型在这里
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim,
input_length=max_length),
# specify the number of convolutions that you want to learn, their size, and their activation function.
# words will be grouped into the size of the filter in this case 5
tf.keras.layers.Conv1D(128, 5, activation='relu'),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(6, activation='softmax')
])
model.summary()
如果我将最后一个稠密更改为 5,它仍然会在训练中出现以下错误。
InvalidArgumentError: Received a label value of 5 which is outside the valid range of [0, 5). Label values: 3 2 1 5 4 4 2 5 5 3 1 1 3 1 4 1 5 3 2 5 3 3 4 3 5 2 3 2 4 5 3 5
[[node loss_1/dense_3_loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at <ipython-input-25-5196760ed1a6>:2) ]] [Op:__inference_keras_scratch_graph_6926]
Function call stack:
keras_scratch_graph
有人可以向我解释一下吗?这是为什么?,以及如何获得与标签数量相同的输出?
解决方案
推荐阅读
- python - Dockerized Python (Streamlit) 应用程序使用错误的 Python 库文件夹
- google-apps-script - Google Sheets Apps 脚本,从命名范围返回随机项目
- c# - Xamarin.Forms.Platform.iOS.FormsApplicationDelegate OpenUrl 中的访问被拒绝
- xamarin.forms - 使用自定义字体文件将文本转换为 xamarin 形式的图像
- python - 如何在 Python 中为导入的类动态添加属性?
- android - 单击与 Firebase 链接的 RecyclerView 中的记录时如何打开另一个活动
- python - Python JSON TypeError:列表索引必须是整数或切片,而不是 str
- python - 如何在python中设置日志记录级别?
- tensorflow - Keras - TensorFlow2 - TensorRT
- google-apps-script - 当范围有边框时,Google表格脚本onEdit错误范围