首页 > 解决方案 > 迭代pytorch中的张量

问题描述

我有两个一维张量。一个是预测向量,第二个是标签向量。我正在尝试编写一个循环来检查向量之间的元素差异。如果发现这样的差异,我想做另一个操作,为简单起见,假设我想打印(“发现差异”)。到目前为止,我想出了这个但我得到了一个错误:标量类型字节的预期对象但参数#2“其他”的标量类型浮点数。我会很感激这里的帮助。也许有一些更有效的方法可以做到这一点,没有循环。

for i in enumerate(t1):
    if t1[i] != t2[i]:
        print("Diff spotted")

标签: pythonpytorchtensor

解决方案


您可以使用eq()pytorch 中的函数来检查张量是否在元素方面是相同的。对于与标签元素相同的元素的每个索引,您将获得True

for label in predictions.round().eq(labels):
    for element in label:
        if element == False:
            print("Diff spotted!")

推荐阅读