首页 > 解决方案 > 张量元组

问题描述

我最近问了这个问题的一部分。我正在构建一个聊天机器人,并且有一个功能会产生问题。函数如下:

def variable_from_sentence(sentence):
  vec, length = indexes_from_sentence(sentence)
  inputs = [vec]
  lengths_inputs = [length]
  if hp.cuda:
    batch_inputs = Variable(torch.stack(torch.Tensor(inputs),1).cuda())
  else:
    batch_inputs = Variable(torch.stack(torch.Tensor(inputs),1))
  return batch_inputs, lengths_inputs

但是当我尝试运行聊天机器人代码时,它给了我这个错误:

stack():参数“张量”(位置 1)必须是张量的元组,而不是张量

出于这个原因,我修复了这样的功能:

def variable_from_sentence(sentence):
  vec, length = indexes_from_sentence(sentence)
  inputs = [vec]
  lengths_inputs = [length]
  if hp.cuda:
    batch_inputs = torch.stack(inputs, 1).cuda()
  else:
    batch_inputs = torch.stack(inputs, 1)
  return batch_inputs, lengths_inputs

但它仍然给我错误,错误是这样的:

TypeError:预期张量作为参数 0 中的元素 0,但得到列表

在这种情况下我现在该怎么办?

标签: pytorchtuplestensor

解决方案


由于vecandlength都是整数,所以可以torch.tensor直接使用:

def variable_from_sentence(sentence):
    vec, length = indexes_from_sentence(sentence)
    inputs = [vec]
    lengths_inputs = [length]
    if hp.cuda:
        batch_inputs = torch.tensor(inputs, device='cuda')
    else:
        batch_inputs = torch.tensor(inputs)
    return batch_inputs, lengths_inputs

推荐阅读