pytorch - Pytorch 中的随时间截断反向传播 (BPTT)
问题描述
在 pytorch 中,我通过以下方式启动反向传播(通过时间)来训练 RNN/GRU/LSTM 网络:
loss.backward()
当序列很长时,我想通过时间进行截断反向传播,而不是使用整个序列的正常时间反向传播。
但是我在 Pytorch API 中找不到任何参数或函数来设置截断的 BPTT。我错过了吗?我应该在 Pytorch 中自己编写代码吗?
解决方案
这是一个例子:
for t in range(T):
y = lstm(y)
if T-t == k:
out.detach()
out.backward()
所以在这个例子中,k
是你用来控制你想要展开的时间步长的参数。
推荐阅读
- sql - Oracle SQL 逗号格式?
- tensorflow - external/local_config_mlir/include/mlir/IR/Attributes.h:783:20:内部编译器错误:在assign_temp中,在function.c:968
- django - 递归遍历模型
- angular - Angular没有将表单数据推送到数组
- javascript - 为 JS 变量提取 JSON 键
- c# - .net 慢 SqlDataReader
- python - Django REST Framework - 如何禁用非员工用户的可浏览 API (is_staff=False)
- excel - 为什么在powershell中调用quit()后excel.exe没有关闭?
- wordpress - 使用 ACF 日期 > 今天查询自定义帖子类型
- css - 使用fr时如何让CSS网格列不缩小或扩大