python - 多类别情况下最后一个密集输出层的单位
问题描述
我目前正在研究这个colab。任务是将句子分类到某个类别。所以我们有一个多类别问题,而不是二元问题,比如根据某些评论句子预测评论的情绪(正面/负面)。如果有多个类别,我认为最后一层中的单元/神经元数量必须与我想要预测的类别数量相匹配。所以当我有一个二元问题时,我使用一个神经元,表示 0 或 1。当我有 5 个类时,我需要 5 个单元。我也那么认为。
但是,在colab的代码中有以下内容:
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, input_length=max_length),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(24, activation='relu'),
tf.keras.layers.Dense(6, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()
当我在这个 colab 中运行 model.fit 代码部分时,它确实有效。但我不明白。当我检查
print(label_tokenizer.word_index)
print(label_tokenizer.word_docs)
print(label_tokenizer.word_counts)
这给
{'sport': 1, 'business': 2, 'politics': 3, 'tech': 4, 'entertainment': 5}
defaultdict(<class 'int'>, {'tech': 401, 'business': 510, 'sport': 511, 'entertainment': 386, 'politics': 417})
OrderedDict([('tech', 401), ('business', 510), ('sport', 511), ('entertainment', 386), ('politics', 417)])
很明显是5个班。但是,当我将模型调整为tf.keras.layers.Dense(5, activation='softmax')
并运行 model.fit 命令时,它不起作用。准确率始终为 0。
为什么这里是 6 而不是 5?
解决方案
是 6,因为编码目标在 [1,5] 中,但 keras sparse_cat 从 0 创建单热标签,因此它创建另一个无用标签 (0)。
要使用Dense(5, activation='softmax')
,您只需执行 y-1 即可在 [0,4] 中获取标签并从 0 开始获取它们
按照 colab 链接,您可以更改:
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, input_length=max_length),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(24, activation='relu'),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(train_padded, training_label_seq-1, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq-1), verbose=2)
推荐阅读
- vba - 如何计算访问vba中的页面加载时间
- javascript - 如何使用javascript在txt文件的新行中打印对象数组的每个对象?
- python - Python 代理如何使用 ElasticAPM 跟踪各种计数器/值随时间的演变?
- android - 为什么“C”!=“C”?
- java - 为什么我的 JavaFX 按钮有奇怪的符号而不是文本?
- python - 使用beautifulsoup4,Python在html标签内查找链接
- python - Flask-WTF 自定义验证 FileField 关于图片像素大小
- markdown - 如何更改 Hugo 网站中头像持有者的样式?
- angular - 角 | 是否有可能导致 *ngIf 只运行一次?
- javascript - 在javascript中访问同一对象中的对象值