首页 > 解决方案 > Numpy在窗口数组中找到最大元组

问题描述

我从元组列表开始(每个元组都是一个(X,Y))。我的最终结果是我想使用 numpy 在长度为 4 的每个窗口/bin 中找到最大 Y 值。

# List of tuples
[(0.05807200929152149, 9.9720125), (0.34843205574912894, 1.1142874), (0.6387921022067363, 2.0234027), (0.9291521486643438, 1.4435122), (1.207897793263647, 2.3677678), (1.4982578397212543, 1.9457655), (1.7886178861788617, 2.8343441), (2.078977932636469, 5.7816567), ...]

# Convert to numpy array
dt = np.dtype('float,float')
arr = np.asarray(listTuples, dt)
# [(0.05807201, 9.97201252) (0.34843206, 1.11428738)
#  (0.6387921 , 2.02340269) (0.92915215, 1.4435122 )
#  (1.20789779, 2.36776781) (1.49825784, 1.9457655 )
#  (1.78861789, 2.83434415) (2.07897793, 5.78165674)
#  (2.36933798, 3.14842606) ...]

#Create windows/blocks of 4 elements
arr = arr.reshape(-1,4)
# [[(0.05807201, 9.97201252) (0.34843206, 1.11428738)
#   (0.6387921 , 2.02340269) (0.92915215, 1.4435122 )]
#  [(1.20789779, 2.36776781) (1.49825784, 1.9457655 )
#   (1.78861789, 2.83434415) (2.07897793, 5.78165674)]
#  [(2.36933798, 3.14842606) (2.95005807, 2.10357308)
#   (3.24041812, 1.15985966) (3.51916376, 2.03056955)]...]

print(arr.max(axis=1)) <-- ERROR HERE
print(max(arr,key=lambda x:x[1])) <-- ERROR, tried this too but doesn't work

我希望使用最大 y 值从每个窗口/块获得的预期输出如下。或者格式可以是元组的常规列表,不需要是 numpy 数组:

[[(0.05807201, 9.97201252)]
[(2.07897793, 5.78165674)]
[(2.36933798, 3.14842606)]...]
OR other format:
[(0.05807201, 9.97201252),(2.07897793, 5.78165674),(2.36933798, 3.14842606)]...]

标签: pythonnumpy

解决方案


这应该可以解决您的问题。

输入:元组列表

输出:元组列表,取每个4个元素块中y值最大的元组 import numpy as np

# List of tuples
listTuples = [(1,1),(120,1000),(12,90),(1,1),(0.05807200929152149, 9.9720125), 
(0.34843205574912894, 1.1142874), (0.6387921022067363, 2.0234027), 
(0.9291521486643438, 1.4435122), (1.207897793263647, 2.3677678), 
(1.4982578397212543, 1.9457655), (1.7886178861788617, 2.8343441), 
(2.078977932636469, 5.7816567)]


def extractMaxY(li):
    result = []
    index = 0
    for i in range(0,len(li), 4):
        max = -100000
#find the max Y in blocks of 4
        for j in range(4):
            if li[i+j][1] > max:
                max = li[i+j][1]
                index = i+j
        result.append(li[index])
    return result


print(extractMaxY(listTuples))

那么输出是

[(120, 1000), (0.05807200929152149, 9.9720125), (2.078977932636469, 
5.7816567)]

它应该是,对吧?


推荐阅读