首页 > 解决方案 > 删除张量以优化 python 中的 for 循环

问题描述

我正在处理我正在尝试优化的大型代码。您在下面看到的代码部分是一个以张量返回编码的 for 循环。如何在不使用张量的情况下将这些数字输出到常规列表中?

def _make_batches(self, lines):
        tokens = [self._tokenize(line) for line in lines]
        lengths = np.array([t.numel() for t in tokens])
        indices = np.argsort(-lengths, kind=self.sort_kind)  # pylint: disable=invalid-unary-operand-type

        def batch(tokens, lengths, indices):
            toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]),
                                      self.pad_index)
            for i in range(len(tokens)):
                toks[i, -tokens[i].shape[0]:] = tokens[i]
            return Batch(srcs=None,
                         tokens=toks,
                         lengths=torch.LongTensor(lengths)), indices

        batch_tokens, batch_lengths, batch_indices = [], [], []
        ntokens = nsentences = 0
        for i in indices:
            if nsentences > 0 and ((self.max_tokens is not None
                                    and ntokens + lengths[i] > self.max_tokens)
                                   or (self.max_sentences is not None
                                       and nsentences == self.max_sentences)):
                yield batch(batch_tokens, batch_lengths, batch_indices)
                ntokens = nsentences = 0
                batch_tokens, batch_lengths, batch_indices = [], [], []
            batch_tokens.append(tokens[i])
            batch_lengths.append(lengths[i])
            batch_indices.append(i)
            ntokens += tokens[i].shape[0]
            nsentences += 1
        if nsentences > 0:
            yield batch(batch_tokens, batch_lengths, batch_indices)

这就是我调用这个函数的方式:

if __name__ == '__main__':
    s = SentenceEncoder("data/model.pt")
    input = [args.string_enc]
    make_batches = s._make_batches
    print([batch[1] for batch, indexes in make_batches(input)])

输出是:

[tensor([[29733, 20720,     2]])]

所需的输出是:

[29733, 20720,     2]

标签: pythonfor-looppytorchtensor

解决方案


你是这个意思?

a=[torch.tensor([[29733, 20720,     2]])]
b=a[0].squeeze(0).tolist()
print(b)

推荐阅读