首页 > 解决方案 > Numpy:在每行中的某个索引之后替换一行的剩余列

问题描述

我在 3 维 ndarray 中捕获了一些数据,其维度为:时间刻度、样本数和每个样本的值数。

通过前面的操作,我知道在某个时间tick,每个样本号是否失效。在这种情况永远不会发生的情况下,该数字设置为 -1。在其他情况下,它表示样本失效的时间刻度。

我希望能够做的是将其余列清空,或者通过将包含无效数据的列和右侧的列设置为 nans,或者通过使用一些屏蔽或索引技术导致只保留左边的数据。

我已经阅读或发现了涉及花式索引、slice()布尔数组和掩码数组的类似问题的参考资料,但我没有找到实现目标的方法。

import numpy as np

# dimensions are timestep, sample, and values per sample.  To make it easy, let's 
# do 3 time steps, 4 samples, and 2 values per sample.
data = np.array( [
    [ # Timestep 0
        [ 1, 2 ],  # Sample 1
        [ 3, 4 ],  #        2
        [ 5, 6 ],  #        3
        [ 7, 8 ],  #        4
    ],
    [ # Timestep 1
        [ 1, 2 ], 
        [ 3, 4 ],
        [ 5, 6 ],
        [ 7, 8 ],
    ],
    [ # Timestep 2
        [ 1, 2 ], 
        [ 3, 4 ],
        [ 5, 6 ],
        [ 7, 8 ],
    ],
])

每个样本都可能在某个时间步变为无效。如果时间步长从不无效,则值为 -1。

invalid_at = [ 
        0, # becomes invalid at timestep 0
        2, #                             2
        1, #                             1 
        -1 # Never is invalid
        ]

例如,如果我们用 nan 替换无效值,那么结果数组应该如下所示

data = np.array( [
    [ # Timestep 0
        [ n, n ],  # Sample 1
        [ 3, 4 ],  #        2
        [ 5, 6 ],  #        3
        [ 7, 8 ],  #        4
    ],
    [ # Timestep 1
        [ n, n ], 
        [ 3, 4 ],
        [ n, n ],
        [ 7, 8 ],
    ],
    [ # Timestep 2
        [ n, n ], 
        [ n, n ],
        [ n, n ],
        [ 7, 8 ],
    ],
])

我遇到的主要困难是我有起始索引,但我找不到创建切片的方法(使用花哨的索引或其他方式)让我分配给它。

例如,以下内容不起作用:

data[ :, invalid_at:-1, : ] = np.nan

我希望会发生的是,无效的 at 数组将被评估并生成每行切片。

我可以使用 for 循环来做到这一点,但我更愿意将其矢量化以提高速度和以后的可扩展性。有任何想法吗?

标签: pythonnumpy

解决方案


有几种可能的方法可以做到这一点。主要问题是您尝试应用的索引参差不齐。

如果样本数量很少,您可以以相对较少的额外开销循环它们。这个选项非常简单,并且支持简单的切片索引,这通常是最快的一种索引,因为它不需要额外的数据副本或掩码:

for sample, step in enumerate(invalid_at):
    if step < 0:
        continue
    data[step:, sample, :] = np.nan

如果您确实需要一步完成,您可以构建一个遮罩并应用它。该数组具有维度(时间步长、样本、x)。掩码只需要前两个维度。您需要设置一个条件,例如“如果一个元素在时间步t并且样本大于或等于invalid_at[t],则将该元素设置为True。该条件可以应用于一对广播数组:一个用于时间步,一个用于样本:

trange = np.arange(data.shape[0]).reshape(-1, 1)
srange = np.array(invalid_at).reshape(1, -1)
srange[srange == -1] = data.shape[0]
mask = (trange >= srange)
data[mask, :] = np.nan

这仅在您明确设置dtype=np.float或类似 for时才有效data,因为当前定义的整数 to 不支持 NaN。


推荐阅读