首页 > 解决方案 > 重塑数组错误:*** IndexError:元组索引超出范围

问题描述

l 有一个或多个名为 my_data 的数组,这样: 类型的维度(my_data)=[96,18,36,36,3]<type 'numpy.ndarray'>

但是,my_data.shape返回(96,)但我希望得到 (96,18,36,36,3)。之后我尝试my_data[0].shape了返回(18,)而不是(18,36,36,3),最后 my_data[0][0].shape返回(36,36,3)。

我想知道为什么我没有得到 (96,18,36,36,3) 的形状

为什么我需要那个?

我想将 my_data 重塑为(96,18,36*36*3)

我试过什么?

my_data=my_data.reshape(my_data.shape[0],my_data.shape[1],my_data.shape[2]*my_data.shape[3]*my_data.shape[4])

我得到以下错误:

 *** IndexError: tuple index out of range

标签: arrayspython-2.7numpyreshapeflatten

解决方案


我什至不确定您是如何到达np.array其中np.array的 contains 的。但是,如果数据对齐,简单的转换应该为您解决它,例如:

import numpy as np

x = [np.random.randint(1, 100, (2, 3, 4)) for _ in range(2 * 3)]
x = np.array(x)
x.shape
# (6, 2, 3, 4)

如果数据没有对齐,你应该先对齐它:

import numpy as np
x = [np.random.randint(1, 100, (2, 3, np.random.randint(1, 5))) for _ in range(2 * 3)]
print([y.shape for y in x])
# [(2, 3, 2), (2, 3, 1), (2, 3, 1), (2, 3, 4), (2, 3, 1), (2, 3, 1)]
# x = np.array(x) will cause ValueError...


def zero_padding(arr, shape):
    result = np.zeros(shape, dtype=arr.dtype)
    mask = tuple(slice(None, d) for d in arr.shape)
    result[mask] = arr
    return result


x = [zero_padding(y, (2, 3, 4)) for y in x]
x = np.array(x)
x.shape
# (6, 2, 3, 4)

注意:这个更复杂的版本,zero_padding()由 NumPypad()FlyingCircus提供num.reframe()

免责声明:我是 FlyingCircus 的主要作者。


推荐阅读