python-3.x - PyTorch 0.4 LSTM:为什么每个 epoch 都会变慢?
问题描述
我在 GPU 上有一个 PyTorch 0.4 LSTM 的玩具模型。玩具问题的总体思路是,我将单个 3 向量定义为输入,并定义一个旋转矩阵 R。然后,地面实况目标是一系列向量:在 T0,输入向量;在 T1 处,输入向量经过 R 旋转;在 T2 处,输入旋转了 R 两次,等等。(输入在 T1 之后用零输入填充输出长度)
损失是地面实况和输出之间的平均 L2 差异。旋转矩阵、输入/输出数据的构造和损失函数可能并不重要,此处未显示。
没关系,结果非常糟糕:为什么随着每个经过的时期,它会变得越来越慢?!
我在下面展示了 GPU 上的信息,但这也发生在 CPU 上(只有更大的时间。)这个愚蠢的小东西执行十个 epoch 的时间迅速增长。只看数字滚动就很明显了。
epoch: 0, loss: 0.1753, time previous: 33:28.616360 time now: 33:28.622033 time delta: 0:00:00.005673
epoch: 10, loss: 0.2568, time previous: 33:28.622033 time now: 33:28.830665 time delta: 0:00:00.208632
epoch: 20, loss: 0.2092, time previous: 33:28.830665 time now: 33:29.324966 time delta: 0:00:00.494301
epoch: 30, loss: 0.2663, time previous: 33:29.324966 time now: 33:30.109241 time delta: 0:00:00.784275
epoch: 40, loss: 0.1965, time previous: 33:30.109241 time now: 33:31.184024 time delta: 0:00:01.074783
epoch: 50, loss: 0.2232, time previous: 33:31.184024 time now: 33:32.556106 time delta: 0:00:01.372082
epoch: 60, loss: 0.1258, time previous: 33:32.556106 time now: 33:34.215477 time delta: 0:00:01.659371
epoch: 70, loss: 0.2237, time previous: 33:34.215477 time now: 33:36.173928 time delta: 0:00:01.958451
epoch: 80, loss: 0.1076, time previous: 33:36.173928 time now: 33:38.436041 time delta: 0:00:02.262113
epoch: 90, loss: 0.1194, time previous: 33:38.436041 time now: 33:40.978748 time delta: 0:00:02.542707
epoch: 100, loss: 0.2099, time previous: 33:40.978748 time now: 33:43.844310 time delta: 0:00:02.865562
该模型:
class Sequence(torch.nn.Module):
def __init__ (self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTM(3,30)
self.lstm2 = nn.LSTM(30,300)
self.lstm3 = nn.LSTM(300,30)
self.lstm4 = nn.LSTM(30,3)
self.hidden1 = self.init_hidden(dim=30)
self.hidden2 = self.init_hidden(dim=300)
self.hidden3 = self.init_hidden(dim=30)
self.hidden4 = self.init_hidden(dim=3)
self.dense = torch.nn.Linear(30, 3)
self.relu = nn.LeakyReLU()
def init_hidden(self, dim):
return (torch.zeros(1, 1, dim).to(device) ,torch.zeros(1, 1, dim).to(device) )
def forward(self, inputs):
out1, self.hidden1 = self.lstm1(inputs, self.hidden1)
out2, self.hidden2 = self.lstm2(out1, self.hidden2)
out3, self.hidden3 = self.lstm3(out2, self.hidden3)
#out4, self.hidden4 = self.lstm4(out3, self.hidden4)
# This is intended to act as a dense layer on the output of the LSTM
out4 = self.relu(self.dense(out3))
return out4
训练循环:
sequence = Sequence().to(device)
criterion = L2_Loss()
optimizer = torch.optim.Adam(sequence.parameters())
_, _, _, R = getRotation(np.pi/27, np.pi/26, np.pi/25)
losses = []
date1 = datetime.datetime.now()
for epoch in range(1001):
# Define input as a Variable-- each row of 3 is a vector, a distinct input
# Define target directly from input by applicatin of rotation vector
# Define predictions by running input through model
inputs = getInput(25)
targets = getOutput(inputs, R)
inputs = torch.cat(inputs).view(len(inputs), 1, -1).to(device)
targets = torch.cat(targets).view(len(targets), 1, -1).to(device)
target_preds = sequence(inputs)
target_preds = target_preds.view(len(target_preds), 1, -1)
loss = criterion(targets, target_preds).to(device)
losses.append(loss.data[0])
if (epoch % 10 == 0):
date2 = datetime.datetime.now()
print("epoch: %3d, \tloss: %6.4f, \ttime previous: %s\ttime now: %s\ttime delta: %s" % (epoch, loss.data[0], date1.strftime("%M:%S.%f"), date2.strftime("%M:%S.%f"), date2 - date1))
date1 = date2
# Zero out the grads, run the loss backward, and optimize on the grads
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
解决方案
简短的回答:因为我们没有分离隐藏层,因此系统会随着时间不断地反向传播更远和父亲,占用更多内存并需要更多时间。
长答案:这个答案是为了在没有老师强迫的情况下工作。“教师强迫”是指所有时间步长的所有输入都是“基本事实”输入值。相比之下,在没有教师强制的情况下,每个时间步的输入都是前一个时间步的输出,无论该数据可能处于训练状态的多早(因此,多么不稳定)。
这在 PyTorch 中有点手动操作,要求我们不仅要跟踪输出,还要跟踪每一步网络的隐藏状态,以便我们可以将其提供给下一步。分离必须发生,不是在每个时间步,而是在每个序列的开始。一种似乎可行的方法是将“分离”方法定义为序列模型的一部分(它将手动分离所有隐藏层),并在优化器.step()之后显式调用它。
这防止了隐藏状态的逐渐积累,防止了逐渐减速,并且似乎仍在训练网络。
我不能真正保证它,因为我只在玩具模型上使用过它,而不是真正的问题。
注意 1:可能有更好的方法来考虑网络的初始化并使用它来代替手动分离。
注意 2:该loss.backward(retain_graph=True)
语句保留了图表,因为错误消息提示了它。一旦实施分离,该警告就会消失。
我不接受这个答案,希望有知识的人能增加他们的专业知识。
推荐阅读
- c# - 如何正确使用 .NET Core 5.x 和 Blazor 以允许用户执行搜索,然后高效、适当且类似于 Blazor 的显示结果?
- r - 从非表格文件中提取信息并存储为 tibble
- reactjs - 如何处理选择组件的不同类型的值?
- c++ - 如何从一个类访问另一个类的函数数据成员
- python - 如果一个值是字典中的键,则将元组值添加到字典
- python - odoo 14 多公司 - 在 account.move 上更改 company_id
- azure - Azure 服务总线主题上的非持久 JMS 订阅
- python - 是否有充分的理由将 python 代码放在 txt 文件中?
- angular - 动态地将 Angular 组件从一个容器移动到另一个容器
- python - Python中的并行化