python - 迭代pytorch中的张量
问题描述
我有两个一维张量。一个是预测向量,第二个是标签向量。我正在尝试编写一个循环来检查向量之间的元素差异。如果发现这样的差异,我想做另一个操作,为简单起见,假设我想打印(“发现差异”)。到目前为止,我想出了这个但我得到了一个错误:标量类型字节的预期对象但参数#2“其他”的标量类型浮点数。我会很感激这里的帮助。也许有一些更有效的方法可以做到这一点,没有循环。
for i in enumerate(t1):
if t1[i] != t2[i]:
print("Diff spotted")
解决方案
您可以使用eq()
pytorch 中的函数来检查张量是否在元素方面是相同的。对于与标签元素相同的元素的每个索引,您将获得True
:
for label in predictions.round().eq(labels):
for element in label:
if element == False:
print("Diff spotted!")