python - 对连续的非零张量值求和
问题描述
我试图找到连续非零张量值的总和,如下所示
假设,我有一个张量A = [1.3, 0.0, 0.6, 0.7, 0.8]
。我想
1) 将张量的连续非零值求和以输出[1.3, 0.0, 2.1]
,然后选择最大值为2.1
。
2) 找到用于对这些值求和的索引。在这种情况下,它将是2, 3, 4
。
解决方案
我们可以通过两个简单的步骤来解决这个问题:
首先,根据零分割主张量。所以,给定的张量A
看起来像这样[ [1.3], [0], [0.6, 0.7, 0.8] ]
。这可以使用以下函数来完成:
def split_list(lst, value=0):
"""
Splits a given list based on a given value
default is zero
"""
groups = []
sub_group = []
for i in lst:
if i == 0:
groups.append(sub_group)
sub_group = []
groups.append([0])
else:
sub_group.append(i)
if sub_group:
groups.append(sub_group)
return groups
其次,对每个子组求和。返回的索引会有点棘手。所以,让我们在代码中看到它:
def get_max_indices(groups):
"""
This function takes a list of lists and
returns the indices of the maximum elements
"""
maximum = 0
max_length = 0
total_elements = 0
length_before = 0
for idx, sub_group in enumerate(groups):
summation = sum(sub_group)
if summation > maximum:
maximum = summation
max_length = len(sub_group)
length_before = total_elements
total_elements += len(sub_group)
return [_ for _ in range(length_before, length_before+max_length)]
现在,让我们同时尝试它们:
>>> lst = [1.3, 0, 0.6, 0.7, 0.8]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[2, 3, 4]
让我们试试另一个例子:
>>> lst = [1, 2, 3, 0, 6, 9, 0, 10]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[4, 5]
我希望这能解决你的问题。我知道这比你想象的要复杂一些,但这会让你开始。我认为这可以优化和清理,但我会把它留给你。