python - 理解RNN的softmax输出层
问题描述
这是 Keras 中的一个简单的 LSTM 模型:
input = Input(shape=(max_len,))
model = Embedding(input_dim=input_dim, output_dim=embed_dim, input_length=max_len)(input)
model = Dropout(0.1)(model)
model = Bidirectional(LSTM(units=blstm_dim, return_sequences=True, recurrent_dropout=0.1))(model)
out =Dense(label_dim, activation="softmax")(model)
这是我将其转换为 Pytorch 模型的尝试:
class RNN(nn.Module):
def __init__(self, input_dim, embed_dim, blstm_dim, label_dim):
super(RNN, self).__init__()
self.embed = nn.Embedding(input_dim, embed_dim)
self.blstm = nn.LSTM(embed_dim, blstm_dim, bidirectional=True, batch_first=True)
self.fc = nn.Linear(2*blstm_dim, label_dim)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), blstm_dim).to(device)
c0 = torch.zeros(2, x.size(0), blstm_dim).to(device)
x = self.embed(x)
x = F.dropout(x, p=0.1, training=self.training)
x,_ = self.blstm(x, (h0, c0))
x = self.fc(x)
return F.softmax(x, dim=1)
# return x
现在运行 Keras 模型会得到以下结果:
Epoch 5/5
38846/38846 [==============================] - 87s 2ms/step - loss: 0.0374 - acc: 0.9889 - val_loss: 0.0473 - val_acc: 0.9859
但是运行 PyTorch 模型会给出:
Train Epoch: 10/10 [6400/34532 (19%)] Loss: 2.788933
Train Epoch: 10/10 [12800/34532 (37%)] Loss: 2.788880
Train Epoch: 10/10 [19200/34532 (56%)] Loss: 2.785547
Train Epoch: 10/10 [25600/34532 (74%)] Loss: 2.796180
Train Epoch: 10/10 [32000/34532 (93%)] Loss: 2.790446
Validation: Average loss: 0.0437, Accuracy: 308281/431600 (71%)
我确保损失和优化器是相同的(交叉熵和 RMSprop)。现在有趣的是,如果我从 PyTorch 模型中删除 softmax(即在代码中使用散列输出,我会得到似乎正确的结果:
Train Epoch: 10/10 [32000/34532 (93%)] Loss: 0.022118
Validation: Average loss: 0.0009, Accuracy: 424974/431600 (98%)
所以这是我的问题:
1)我在上面打印的两个模型是否等效(让我们忽略recurrent_dropout,因为我还没有弄清楚如何在PyTorch中做到这一点)?
2)我在 PyTorch 中的 softmax 输出层做错了什么?
非常感谢!
解决方案
- 我在上面打印的两个模型是否等效(让我们忽略recurrent_dropout,因为我还没有弄清楚如何在PyTorch 中做到这一点)?
除了辍学,我看不出有什么区别。所以它们在结构上应该是完全等价的。
注意:如果您以这种方式使用它,则不必初始化状态(如果您不重用状态)。您可以只转发 LSTM x,_ = self.blstm(x)
- 它会自动用零初始化状态。
- 我在 PyTorch 中的 softmax 输出层做错了什么?
PyTorchtorch.nn.CrossEntropyLoss
已经包含了 softmax:
该标准将
nn.LogSoftmax()
和结合nn.NLLLoss()
在一个类中。
所以它实际上是一个带有logits的CE。我想这使它更有效率。所以你可以在最后省略 softmax 激活。
推荐阅读
- java - dispatchKeyEvent 不会在 Unity 中触发 - 本机 Android 插件
- javascript - 尝试使用handlebars-express将字符串变量传递给html内的javascript范围
- amazon-ec2 - 如何防止 EC2 实例在重启时更改公网 IP 地址?
- typescript - 一般为多种 SVG 元素类型键入 d3 选择
- sql - SQL 查询未按预期执行:Visual Studio
- javascript - 在标题中包含 Bootstrap 4 时,鼠标事件不起作用
- database - 确保应用程序的早期版本在更改数据库设计时不会崩溃的最佳方法是什么?
- javascript - javascript中文件的校验和
- sql - 如何将多行插入特定的 postmeta
- alloy - 排序谓词不可满足