首页 > 解决方案 > NumPy 数组按索引部分向下填充行

问题描述

假设我有以下(虚构的)NumPy 数组:

arr = np.array(
    [[1,   2,  3,  4],
     [5,   6,  7,  8],
     [9,  10, 11, 12],
     [13, 14, 15, 16],
     [17, 18, 19, 20],
     [21, 22, 23, 24],
     [25, 26, 27, 28],
     [29, 30, 31, 32],
     [33, 34, 35, 36],
     [37, 38, 39, 40]
    ]
)

对于行索引idx = [0, 2, 3, 5, 8, 9],我想向下重复每行中的值,直到它到达下一个行索引:

np.array(
    [[1,   2,  3,  4],
     [1,   2,  3,  4],
     [9,  10, 11, 12],
     [13, 14, 15, 16],
     [13, 14, 15, 16],
     [21, 22, 23, 24],
     [21, 22, 23, 24],
     [21, 22, 23, 24],
     [33, 34, 35, 36],
     [37, 38, 39, 40]
    ]
)

请注意,idx它将始终被排序并且没有重复值。虽然我可以通过执行以下操作来完成此操作:

for  start, stop in zip(idx[:-1], idx[1:]):
    for i in range(start, stop):
        arr[i] = arr[start]

# Handle last index in `idx`
start, stop = idx[-1], arr.shape[0]
for i in range(start, stop):
    arr[i] = arr[start]

不幸的是,我有很多很多这样的数组,随着数组的大小变大(行数和列数)并且长度idx也增加,这可能会变慢。最终目标是将这些绘制为热图matplotlib,我已经知道该怎么做。我尝试的另一种方法是使用np.tile

for  start, stop in zip(idx[:-1], idx[1:]):
    reps = max(0, stop - start)
    arr[start:stop] = np.tile(arr[start], (reps, 1))

# Handle last index in `idx`
start, stop = idx[-1], arr.shape[0]
arr[start:stop] = np.tile(arr[start], (reps, 1))

但我希望有办法摆脱缓慢的for-loop.

标签: pythonarraysnumpy

解决方案


尝试np.diff找到每一行的重复,然后np.repeat

# this assumes `idx` is a standard list as in the question
np.repeat(arr[idx], np.diff(idx+[len(arr)]), axis=0)

输出:

array([[ 1,  2,  3,  4],
       [ 1,  2,  3,  4],
       [ 9, 10, 11, 12],
       [13, 14, 15, 16],
       [13, 14, 15, 16],
       [21, 22, 23, 24],
       [21, 22, 23, 24],
       [21, 22, 23, 24],
       [33, 34, 35, 36],
       [37, 38, 39, 40]])

推荐阅读