首页 > 解决方案 > NumPy dtype 强制结构化数组扩展类型

问题描述

我将 specialnp.dtypes用于具有如下结构的程序:

POINT = np.dtype([('vertices', '<f4', 2), ('center', '<f4', 2), ('bbox', '<f4', 4)])

我需要指定另一个np.dtype只使用上面最后一个字段的字段,如下所示:

MBR = np.dtype([('bbox', '<f4', 4)])

这样我以后可以像这样访问两个数组的该字段:

def intersection(s, t):

    sxmin, sxmax, symin, sxmax = s['bbox']
    txmin, txmax, tymin, tymax = t['bbox']

    # do stuff

但是,当我创建以下数组时,它正在被扩展,我不确定为什么:

box = np.array([1, 2, 3, 4], dtype=MBR)
# expected output...
array([1., 2., 3., 4.], dtype=[('bbox', '<f4', 4)])
# actual output...
array([([1., 1., 1., 1.],), ..., ([4., 4., 4., 4.],)], dtype=[('bbox', '<f4', 4)])

快速测试返回了我的预期......

np.empty([], dtype=MBR)
array(([nan, nan, inf, nan],), dtype=[('bbox', '<f4', 4)])

编辑:

执行以下操作会返回预期结果:

box = np.array(([1, 2, 3, 4],), dtype=MBR)

所以现在的问题是:为什么我必须将它包装在一个元组中以符合 dtype?

标签: pythonnumpy

解决方案


简短的回答是,带有嵌套列表和元组的输入格式必须与显示格式相匹配:

In [59]: MBR = np.dtype([('bbox', '<f4', 4)])                                                    
In [60]: arr = np.zeros(3, dtype=MBR)                                                            
In [61]: arr                                                                                     
Out[61]: 
array([([0., 0., 0., 0.],), ([0., 0., 0., 0.],), ([0., 0., 0., 0.],)],
      dtype=[('bbox', '<f4', (4,))])
In [62]: arr[0]                                                                                  
Out[62]: ([0., 0., 0., 0.],)
In [63]: arr[0]=[1,2,3,4]                                                                        
In [64]: arr[1]=[10,11,12,13]                                                                    
In [65]: arr                                                                                     
Out[65]: 
array([([ 1.,  2.,  3.,  4.],), ([10., 11., 12., 13.],),
       ([ 0.,  0.,  0.,  0.],)], dtype=[('bbox', '<f4', (4,))])
In [66]: np.array([([1,2,3,4],)],MBR)                                                            
Out[66]: array([([1., 2., 3., 4.],)], dtype=[('bbox', '<f4', (4,))])

因此,对于典型的复合 dtype,我们说输入应该是一个元组列表,数组的每个“记录”一个元组。在元组中,每个字段一个元素。

您在字段中增加了大小 (4,) 维度的复杂性。

请注意,从数组中提取的字段形状结合了外部数组形状和内部字段形状:

In [67]: arr['bbox']                                                                             
Out[67]: 
array([[ 1.,  2.,  3.,  4.],
       [10., 11., 12., 13.],
       [ 0.,  0.,  0.,  0.]], dtype=float32)

通常,按字段而不是按记录为结构化数组分配值更容易:

In [68]: arr['bbox']=np.arange(12).reshape(3,4)                                                  
In [69]: arr                                                                                     
Out[69]: 
array([([ 0.,  1.,  2.,  3.],), ([ 4.,  5.,  6.,  7.],),
       ([ 8.,  9., 10., 11.],)], dtype=[('bbox', '<f4', (4,))])

推荐阅读