python - 在 narray 子类中使用 super().take 包装 numpy 切片
问题描述
我正在尝试对一个 numpy 数组进行子类化,以创建具有周期性边界条件的数组。我只需要它来处理一维数组。
我想要什么:假设我有一个对象a=Periodic_Grid([0,1,2,3,4,5,6,7,8,9])
。我想要以下输出
>>> a=Periodic_Grid(a)
>>> a
Periodic_Grid([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> a[0]
0
>>> a[7]
7
>>> a[14]
4
>>> a[-5]
5
>>> a[-14]
6
>>> a[1:5]
array([1., 2., 3., 4.])
>>> a[8:13]
array([8., 9., 0., 1., 2.])
>>> a[-5:2]
array([5., 6., 7., 8., 9., 0., 1.])
换句话说,索引只是环绕。切片也应该适当地环绕。否则,它会正常工作。
最小工作代码
class Periodic_Grid(np.ndarray):
def __new__(cls, input_array):
obj = np.asarray(input_array).view(cls)
obj.grid_size = input_array.size
return obj
def __getitem__(self, index):
if isinstance(index, slice):
indicies=list(range(index.start, index.stop))
size=len(indicies)
a=np.empty(size)
print("Size: {0} indicies: {1}".format(size, indicies))
for i, ind in enumerate(indicies):
idx=self.wrapped_index(ind)
a[i]=self[idx]
return a
elif isinstance(index, tuple):
idx=index[0]
idx=self.wrapped_index(idx)
return super().__getitem__(idx)
elif isinstance(index, int):
index = self.wrapped_index(index)
return super().__getitem__(index)
else:
print("RAISING EXCEPTION")
raise ValueError
def __setitem__(self, index, item):
index = self.wrapped_index(index)
return super().__setitem__(index, item)
def __array_finalize__(self, obj):
if obj is None: return
self.grid_size = getattr(obj, 'grid_size', obj.size)
pass
def wrapped_index(self, index):
idx=None
#if hasattr(index, '__iter__'): idx=
if type(index) is tuple:
idx=index[0] #only support for 1D!
elif type(index) is int:
idx=index
else:
raise ValueError('Unexpected index: {}'.format(index))
我的问题
上面的代码实际上运行良好,所以从这个意义上说,我的问题解决了。但是,我有一个关于如何在__getitem__
. 在上面的代码中,我使用了一个循环并手动循环需要的索引并填充适当的值。这发生在 Python 循环中,而我想尽可能让numpy
处理这个循环。我认为这应该可以使用 numpy.take 像这样(仅查看__getitem__
处理切片对象的块:
def __getitem__(self, index):
if isinstance(index, slice):
start=index.start
stop=index.stop
indices=list(range(start, stop))
return super().take(self, indices, mode='wrap')
但是,现在我遇到了一个错误:
>>> b[1:5]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "path/to/class/file", line 130, in __getitem__
return super().take(self, indices, mode='wrap')
TypeError: 'list' object cannot be interpreted as an integer
然而,据我所知,通过文档take
中的第一个示例判断,允许传递索引列表。我怀疑我用错了,但我不确定。任何帮助在这里表示赞赏。super()
解决方案
推荐阅读
- node.js - 当 xml 模式验证失败时,Nodejs libxmljs 使 docker 容器崩溃
- ruby - 为什么我不能将类错误与链一起进一步发送?
- javascript - 从 javascript 对象列表中查找最多前两次出现的元素
- powershell - 使用 Powershell 过滤文件扩展名
- angular - Angular:“Observable”类型上不存在属性“then”
- jenkins - java.io.IOException:找不到 allure 命令行
- java - 在TestNG中参数化@BeforeMethod方法
- firebase - Firebase 动态链接的 AppLink 配置
- c# - 在C#中检索firebase子根的所有项目
- python - 与特殊情况 Sqlalchamy 的一对一关系