首页 > 解决方案 > 将状态整数转换为 onehotvector

问题描述

我有一个函数用于将两个状态转换为一个单热向量并将它们全部组合起来,它工作正常,但如果我有一个大L的 ,例如3000,这将需要时间。该功能与L=3.

def OH3(x,end=2,len=3):
  x = T.LongTensor([[x]])
  end = T.LongTensor([[end]])
  one_hot_x = T.FloatTensor(len,l)
  one_hot_end = T.FloatTensor(len,l)
  first=one_hot_x.zero_().scatter_(1,x,1)
  second=one_hot_end.zero_().scatter_(1,end,1)
  vector=T.cat((one_hot_x,one_hot_end),dim=1)
  return vector

输出:

OH3(1)
tensor([[0., 1., 0., 0., 0., 1.]])

正如我所提到的,该功能工作正常,但对于大量状态来说速度较慢,我一次发送多个状态,比如 500 个状态,需要 4 秒

有更快的方法吗?

标签: pythonpytorch

解决方案


如果上面评论中的假设是正确的,那么这应该尽可能快地完成,一个新数组的分配和 2 个赋值操作。如果我遗漏了什么,请告诉我。

def OH3(x,end=2,len=3):
    vector = torch.zeros(2*len,dtype = int)
    vector[[x,len+end] = 1
    return vector

推荐阅读