首页 > 解决方案 > 折叠多维 NumPy 数组的简单方法

问题描述

我有一个 NumPy 二维数组:

a = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])

该数组的每一列(即:[1,3, 5, 7][2, 4, 6, 8])需要转换为给定大小的矩阵M1xM2(在整形时使用 order='F' )。在这种情况下,M1 = M2 = 2。所以,我想要的输出是:

b = np.array([[[1, 5], [3, 7]], [[2, 6], [4, 8]]]).

我可以通过遍历列轻松实现这一点。但是,列数可以是任意的,初始二维数组最多可以是 8 维。如何轻松扩展此解决方案以用于更多维度?

我怀疑这是一个常见的过程,并且有一个内置函数可以解决它,但一直找不到。

标签: pythonarraysnumpy

解决方案


它就像重塑一样简单,然后转置:

a.reshape(2, 2, -1).T
# array([[[1, 5],
#         [3, 7]],
# 
#        [[2, 6],
#         [4, 8]]])

推荐阅读