首页 > 解决方案 > Pytorch:我应该何时以及为什么要使用缓冲区?

问题描述

我正在使用缓冲区来传递 LSTM 网络的隐藏状态。

def __init__(self, model, hidden_state1=None, ...somethine else...):
    self.register_buffer('hidden_state1', hidden_state1)
    self.hidden_state1 = hidden_state1
    ....#other codes

为了避免错误:

RuntimeError: Trying to backward through the graph a second time, 
but the buffers have already been freed. 
Specify retain_graph=True when calling backward the first time.

.clone().detach()用来分离缓冲区。

由于无论如何我都需要手动分离它们,我还需要在 Pytorch 中使用缓冲区而不是普通参数吗?

带有“requires_grad=False”的普通参数是否足以替代缓冲区的使用?

(其实不知道这样传递隐藏状态是不是一个好办法)

标签: pytorchlstmrecurrent-neural-network

解决方案


推荐阅读