首页 > 解决方案 > 将 softmax 层添加到 LSTM 网络“冻结”输出

问题描述

我一直在尝试通过 PyTorch 上的个人项目自学 RNN 的基础知识。我想生成一个能够预测序列中下一个字符的简单网络(想法主要来自这篇文章http://karpathy.github.io/2015/05/21/rnn-effectiveness/但我最想做的我自己的东西)。

我的想法是这样的:我采用一批大小为n的B输入序列(np整数数组),对它们进行热编码,然后将它们通过由几个 LSTM 层、一个全连接层和一个 softmax 单元组成的网络。然后我将输出与目标序列进行比较,这些目标序列是前移一步的输入序列。

我的问题是,当我包含 softmax 层时,每个批次的每个时期的输出都是相同的。当我不包括它时,网络似乎可以适当地学习。我不知道出了什么问题。

我的实现如下:

class Model(nn.Module):
def __init__(self, one_hot_length, dropout_prob, num_units, num_layers):
    
    super().__init__()

    self.LSTM = nn.LSTM(one_hot_length, num_units, num_layers, batch_first = True, dropout = dropout_prob)
    
    self.dropout = nn.Dropout(dropout_prob)
    self.fully_connected = nn.Linear(num_units, one_hot_length)
    self.softmax = nn.Softmax(dim = 1)
    # dim = 1 as the tensor is of shape (batch_size*seq_length, one_hot_length) when entering the softmax unit

def forward_pass(self, input_seq, hc_states):

    output, hc_states = self.LSTM (input_seq, hc_states)
    output = output.view(-1, self.num_units)
    output = self.fully_connected(output) 
    # I simply comment out the next line when I run the network without the softmax layer
    output = self.softmax(output)
    return output, hc_states

one_hot_length是我的字符字典的大小(~200,也是一个热编码向量的大小) num_units是 LSTM 单元中隐藏单元的数量,num_layers是网络中 LSTM 层的数量。

训练循环的内部(简化)如下:

input, target = next_batches(data, batch_pointer)

input = nn.functional.one_hot(input_seq, num_classes = one_hot_length).float().

for state in hc_states:
    state.detach_()
            
optimizer.zero_grad()
            
output, states = net.forward_pass(input, hc_states)

loss = nn.CrossEntropyLoss(output, target)
loss.backward()
    
nn.utils.clip_grad_norm_(net.parameters(), MaxGradNorm)

optimizer.step()

使用hc_states的元组具有隐藏状态张量和单元状态张量input是大小为 ( B , n , one_hot_length ) 的张量,目标是 ( B , n )。

我正在对一个非常小的数据集(约 400Ko 的 .txt 中的句子)进行训练,只是为了调整我的代码,并使用不同的参数进行了 4 次不同的运行,每次结果都相同:网络根本不学习当它具有 softmax 层时,并且在没有的情况下进行适当的训练。我认为张量形状不是问题,因为我几乎可以肯定我检查了所有内容。

我对我的问题的理解是我正在尝试进行分类,通常是在最后放置一个 softmax 单元以获得每个字符出现的“概率”,但显然这是不对的。

有什么想法可以帮助我吗?我对 Pytorch 和 RNN 也很陌生,所以如果我的架构/实现对知识渊博的人来说是某种怪物,我提​​前道歉。请随时纠正我并提前感谢。

标签: pythondeep-learningpytorchlstmsoftmax

解决方案


推荐阅读