首页 > 解决方案 > 图网络的测试/训练数据集



g=nx.read_edgelist(data, create_using=nx.Graph())


train, test = train_test_split(g, test_size=0.2) 


标签: pythongraphnetworkx


根据您的任务,您可以尝试使用 Stellargraph 的EdgeSplitter类 ( docs ) 和 scikit-learn 的train_test_split函数 ( docs ) 来执行此操作。


如果你的任务是一个节点分类任务,这个使用图卷积网络(GCN)的节点分类是一个很好的例子,说明如何加载数据和进行训练-测试-分割。它以 Cora 数据集为例。最重要的步骤如下:

dataset = sg.datasets.Cora()
G, node_subjects = dataset.load()

train_subjects, test_subjects = model_selection.train_test_split(
    node_subjects, train_size=140, test_size=None, stratify=node_subjects
val_subjects, test_subjects = model_selection.train_test_split(
    test_subjects, train_size=500, test_size=None, stratify=test_subjects

train_gen = generator.flow(train_subjects.index, train_targets)
val_gen = generator.flow(val_subjects.index, val_targets)
test_gen = generator.flow(test_subjects.index, test_targets)

基本上,它与普通分类任务的 train-test-split 相同,除了我们这里拆分的是节点。


如果您的任务是边缘分类,您可以查看这个链接预测示例:Cora 引文数据集上的 GCN。train-test-split 最相关的代码是

# Define an edge splitter on the original graph G:
edge_splitter_test = EdgeSplitter(G)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the
# reduced graph G_test with the sampled links removed:
G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(
    p=0.1, method="global", keep_connected=True

# Define an edge splitter on the reduced graph G_test:
edge_splitter_train = EdgeSplitter(G_test)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the
# reduced graph G_train with the sampled links removed:
G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(
    p=0.1, method="global", keep_connected=True

# For training we create a generator on the G_train graph, and make an 
# iterator over the training links using the generator’s flow() method:

train_gen = FullBatchLinkGenerator(G_train, method="gcn")
train_flow = train_gen.flow(edge_ids_train, edge_labels_train)
test_gen = FullBatchLinkGenerator(G_test, method="gcn")
test_flow = train_gen.flow(edge_ids_test, edge_labels_test)

EdgeSplitter这里class( docs )后面的分割算法比较复杂,需要在分割的同时保持图的结构,比如保持图的连通性。有关更多详细信息,请参阅 EdgeSplitter的源代码
