首页 > 解决方案 > 在对 numpy 数组进行迭代时,我无法调用存储在数组中的对象的方法

问题描述

StackOverflow 中提出的第一个问题,因此非常欢迎有关如何更好地“提问”的提示。

这部分代码的基本目标:多个球(no_balls)在随机方向移动。

我正在尝试从 python 列表转移到 numpy 数组以获得更好的性能。这是简化的代码。

基本问题:我的迭代器给了我类型为 ndarray 而不是 vpy.sphere 的对象,因此在我迭代的对象上调用 sphere.pos 失败。或者这是不可能的,因为 Numpy 是为数字而构建的?性能的替代品?

import vpython as vpy
import numpy as np

#Create and Fill numpy array with random size balls
balls = np.empty([no_ball], dtype=vpy.sphere)

with np.nditer(balls, flags=['refs_ok'], op_flags=['readwrite']) as b_it:
    debug_msg(len(b_it))
    for b in b_it:
        b[...] = (vpy.sphere( radius=random_in_range(ball_min_r,ball_max_r), 
                              opacity=0.8, 
                              color=random_RGB(), 
                              pos=vpy.vector(0,0,0),))
    debug_msg('populated balls list')

#Main Loop
debug_msg('Starting Main Loop')
while True:
    vpy.rate(30)
            
            
with np.nditer(balls, flags=['refs_ok'], op_flags=['readwrite']) as b_it:
    #Main Loop
    debug_msg('Starting Main Loop')
    while True:
        vpy.rate(30)
            
#The actual loop manipulates the position but the problem is that I can't access the   position of the sphere objects. Type returns nd.array for b
        for b in b_it:
           debug_msg(type(b[...]))
           debug_msg(b[...].pos)
#Above outputs
<class 'numpy.ndarray'>
Traceback (most recent call last):
  File "path", line 93, in <module>
    debug_msg(b[...].pos)
AttributeError: 'numpy.ndarray' object has no attribute 'pos'

如何调用数组中对象的方法和成员。在旁注中,为什么我需要调用 b[...] 而不是 b,似乎已经过时了。

标签: pythonarraysnumpyiteratorvpython

解决方案


一个简单的类:

In [149]: class Foo():
     ...:     def __init__(self,i):
     ...:         self.i = i
     ...:     def __repr__(self):
     ...:         return f'<FOO {self.i}>'
     ...: 
In [150]: Foo(323)
Out[150]: <FOO 323>

此类对象的列表:

In [151]: alist = [Foo(i) for i in range(10)]

等效的对象 dtype 数组:

In [152]: arr = np.array(alist)
In [153]: arr.dtype
Out[153]: dtype('O')
In [154]: arr
Out[154]: 
array([<FOO 0>, <FOO 1>, <FOO 2>, <FOO 3>, <FOO 4>, <FOO 5>, <FOO 6>,
       <FOO 7>, <FOO 8>, <FOO 9>], dtype=object)

从列表中获取属性:

In [155]: [f.i for f in alist]
Out[155]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
In [156]: timeit [f.i for f in alist]
826 ns ± 8.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

并从数组(较慢):

In [157]: timeit [f.i for f in arr]
1.66 µs ± 15.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

使用nditer- 您研究了足够多的文档以正确设置标志,但没有掌握这b是一个数组,而不是Foo

In [158]: with np.nditer(arr, flags=['refs_ok'], op_flags=['readwrite']) as b_it:
     ...:     for b in b_it:
     ...:         print(b, b.dtype, b.shape, b.item())
     ...: 
<FOO 0> object () <FOO 0>
<FOO 1> object () <FOO 1>
<FOO 2> object () <FOO 2>
<FOO 3> object () <FOO 3>
<FOO 4> object () <FOO 4>
<FOO 5> object () <FOO 5>
<FOO 6> object () <FOO 6>
<FOO 7> object () <FOO 7>
<FOO 8> object () <FOO 8>
<FOO 9> object () <FOO 9>

获取属性列表:

In [159]: res = []
     ...: with np.nditer(arr, flags=['refs_ok'], op_flags=['readwrite']) as b_it:
     ...:     for b in b_it:
     ...:         res.append(b.item().i)
     ...: 
     ...: 
In [160]: res
Out[160]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

而且时机不好:

In [161]: %%timeit
     ...: res = []
     ...: with np.nditer(arr, flags=['refs_ok'], op_flags=['readwrite']) as b_it:
     ...:     for b in b_it:
     ...:         res.append(b.item().i)
     ...: 

7.25 µs ± 60.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

对对象数组的元素执行操作的一种更简洁的方法是frompyfunc

In [162]: f = np.frompyfunc(lambda b:b.i,1,1)
In [163]: f(arr)
Out[163]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=object)
In [164]: timeit f(arr)
2.1 µs ± 8.58 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

仍然比迭代慢,但如果我们想要一个数组而不仅仅是一个列表,它比:

In [165]: timeit np.array([f.i for f in arr])
5.79 µs ± 21.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

nditer文档需要更强大的性能免责声明 。nditer当在ccython代码中使用时有用且快速,但当通过 Python 代码访问时,它不如更明显的替代方案。在某些情况下,额外的花里胡哨可能很有用,但大多数情况下,我认为它是正确编译代码的桥梁,而不是其本身的结束。

性能问题的核心Foo是 Python 类。因此访问i属性必须使用完整的 Python 引用系统。它不能使用任何快速编译的numpy数值方法。


推荐阅读