首页 > 解决方案 > 在numpy ndarray的子​​类中覆盖.T(转置)

问题描述

我有一个三维数据集,其中第一个维度给出变量的类型,第二个和第三个维度是空间索引。我试图通过创建一个ndarray包含数据的子类来使这些数据对用户更加友好,但属性具有指向适当变量维度的合理名称。变量类型之一是温度,我想用属性来表示.T。我尝试这样设置:

self.T = self[8,:,:]

但是,这与用于转置数组的底层 numpy 属性相冲突。通常,覆盖类属性是微不足道的,但是在这种情况下,当我尝试重写属性时会出现异常。以下是同一问题的最小示例:

import numpy as np

class foo(np.ndarray):
    def __new__(cls, input_array):
        obj = np.asarray(input_array).view(cls)
        obj.T = 100.0
        return obj

foo([1,2,3,4])

结果是:

Traceback (most recent call last):
  File "tmp.py", line 9, in <module>
    foo([1,2,3,4])
  File "tmp.py", line 6, in __new__
    obj.T = 100.0
AttributeError: attribute 'T' of 'numpy.ndarray' objects is not writable

我曾尝试使用setattr(obj, 'T', 100.0)设置属性,但结果是一样的。

显然,我可以放弃并命名我的属性.temperature或其他名称。然而.T,对于将使用这些数据对象完成的后续数学表达式来说,这将更加雄辩。如何强制 python/numpy 覆盖此属性?

标签: pythonnumpysubclassing

解决方案


对于np.matrix子类,在 np.matrixlib.defmatrix 中定义:

@property
def T(self):
    """
    Returns the transpose of the matrix.
    ....
    """
    return self.transpose()

推荐阅读