首页 > 解决方案 > 如何根据元组内容拆分 numpy 数组?

问题描述

假设我有一个数组[0, 1, 2, 3, 4, 5, 6, 7]和一个元组:(3, 3, 2)

我正在寻找一种方法来3根据我的元组数据将我的数组拆分为数组:

[0, 1, 2]
[3, 4, 5]
[6, 7]

我可以编写一个像这样的简单代码来获得我想要的东西,但是我正在寻找一种正确且 Pythonic 的方法来做到这一点:

为了简单起见,我使用了列表。

a = [0, 1, 2, 3, 4, 5, 6, 7]
b = (3, 3, 2)

pointer = 0
for i in b:
        lst = []
        for j in range(i):
                lst.append(a[pointer])
                pointer += 1
        print(lst)

或者这个:

a = [0, 1, 2, 3, 4, 5, 6, 7]
b = (3, 3, 2)
pointer = 0
for i in b:
        lst = a[pointer:pointer+i]
        pointer += i
        print(lst)

结果:

[0, 1, 2]
[3, 4, 5]
[6, 7]

标签: pythonnumpy

解决方案


you can use the split method of numpy

import numpy as np

a = [0, 1, 2, 3, 4, 5, 6, 7]
b = (3, 3, 2)

c = np.split(a, np.cumsum(b)[:-1])

for r in c:
    print(r)

np.split(a, b) splits a by the indices in b along a given axis(0 by default).


推荐阅读