machine-learning - 了解 LSTM 架构中的密集层(标签和 logits)
问题描述
我正在研究这个笔记本——https: //github.com/aamini/introtodeeplearning/blob/master/lab1/solutions/Part2_Music_Generation_Solution.ipynb——我们在其中使用嵌入层、LSTM 和最终密集层 w/ softmax 生成音乐。
但是,我对我们如何计算损失有点困惑。据我了解,在这个笔记本中(在 compute_loss() 中),在任何给定的批次中,我们都将预期标签(即注释本身)与 logits(即来自密集层的预测)进行比较。但是,这些预测不应该是概率分布吗?我们什么时候真正选择我们预测的标签?
进一步澄清我的问题:如果我们的标签的形状是(batch_size,# of time step),而我们的 logits 的形状是(batch_size,# of time step,vocab_size),那么在compute_loss()中的哪个点函数我们实际上是在为每个时间步选择一个标签吗?
解决方案
简短的回答是 Keras 损失函数可以满足sparse_categorical_crossentropy()
您的所有需求。
在 LSTM 模型的每个时间步长,该损失函数中的顶层密集层和 softmax 函数共同生成模型词汇表的概率分布,在本例中是音符。假设词汇表包含音符 A、B、C、D。然后生成的一个可能的概率分布是:[0.01, 0.70, 0.28, 0.01]
,这意味着模型将大量概率放在音符 B(索引 1)上,如下所示:
Label: A B C D
---- ---- ---- ---- ----
Index: 0 1 2 3
---- ---- ---- ---- ----
Prob: 0.01 0.70 0.28 0.01
假设真正的音符应该是 C,它由数字 2 表示,因为它位于分布数组中的索引 2 处(索引从 0 开始)。要测量预测分布和真实值分布之间的差异,请使用该sparse_categorical_crossentropy()
函数生成表示损失的浮点数。
更多信息可以在这个 TensorFlow 文档页面上找到。在该页面上,他们有示例:
y_true = [1, 2]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
您可以在该示例中看到有一批两个实例。对于第一个实例,真实标签是1
,预测分布是[0.05, 0.95, 0]
,对于第二个实例,真实标签是2
,而预测分布是[0.1, 0.8, 0.1]
。
此函数在 2.5 节的 Jupyter Notebook 中使用:
为了在这个分类任务上训练我们的模型,我们可以使用一种形式的交叉熵损失(负对数似然损失)。具体来说,我们将使用 sparse_categorical_crossentropy 损失,因为它利用整数目标进行分类分类任务。我们将希望使用真实目标(标签)和预测目标(logits)来计算损失。
所以直接回答你的问题:
据我了解,在这个笔记本中(在 compute_loss() 中),在任何给定的批次中,我们都将预期标签(即注释本身)与 logits(即来自密集层的预测)进行比较。
是的,你的理解是正确的。
但是,这些预测不应该是概率分布吗?
是的,他们是。
我们什么时候真正选择我们预测的标签?
它在sparse_categorical_crossentropy()
函数内部完成。如果您的分布是[0.05, 0.95, 0]
,那么这隐含地意味着该函数预测索引 0 的概率为 0.05,索引 1 的概率为 0.95,索引 3 的概率为 0.0。
进一步澄清我的问题:如果我们的标签的形状是(batch_size,# of time step),而我们的 logits 的形状是(batch_size,# of time step,vocab_size),那么在compute_loss()中的哪个点函数我们实际上是在为每个时间步选择一个标签吗?
它在那个函数里面。
推荐阅读
- amazon-web-services - aws_shield_protection Terraform
- python - 删除N行文件
- mysql - 将字符串布尔值隐式转换为整数失败
- haskell - 使用模式匹配重载函数?
- mysql - 如何返回 GROUP_CONCAT 输出的 mysql JSON 数组?
- android - 是状态
仅针对 Android 中的 @Composable 设计? - sql - 如何在 SQL Server 中执行 SELECT * 以同时获取具有表名和字段的别名
- bash - 将设备挂载到存储在变量中的挂载点
- javascript - 向对象添加属性数组
- passport.js - 使用环回 4 的身份验证 SAML