首页 > 解决方案 > 用 nditer 替换 pytorch 张量的 numpy 数组

问题描述

我正在寻找与 numpy 数组的 nditer 类似的 pytorch 张量功能,请参阅此链接和一个小示例。 https://discuss.pytorch.org/t/replacement-of-np-nditer-for-torch/64024?u=songqsh

标签: pythonpytorch

解决方案


我同意这样的评论,即逐个元素地迭代对于与数组或张量进行交互是一个好主意。nditer有很多功能,这里只是对张量中所有元素的迭代,它返回元素的坐标以及元素本身:

def deep_iter(data, ix=tuple()):
    try:
        for i, element in enumerate(data):
            yield from deep_iter(element, ix + (i,))
    except:
        yield ix, data

因此,例如在 pytorch 论坛上,它可以按如下方式使用:

new_values = {}
for i, value in deep_iter(a):
    if all(map(lambda x: 0 < x < (a.shape[1] - 1), i)):
        new_values[i] = calc_average(i, a) #write func to calc average

for i, new_value in new_values.items():
    a[i] = new_value

推荐阅读