首页 > 解决方案 > 如何在 PyTorch 的图中返回具有相同度数的特定节点?

问题描述

我有一个函数,它需要一个图表并找到每个单独节点的度数,然后计算具有相同度数值的节点数:

import torch
from torch_geometric.utils import degree

def fun(self, graph):
        
        n = graph.num_nodes
        d = degree(graph.edge_index[1], n, dtype=torch.long)
        counts = torch.bincount(d)

        return counts

上述功能工作正常。但我希望它只找到度数小于 50 的节点,然后返回具有相同度数值的节点数(度数 < 50),所以我将代码更改为以下代码:

def fun(self, graph):
        
        n = graph.num_nodes
        deg = degree(graph.edge_index[1], n, dtype=torch.long)
        less_than_fifty = [i if i < 50 else 0 for i in deg]
        counts = torch.bincount(less_than_fifty)

        return counts

运行后,出现以下错误:

TypeError: bincount(): argument 'input' (position 1) must be Tensor, not list

因此,我使用了张量而不是列表,如下所示:

def fun(self, graph):
        
        n = graph.num_nodes
        d = degree(graph.edge_index[1], n, dtype=torch.long)
        less_than_fifty = torch.tensor([i if i < 50 else 0 for i in deg])
        counts = torch.bincount(less_than_fifty)

        return counts

但这一次又出现了一个问题。我在 Google Colab 上运行了代码,由于最后一次修改(将列表转换为张量),Colab 一直在崩溃。我在 Colab 上使用 GPU。我确信崩溃的原因是这条线less_than_fifty = torch.tensor([i if i < 50 else 0 for i in deg]),因为每当我删除它时,就再也没有崩溃了。我的问题是如何解决这些问题?

标签: pythonpython-3.xpytorch

解决方案


推荐阅读