首页 > 解决方案 > 如何将outlogits转换为令牌?

问题描述

我在 allenNlp 中有一个转发函数,由:

  def forward(self, input_tokens, output_tokens):
    '''
    This is the main process of the Model where the actual computation happens. 
    Each Instance is fed to the forward method. 
    It takes dicts of tensors as input, with same keys as the fields in your Instance (input_tokens, output_tokens)
    It outputs the results of predicted tokens and the evaluation metrics as a dictionary. 
    '''

    mask = get_text_field_mask(input_tokens)
    embeddings = self.embedder(input_tokens)
    rnn_hidden = self.rnn(embeddings, mask)
    out_logits = self.hidden2out(rnn_hidden)
    loss = sequence_cross_entropy_with_logits(out_logits, output_tokens['tokens'], mask)

    return {'loss': loss}

out_logits 变量包含标记的概率,如何显示这些标记。outlogits 给出:

 array([[ 0.02416356,  0.0195566 , -0.03279119,  0.057118  ,  0.05091334,
    -0.01906729, -0.05311333,  0.04695245,  0.06872341,  0.05173637,
    -0.03523348, -0.00537474, -0.03946163, -0.05817827, -0.04316377,
    -0.06042208,  0.01190596,  0.00574979,  0.01183304,  0.02330608,
     0.04587644,  0.02319966,  0.0020873 ,  0.03781978, -0.03975108,
    -0.0131919 ,  0.00393738,  0.04785313,  0.00159995,  0.05751844,
     0.05420169, -0.01404533, -0.02716331, -0.03871592,  0.00949999,
    -0.02924301,  0.03504215,  0.00397302, -0.0305252 , -0.00228448,
     0.04034173,  0.01458408],
   [ 0.02050283,  0.0204745 , -0.03081856,  0.06295916,  0.04601778,
    -0.0167818 , -0.05653084,  0.05017883,  0.07212739,  0.06197165,
    -0.03590995, -0.01142827, -0.03807197, -0.05942211, -0.0375165 ,
    -0.06769539,  0.01200251,  0.01012686,  0.01514241,  0.01875677,
     0.04499928,  0.02748671,  0.0012517 ,  0.04062563, -0.04049949,
    -0.01986902,  0.00630998,  0.05092276,  0.00276728,  0.05341531,
     0.05047017, -0.01111878, -0.03038253, -0.04320357,  0.01768938,
    -0.03470382,  0.03567442,  0.00776757, -0.02703476, -0.00392571,
     0.04700187,  0.01671317]] dtype=float32)}

我想将最后一个数组转换为令牌?

标签: nlppytorchallennlp

解决方案


在 allennlp 中,您可以使用词汇表访问该self.vocab属性。get_token_from_index

通常要从 logits 中选择一个标记,会应用一个 softmax(为了使所有概率总和为 1),然后选择最可能的一个。

如果您想从模型中解码序列,也许您应该查看 [BeamSearch] ( https://docs.allennlp.org/master/api/nn/beam_search/#beamsearch )。


推荐阅读