首页 > 解决方案 > 如何根据第二个数组中的重复元素将pytorch数组元素相乘

问题描述

我有一个看似简单的问题,但我无法在任何地方找到答案。如果我有两个数组,我想根据另一个数组中的元素是连续的还是重复的,来相乘/组合其中一个的元素。例如,

array_with_repeated_elements = tensor([1, 2, 0, 0, 2, 2, 2, 1, 0, 0])
# could just as well be [a, b, c, c, d, d, d, e, f, f]
array_to_be_multiplied = tensor([1., 3., 5., 2., 2., 7., 2., 4., 3., 4.])

desired_output = tensor([1, 3, 10, 28, 4, 12])

在 numpy 中,这可以很容易地完成:

first_index_of_each_sequence = np.hstack([0,np.where(array_with_repeated_elements[1:] != array_with_repeated_elements[0:-1])[0]+1])
# this creates array([0, 1, 2, 4, 7, 8])
desired_output = 1-np.multiply.reduceat(array_to_be_multiplied, first_index_of_each_sequence)

我似乎无法在 pytorch 中执行此操作。我最好的猜测是这个怪物:

first_index_of_each_sequence = torch.cat([torch.LongTensor((0,)), torch.where(array_with_repeated_elementst[1:] != array_with_repeated_elementst[0:-1])[0]+1, torch.LongTensor((len(array_with_repeated_elements),))])
# makes tensor([0, 1, 2, 4, 7, 8, 10])

size_of_each_sequence = first_index_of_each_sequence[1:] - first_index_of_each_sequence[0:-1]
# makes tensor([1, 1, 2, 3, 1, 2])

full_length_array_of_ascending_index_elements = torch.arange(len(size_of_each_sequence)).repeat_interleave(size_of_each_sequence)
desired_output_base = torch.zeros(len(size_of_each_sequence))
# makes tensor([0, 1, 2, 2, 3, 3, 3, 4, 5, 5])

desired_output_base.index_add_(0, full_length_array_of_ascending_index_elements, torch.log(array_to_be_multipliedt))
# does what I want in log space, but ew if I ever have a zero

desired_output = torch.exp(desired_output_base)
# duh

有没有人对如何很好地做到这一点有任何想法?简单的 numpy 实现表明我在 pytorch 中遗漏了一些东西......

标签: pythonnumpypytorch

解决方案


推荐阅读