首页 > 解决方案 > 图网络的测试/训练数据集

问题描述

我有一个正在创建的图形网络,如下所示:

g=nx.read_edgelist(data, create_using=nx.Graph())

我正在尝试为数据创建一个测试和训练集。我尝试使用以下命令:

train, test = train_test_split(g, test_size=0.2) 

但这没有用。你能告诉我当我有一个图形网络时我应该如何创建一个测试和训练集。

标签: pythongraphnetworkx

解决方案


根据您的任务,您可以尝试使用 Stellargraph 的EdgeSplitter类 ( docs ) 和 scikit-learn 的train_test_split函数 ( docs ) 来执行此操作。

节点分类

如果你的任务是一个节点分类任务,这个使用图卷积网络(GCN)的节点分类是一个很好的例子,说明如何加载数据和进行训练-测试-分割。它以 Cora 数据集为例。最重要的步骤如下:

dataset = sg.datasets.Cora()
display(HTML(dataset.description))
G, node_subjects = dataset.load()

train_subjects, test_subjects = model_selection.train_test_split(
    node_subjects, train_size=140, test_size=None, stratify=node_subjects
)
val_subjects, test_subjects = model_selection.train_test_split(
    test_subjects, train_size=500, test_size=None, stratify=test_subjects
)

train_gen = generator.flow(train_subjects.index, train_targets)
val_gen = generator.flow(val_subjects.index, val_targets)
test_gen = generator.flow(test_subjects.index, test_targets)

基本上,它与普通分类任务的 train-test-split 相同,除了我们这里拆分的是节点。

边缘分类

如果您的任务是边缘分类,您可以查看这个链接预测示例:Cora 引文数据集上的 GCN。train-test-split 最相关的代码是

# Define an edge splitter on the original graph G:
edge_splitter_test = EdgeSplitter(G)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the
# reduced graph G_test with the sampled links removed:
G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(
    p=0.1, method="global", keep_connected=True
)

# Define an edge splitter on the reduced graph G_test:
edge_splitter_train = EdgeSplitter(G_test)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the
# reduced graph G_train with the sampled links removed:
G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(
    p=0.1, method="global", keep_connected=True
)

# For training we create a generator on the G_train graph, and make an 
# iterator over the training links using the generator’s flow() method:

train_gen = FullBatchLinkGenerator(G_train, method="gcn")
train_flow = train_gen.flow(edge_ids_train, edge_labels_train)
test_gen = FullBatchLinkGenerator(G_test, method="gcn")
test_flow = train_gen.flow(edge_ids_test, edge_labels_test)

EdgeSplitter这里class( docs )后面的分割算法比较复杂,需要在分割的同时保持图的结构,比如保持图的连通性。有关更多详细信息,请参阅 EdgeSplitter的源代码


推荐阅读