首页 > 解决方案 > 将属性添加到数据集的对象

问题描述

我对 pytorch 和 pytorch-geometric 非常陌生。我需要加载一个数据集,然后将一个属性映射到集合中稍后将在脚本中使用的每个对象。但是我不知道该怎么做。

我开始加载

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='PROTEINS')

然后我添加属性。我试过(值 3 仅作为示例,它将是一个 db 查询)

for data in dataset:
    data.keys.append('szemeredi_id')
    data.szemeredi_id = 3

或者

for data in dataset:
    data['szemeredi_id'] = 3

或者

for i, s in enumerate(dataset):
    dataset[i]['szemeredi_id'] = 3

或者

for data in dataset:
    setattr(data, 'szemeredi_id', 3)

但该属性始终为空。我什至尝试为 Data 类编写一个装饰器类

class SzeData(Data):
    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, normal=None, face=None, **kwargs):
        super(SzeData, self).__init__(x, edge_index, edge_attr, y, pos, normal, face)
        self.szemeredi_id = None

但是如果我尝试替换 Data 对象,它会引发错误,或者如果我使用这个解决方案TypeError: 'TUDataset' object does not support item assignment它什么也不做。

任何建议都非常感谢。谢谢你。

标签: pythonpython-3.xpytorchdatasetpytorch-geometric

解决方案


推荐阅读