python - LSTM 训练期间的持续损失 - PyTorch
问题描述
我正在尝试实现一个 LSTM 网络来预测句子中的下一个单词。这是我第一次构建神经网络,我对我在互联网上找到的所有信息感到困惑。
我正在尝试使用以下架构:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class WordLSTM(nn.Module):
def __init__(self, vocabulary_size, embedding_dim, hidden_dim):
super().__init__()
# Word embeddings
self.encoder = nn.Embedding(vocabulary_size, embedding_dim)
# LSTM input dim is embedding_dim, output dim is hidden_dim
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
# Linear layer to map hidden states to vocabulary space
self.decoder = nn.Linear(hidden_dim, vocabulary_size)
def forward(self, sentence):
encoded = self.encoder(sentence)
output, _ = self.lstm(
encoded.view(len(sentence), 1, -1))
decoded = self.decoder(output)
word_scores = F.softmax(decoded, dim=1)
return word_scores[-1].view(1, -1)
我用我的数据集中的所有句子创建了一个字典,每个单词都用字典中的索引进行编码。它们后面是编码的下一个单词(目标向量)。这是我正在尝试使用的一堆训练示例:
[tensor([39]), tensor([13698])],
[tensor([ 39, 13698]), tensor([11907])],
[tensor([ 39, 13698, 11907]), tensor([70])]
我在训练期间一次通过一个句子,所以我的批量大小始终为 1。
NUM_EPOCHS = 100
LEARNING_RATE = 0.0005
rnn = WordLSTM(vocab_size, 64, 32)
optimizer = optim.SGD(rnn.parameters(), lr=LEARNING_RATE)
for epoch in range(NUM_EPOCHS):
training_example = generate_random_training_example(training_ds)
optimizer.zero_grad()
for sentence, next_word in training_example:
output = rnn(sentence)
loss = F.cross_entropy(output, next_word)
loss.backward()
optimizer.step()
print(f"Epoch: {epoch}/{NUM_EPOCHS} Loss: {loss:.4f}")
但是,当我开始训练时,损失不会随时间变化:
Epoch: 0/100 Loss: 10.3929
Epoch: 1/100 Loss: 10.3929
Epoch: 2/100 Loss: 10.3929
Epoch: 3/100 Loss: 10.3929
Epoch: 4/100 Loss: 10.3929
Epoch: 5/100 Loss: 10.3929
Epoch: 6/100 Loss: 10.3929
我已经尝试将optimizer.zero_grad()
andoptimizer.step()
放在不同的地方,但也没有帮助。
在这种情况下可能是什么问题?我是以错误的方式计算损失,还是以错误的格式传递张量?
解决方案
删除F.softmax
。你做 log_softmax(softmax(x))。
该标准将 nn.LogSoftmax() 和 nn.NLLLoss() 组合在一个类中。
import torch as t
class Net(t.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.emb = t.nn.Embedding(100, 8)
self.lstm = t.nn.LSTM(8, 16, batch_first=True)
self.linear = t.nn.Linear(16, 100)
def forward(self, x):
x = self.emb(x)
x, _ = self.lstm(x)
x = self.linear(x[:, -1])
#x = t.nn.Softmax(dim=1)(x)
return x
t.manual_seed(0)
net = Net()
batch_size = 1
X = t.LongTensor(batch_size, 5).random_(0, 100)
Y = t.LongTensor(batch_size).random_(0, 100)
optimizer = t.optim.Adam(net.parameters())
criterion = t.nn.CrossEntropyLoss()
for epoch in range(10):
optimizer.zero_grad()
output = net(X)
loss = criterion(output, Y)
loss.backward()
optimizer.step()
print(loss.item())
4.401515960693359 4.389760494232178 4.377873420715332 4.365848541259766 4.353675365447998 4.341339588165283 4.328824520111084 4.316114902496338 4.303196430206299 4.2900567054748535
未注释t.nn.Softmax
:
4.602912902832031 4.6027679443359375 4.602619171142578 4.6024675369262695 4.602311611175537 4.602152347564697 4.601987361907959 4.601818084716797 4.6016435623168945 4.601463794708252
在评估期间使用 softmax:
net.eval()
t.nn.Softmax(dim=1)(net(X[0].view(1,-1)))
张量([[0.0088, 0.0121, 0.0098, 0.0072, 0.0085, 0.0083, 0.0083, 0.0108, 0.0127, 0.0090、0.0094、0.0082、0.0099、0.0115、0.0094、0.0107、0.0081、0.0096、 0.0087、0.0131、0.0129、0.0127、0.0118、0.0107、0.0087、0.0073、0.0114、 0.0076、0.0103、0.0112、0.0104、0.0077、0.0116、0.0091、0.0091、0.0104、 0.0106、0.0094、0.0116、0.0091、0.0117、0.0118、0.0106、0.0113、0.0083、 0.0091, 0.0076, 0.0089, 0.0076, 0.0120, 0.0107, 0.0139, 0.0097, 0.0124, 0.0096、0.0097、0.0104、0.0128、0.0084、0.0119、0.0096、0.0100、0.0073、 0.0099, 0.0086, 0.0090, 0.0089, 0.0098, 0.0102, 0.0086, 0.0115, 0.0110, 0.0078、0.0097、0.0115、0.0102、0.0103、0.0107、0.0095、0.0083、0.0090、 0.0120, 0.0085, 0.0113, 0.0128, 0.0074, 0.0096, 0.0123, 0.0106, 0.0105, 0.0101、0.0112、0.0086、0.0105、0.0121、0.0103、0.0075、0.0098、0.0082、 0.0093]],grad_fn=)
推荐阅读
- sql - 列出在超过 2 名员工的团队中从平均收入中获得平均值(准确率高达 30%)的员工(姓名、base_salary)
- go - 为什么我不能通过 *interface{} 参数传递结构指针?
- botframework - 什么是 botframework 安全模型?
- amazon-web-services - 子域的 AWS 证书通配符
- elasticsearch - 使用 JestClient 用 java 编写的 lambda 函数对 AWS Elasticsearch 的第一次查询的响应非常慢
- javascript - document.addEventListener("click") 无法正常工作
- nginx - EC2 上的 NGINX 未启用加载站点
- javascript - 自定义 Web 可访问性错误消息,当在提交按钮上使用键盘按下 Enter 并且输入字段为空白时
- python - 在熊猫数据框中按条件分组
- windows - 使用用户输入创建文件夹结构的批处理脚本