首页 > 解决方案 > 对连续的非零张量值求和

问题描述

我试图找到连续非零张量值的总和,如下所示

假设,我有一个张量A = [1.3, 0.0, 0.6, 0.7, 0.8]。我想

1) 将张量的连续非零值求和以输出[1.3, 0.0, 2.1],然后选择最大值为2.1

2) 找到用于对这些值求和的索引。在这种情况下,它将是2, 3, 4

标签: pythondeep-learningpytorch

解决方案


我们可以通过两个简单的步骤来解决这个问题:

首先,根据零分割主张量。所以,给定的张量A看起来像这样[ [1.3], [0], [0.6, 0.7, 0.8] ]。这可以使用以下函数来完成:

def split_list(lst, value=0):
    """
    Splits a given list based on a given value
    default is zero
    """
    groups = []
    sub_group = []
    for i in lst:
        if i == 0:
            groups.append(sub_group)
            sub_group = []
            groups.append([0])
        else:
            sub_group.append(i)
    if sub_group:
        groups.append(sub_group)
    return groups

其次,对每个子组求和。返回的索引会有点棘手。所以,让我们在代码中看到它:

def get_max_indices(groups):
    """
    This function takes a list of lists and 
    returns the indices of the maximum elements
    """
    maximum = 0
    max_length = 0
    total_elements = 0
    length_before = 0
    for idx, sub_group in enumerate(groups):
        summation = sum(sub_group)
        if summation > maximum:
            maximum = summation
            max_length = len(sub_group)
            length_before = total_elements
        total_elements += len(sub_group)
    return [_ for _ in range(length_before, length_before+max_length)]

现在,让我们同时尝试它们:

>>> lst = [1.3, 0, 0.6, 0.7, 0.8]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[2, 3, 4]

让我们试试另一个例子:

>>> lst = [1, 2, 3, 0, 6, 9, 0, 10]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[4, 5]

我希望这能解决你的问题。我知道这比你想象的要复杂一些,但这会让你开始。我认为这可以优化和清理,但我会把它留给你。


推荐阅读