python - 如何在 numpy.append 中短路展平
问题描述
我有一个从重载的第三方类继承的数据结构__getitem__
,返回一个元组。
现在,我在其他地方有代码,涉及将这些对象的集合附加到 NumPy 数组:
class ThirdPartyThing:
def __init__(self, size):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, key):
return (self, key)
def __iter__(self):
return zip([self] * self.size, range(self.size))
class MyThing(ThirdPartyThing):
pass
x = numpy.array([], dtype = MyThing, ndmin = 1)
temp = [MyThing(1) for _ in range(5)]
x = numpy.append(x, temp)
当我这样做时,我期望的是一个具有五个类型对象的 Numpy 数组,MyThing
但我得到的是一个像这样的一维数组:
[MyThing(), 0, MyThing(), 0, MyThing(), 0, MyThing(), 0, MyThing(), 0]
它的长度为 10,其中每个其他元素都是整数。
根据文档,如果未定义但定义轴对我的情况没有影响,append
则尝试展平数组。axis
有没有办法避免这个陷阱?
更新
仔细检查后,我意识到基类重载__len__
。我认为这就是造成这里问题的原因。
解决方案
您帖子的准确副本:
In [1]: class MyThing:
...: pass
...:
...: x = numpy.array([], dtype = MyThing, ndmin = 1)
...: temp = [MyThing() for _ in range(5)]
...: x = numpy.append(x, temp)
In [2]: x
Out[2]:
array([<__main__.MyThing object at 0x7f21b45cd2e8>,
<__main__.MyThing object at 0x7f21b45cd278>,
<__main__.MyThing object at 0x7f21b45cd240>,
<__main__.MyThing object at 0x7f21b45cd320>,
<__main__.MyThing object at 0x7f21b45cd390>], dtype=object)
至于np.append
,其代码为:
def append(arr, values, axis=None):
arr = asanyarray(arr)
if axis is None:
if arr.ndim != 1:
arr = arr.ravel()
values = ravel(values)
axis = arr.ndim-1
return concatenate((arr, values), axis=axis)
所以有了轴,它就是concatenate
. 没有它确保两个参数都是 1d。
你x
是 (0,) 形状,你temp
是一个 5 元素列表,其中 asarray 变成 (5,) 形状,结果是 (5,)
In [14]: x=numpy.array([], dtype = MyThing, ndmin = 1)
In [15]: x.shape
Out[15]: (0,)
In [16]: np.array(temp).shape
Out[16]: (5,)
In [17]: np.concatenate((x,temp)).shape
Out[17]: (5,)
我没有看到这个问题。中的“扁平化”np.append
不会影响代码。但正如我评论的那样,我不喜欢np.append
. 它使太多新用户感到困惑,并且不需要。直接使用concatenate
。
您还包括ThirdPartyThing
类的代码,但不要使用它。
给MyThing
一个代表:
In [21]: MyThing.__repr__= lambda self: "MYTHING"
并定义一个不同的temp
:
In [28]: temp1 = np.array([(MyThing(),0) for _ in range(3)])
现在我们看到了append
ravels 的效果:
In [30]: np.append(x,temp1)
Out[30]: array([MYTHING, 0, MYTHING, 0, MYTHING, 0], dtype=object)
(3,2)在与 (0,) 连接之前temp1
变为。(6m,)
x
添加axis=0
不起作用,因为维数不同。
使用您编辑的代码:
In [64]: temp = np.array([MyThing(1) for _ in range(3)])
In [65]: temp
Out[65]:
array([[[<__main__.MyThing object at 0x7f21adbc5048>, 0]],
[[<__main__.MyThing object at 0x7f21adbc5a58>, 0]],
[[<__main__.MyThing object at 0x7f21adbc5470>, 0]]], dtype=object)
In [66]: temp.shape
Out[66]: (3, 1, 2)
或与我的代表:
In [67]: MyThing.__repr__= lambda self: "MYTHING"
In [68]: temp
Out[68]:
array([[[MYTHING, 0]],
[[MYTHING, 0]],
[[MYTHING, 0]]], dtype=object)
In [70]: np.append(x,temp)
Out[70]: array([MYTHING, 0, MYTHING, 0, MYTHING, 0], dtype=object)
并且添加axis=0
仍然给出
ValueError: all the input arrays must have same number of dimensions
无论您如何构造它,尝试将 (0,) 形状数组与 (3,1,2) 形状连接起来都需要一些调整。
但是为什么要加入这两个阵列呢?(0,) 形状数组最初是从哪里来的?
您构建列表的方式是问题的根源:
In [87]: [MyThing(1) for _ in range(3)]
Out[87]: [MYTHING, MYTHING, MYTHING]
In [88]: np.array(_)
Out[88]:
array([[[MYTHING, 0]],
[[MYTHING, 0]],
[[MYTHING, 0]]], dtype=object)
In [89]: [MyThing(i) for i in range(3)] # different MyThing parameter each time
Out[89]: [MYTHING, MYTHING, MYTHING]
In [90]: np.array(_)
Out[90]: array([MYTHING, MYTHING, MYTHING], dtype=object)
但np.array([MyThing(2),MyThing(3)])
会导致某种无限循环。
但回到append
. 通常在迭代构建数组时,我们建议在列表中收集值(list append
非常快),并在最后进行一个数组构建(使用np.array
,np.stack
和/或np.concatenate
)。
不建议迭代地进行连接。它速度较慢,并且在创建有效的起始“空”数组时存在问题。你x
看起来像一个这样的空头。 np.append
给人一种错误的感觉,即这种迭代数组构造与列表追加方法一样好。它不是。这也是我不喜欢的部分原因np.append
。concatenate
您至少必须直接解决数组维度的差异。并concatenate
接受一个列表,而不仅仅是两个参数。所以它在循环之外工作。
和len
, iter
(ThirdPartyThing
和继承MyThing
)是一个可迭代的。 np.array
当从这些事物的列表中构造一个数组时,也尝试对它们进行迭代(与列表列表相同)。
MyThing
我可以创建一个空对象数组,然后单独填充它,而不是从 s 列表中创建数组。现在我得到了这些对象的“干净”数组:
In [93]: temp = np.empty(5, object)
In [94]: temp
Out[94]: array([None, None, None, None, None], dtype=object)
In [95]: for i in range(3):
...: temp[i] = MyThing(1)
...:
In [96]: temp
Out[96]: array([MYTHING, MYTHING, MYTHING, None, None], dtype=object)
甚至
In [100]: temp[:] = [MyThing(1) for _ in range(5)]
In [101]: temp
Out[101]: array([MYTHING, MYTHING, MYTHING, MYTHING, MYTHING], dtype=object)
只是不要给名单np.array
!
这temp
可以通过多种方式连接:
In [102]: np.concatenate([temp,temp,temp])
Out[102]:
array([MYTHING, MYTHING, MYTHING, MYTHING, MYTHING, MYTHING, MYTHING,
MYTHING, MYTHING, MYTHING, MYTHING, MYTHING, MYTHING, MYTHING,
MYTHING], dtype=object)
In [103]: np.vstack([temp,temp,temp])
Out[103]:
array([[MYTHING, MYTHING, MYTHING, MYTHING, MYTHING],
[MYTHING, MYTHING, MYTHING, MYTHING, MYTHING],
[MYTHING, MYTHING, MYTHING, MYTHING, MYTHING]], dtype=object)
In [105]: np.append(x,temp)
Out[105]: array([MYTHING, MYTHING, MYTHING, MYTHING, MYTHING], dtype=object)
推荐阅读
- python - 如何摆脱“IndexError:字符串索引超出范围”
- visual-studio-code - 有没有办法在 Visual Sudio Code 1.56.2 中将默认 shell 永久设置为命令提示符?
- discord.js - 我需要帮助修复错误,我尝试了多种不同的方法来修复它,但我没有成功
- bash - 如何在 Bash 中的命令的先前 xargs 的管道的 grep 输出上运行类似 xargs 的命令
- delay - VB2019 睡眠/延迟,GUI 仍在更新
- c++ - 将简单十进制转换为十六进制的最佳方法以及 32 位和 64 位的重要性?十进制字符串?转换?
- postgresql - 如何在postgresql中查找用户创建日期以及所有权限
- linux - ansible 模块如何触发特定的 linux 命令?
- javascript - 字体真棒在 CSS 伪代码中不起作用
- ios - 无法在 ios 上选择初始时间值