首页 > 解决方案 > 构造一个允许通过numpy数组调用方法的类

问题描述

我定义了一个类Tomato并创建了一个包含该类的多个对象的数组:

import numpy as np
class Tomato:
    color = None
    radius = None
    def __init__(self):
        color = np.random.choice(['red', 'green'])
        radius = np.random.rand()

arr1 = np.array([Tomato() for i in range(6)]).reshape(3, 2)
arr1

产量

array([[<__main__.Tomato object at 0x000002479E7D97F0>,
        <__main__.Tomato object at 0x000002479E7D9710>],
       [<__main__.Tomato object at 0x000002479E7D94E0>,
        <__main__.Tomato object at 0x000002479E7D9DA0>],
       [<__main__.Tomato object at 0x000002479E710630>,
        <__main__.Tomato object at 0x000002479E7D9C18>]], dtype=object)

我希望能够打电话

arr1.radius

并获得一个仅包含radius每个番茄的 3x2 数组。我知道我可以使用np.vectorize()lambda 表达式,正如在提问者正在处理来自外部导入类的对象的问题中所建议的那样。cftime

但我相信我应该有更多的选择,因为我Tomato自己定义了这个类。

例如,complex128数据类型有方法.real.imag,复数浮点数组也有。

arr2 = np.random.normal(size=(3, 2)) + 1j * np.random.normal(size=(3, 2))
arr2.imag

为您提供每个条目的虚部:

array([[-0.23054982,  0.04599812],
       [-0.07459619, -0.11282513],
       [-0.32441139,  0.8920348 ]])

有没有办法修改Tomato的类定义以允许用户通过 numpy 数组访问其属性?

如果不是,上面的例子是如何arr2工作的?numpy数组类的代码中是否手动指定了.realand方法?.imag

标签: pythonnumpy

解决方案


您可以为您提供的用例(即切片和初始化)定义自己的类:

import numpy as np
class TomatoArr:
    def __init__(self, col, r):
        self.col = col if isinstance(col, np.ndarray) else np.array(col, dtype='<U5')
        self.r = r if isinstance(r, np.ndarray) else np.array(r, dtype=float)
    
    def __getitem__(self, idx):
        return TomatoArr(self.col[idx], self.r[idx])
   
    @classmethod
    def from_list(cls, tlist):
        n = len(tlist)
        col = np.array([a.col for a in tlist], dtype='<U5')
        r = np.array([a.r for a in tlist], dtype=float)
        return cls(col, r)

利用:

In [14]: tom = TomatoArr([['red', 'green'], ['green', 'red']], [[1.0, 1.2], [0.9, 1.1]])

In [15]: tom.col
Out[15]: 
array([['red', 'green'],
       ['green', 'red']], dtype='<U5')

In [16]: tom[:, 1].r
Out[16]: array([1.2, 1.1])

In [17]: tom[:, 0].r += 100

In [18]: tom.r
Out[18]: 
array([[101. ,   1.2],
       [100.9,   1.1]])

In [19]: tom2 = TomatoArr.from_list([TomatoArr('red', 1.3), TomatoArr('red', 1.4)])

In [20]: tom2.r
Out[20]: array([1.3, 1.4])

In [21]: tom2.col
Out[21]: array(['red', 'red'], dtype='<U5')

当然,其他 numpy 操作不起作用并且无论如何都会有意义 - 和的总和是'red'多少'green'

注意:不起作用的是布尔索引分配:

# no effect
tom[tom.col=='red'].r = 10

您可以添加一个方法__setitem__,但这仅适用于类似

tom[tom.col=='red'] = TomatoArr('red', 10)

推荐阅读