首页 > 解决方案 > 在python中将coo格式转换为边权重

问题描述

我想制作标签数组,以便使用 GNN 输出中的 BCE 计算损失。我有的

 import torch
 edge_index = [[0, 0, 0, 2, 2, 2, 3, 4, 5], [2, 3, 4, 3, 4, 5, 4, 5, 6]]
 edge_index_gt = [[0, 2, 3, 4, 5], [2, 3, 4, 5, 6]]
 out_put = torch.rand(edge_index.shape[1]

边缘索引指示一个节点是否连接到另一个节点。例如在这种情况下,节点 0 将连接到节点 2、3 和 4。edge_index 和 edge_index_gt 都具有形状 [2, number of edges]。我想让输出像这样

label = [1, 0, 0, 1, 0, 0, 1, 1, 1]

形状等于 edge_index 数组中的边数。

我试过的

label = torch.zeros_like(out_put)
for i, val in enumerate(zip(edge_index[0], edge_index[1])):
  if val in zip(edge_index_gt[0], edge_index_gt[1]):
     label[i] = 1

如您所见,此代码采用 O(N*M) 其中 N 是 edge_index 数组中的元素数, M 是 edge_index_gt 中的元素数

有什么办法可以让事情变得更好吗?谢谢

标签: pythontorch

解决方案


推荐阅读