graph - Pytorch 几何图分类:AttributeError:'Batch' 对象没有属性'local_var'
问题描述
我目前正在使用深度学习,特别是 pytorch 几何环境对 IMDB-Binary 数据集进行图形分类。
我已将数据拆分为测试/训练样本,这些样本是包含图形及其标签的元组列表。我必须做的一件事是使用 torch_geometric.data.Batch 将不同的图视为“Batch”,一个大的断开连接的图。首先,我使用具有以下整理功能的数据加载器
def collate(samples) :
graphs,labels = map(list,zip(*samples))
datalist = make_datalist(graphs)
datalist = Batch.from_data_list(datalist)
return datalist, torch.tensor(labels)
我的分类器如下:
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
# Use node degree as the initial node feature. For undirected graphs, the in-degree
# is the same as the out_degree.
h = g.in_degrees
# Perform graph convolution and activation function.
h = F.relu(self.conv1(g, h))
h = F.relu(self.conv2(g, h))
g.ndata['h'] = h
# Calculate graph representation by averaging all the node representations.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
它只是平均每个图的节点表示,并将其提供给 MLP
我想出的问题是,在我们批次的预测过程中,我有错误 AttributeError: 'Batch' object has no attribute 'local_var' 我找不到它可能来自哪里,有人知道吗?
感谢您花时间阅读!
解决方案
我还在试验 Pytorch 几何及其数据集功能。
也许以下信息将来会对某人有所帮助:
当我忘记为我的数据集类属性设置 @property 带注释的 getter/setter 时,我正面临 AttributeErrors。请参阅https://docs.python.org/3.7/library/functions.html#property
我认为要回答您的问题,我们需要有关您的make_datalist
功能的更多信息。
但是,这里是批处理类的链接:
- https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/batch.html#Batch
事实上,没有什么比local_var
变量更重要的了。
推荐阅读
- excel - 去掉 MIN 函数中的 0 值
- html - CSS/Sass grid gaps not responsive
- javascript - 节点“保存”主密码 5 分钟
- httpresponse - 即使客户端不关心响应,是否也需要响应和状态码?
- css - 用 CSS 覆盖 Childs 属性
- python - 如何在 matplotlib 中添加额外的 y 轴标签
- r - 有没有办法在 Rmarkdown 中指定全局选项块函数,以便通过切换命令仅显示/不显示某些图形?
- javascript - 如何在任何孩子上使用相同的谷歌地图组件实例?
- ruby-on-rails - 为包含查询数据库的回调的 Rails 模型创建 FactoryGirl 工厂
- laravel - Laravel/React .HTACCESS 和 .HTPASSWD 不起作用