首页 > 解决方案 > 在 vanilla LSTM 中表示时空图像数据的最佳方法是什么

问题描述

我正在研究视频帧分割预测,我想开始使用 vanilla LSTM 作为基线,我知道它不会得到好的结果。

在我目前的方法中,我在第一个通道中有原始输入图像帧,而分割帧是第二个通道,就像这个数据表示一样,然后将输入展平为一维数组。我应该如何表示我的时空,以便它适用于普通 LSTM?

这是我在 Pytorch 中使用的香草 LSTM 的片段

class ImageLSTM(nn.Module):
def __init__(self,  n_inputs:int=49, 
                    n_outputs:int=4096, 
                    n_hidden:int=256, 
                    n_layers:int=1, 
                    bidirectional:bool=False):
    """
    Takes a 1D flatten images.
    """
    super(ImageLSTM, self).__init__()
    self.n_inputs   = n_inputs
    self.n_hidden   = n_hidden
    self.n_outputs  = n_outputs
    self.n_layers   = n_layers
    self.bidirectional = bidirectional
    self.lstm       = nn.LSTM(  input_size=self.n_inputs, 
                                hidden_size=self.n_hidden, 
                                batch_first=False, 
                                num_layers=self.n_layers, 
                                bidirectional=self.bidirectional)

    if (self.bidirectional):
        self.FC         = nn.Sequential(
                                        nn.Linear(self.n_hidden*2, self.n_outputs),
                                        nn.Dropout(p=0.5),
                                        nn.Sigmoid()
                                            )

    else:
        self.FC         = nn.Sequential(
                                      
                                        nn.Linear(self.n_hidden, self.n_outputs),
                                        nn.Dropout(p=0.5),
                                        nn.Sigmoid()
                                        )

    
def init_hidden(self, x, device=None): # input 4D tensor: (batch size, channels, width, height)
    # initialize the hidden and cell state to zero
    # vectors:(number of layer, sequence length, number of hidden nodes)
    if (self.bidirectional):
        h0 = torch.zeros(2*self.n_layers,  1, self.n_hidden)
        c0 = torch.zeros(2*self.n_layers,  1, self.n_hidden)
    else:
        h0 = torch.zeros(self.n_layers,  1, self.n_hidden)
        c0 = torch.zeros(self.n_layers,  1, self.n_hidden)

    if device is not None:
        h0 = h0.to(device)
        c0 = c0.to(device)
    self.hidden = (h0,c0)

    
def forward(self, X): # X: tensor of shape (batch_size, channels, width, height)
    # forward propagate LSTM
    lstm_out, self.hidden = self.lstm(X, self.hidden) # lstm_out: tensor of shape (seq_length, batch_size, hidden_size)
    out = self.FC(lstm_out[:, -1, :])
    return out

标签: pythoncomputer-visionpytorchtorchvision

解决方案


推荐阅读