首页 > 解决方案 > Pytorch 几何图分类:AttributeError:'Batch' 对象没有属性'local_var'

问题描述

我目前正在使用深度学习,特别是 pytorch 几何环境对 IMDB-Binary 数据集进行图形分类。

我已将数据拆分为测试/训练样本,这些样本是包含图形及其标签的元组列表。我必须做的一件事是使用 torch_geometric.data.Batch 将不同的图视为“Batch”,一个大的断开连接的图。首先,我使用具有以下整理功能的数据加载器

def collate(samples) :
  graphs,labels = map(list,zip(*samples))
  datalist = make_datalist(graphs)
  datalist = Batch.from_data_list(datalist)
  return datalist, torch.tensor(labels)

我的分类器如下:

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.in_degrees
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

它只是平均每个图的节点表示,并将其提供给 MLP

我想出的问题是,在我们批次的预测过程中,我有错误 AttributeError: 'Batch' object has no attribute 'local_var' 我找不到它可能来自哪里,有人知道吗?

感谢您花时间阅读!

标签: graphpytorch

解决方案


我还在试验 Pytorch 几何及其数据集功能。

也许以下信息将来会对某人有所帮助:

当我忘记为我的数据集类属性设置 @property 带注释的 getter/setter 时,我正面临 AttributeErrors。请参阅https://docs.python.org/3.7/library/functions.html#property

我认为要回答您的问题,我们需要有关您的make_datalist功能的更多信息。

但是,这里是批处理类的链接:

事实上,没有什么比local_var变量更重要的了。


推荐阅读