首页 > 解决方案 > Pytorch 中的随时间截断反向传播 (BPTT)

问题描述

在 pytorch 中,我通过以下方式启动反向传播(通过时间)来训练 RNN/GRU/LSTM 网络:

loss.backward()

当序列很长时,我想通过时间进行截断反向传播,而不是使用整个序列的正常时间反向传播。

但是我在 Pytorch API 中找不到任何参数或函数来设置截断的 BPTT。我错过了吗?我应该在 Pytorch 中自己编写代码吗?

标签: pytorchbackpropagationtruncated

解决方案


这是一个例子:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()

所以在这个例子中,k是你用来控制你想要展开的时间步长的参数。


推荐阅读