python - 使用 StellarGraph 创建嵌入是不可重现的
问题描述
我正在使用 StellarGraph(一个很棒的图神经网络包),并且正在尝试为特定的图/特征集创建嵌入。不幸的是,尽管每次都提供相同的信息,但每次创建/训练图表时嵌入都是不同的。
是这个错误,还是我错误地使用了 StellarGraph?
下面是演示该问题的代码:
import networkx as nx
import random
import numpy as np
import pandas as pd
import keras
import stellargraph as sg
from stellargraph.mapper import GraphSAGELinkGenerator, GraphSAGENodeGenerator
from stellargraph.layer import GraphSAGE, link_classification
from stellargraph.data import UnsupervisedSampler
# Establish random seed
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
# Create a graph from well-known karate club data
print(f"Creating graph")
graph = nx.karate_club_graph()
# Create features for each node
print(f"Creating features")
features = []
nodes = list(graph.nodes)
columns = ["c-" + str(x) for x in range(10)]
nodes.sort()
for node in nodes:
f = {c: random.random() for c in columns}
features.append(f)
features_df = pd.DataFrame(features)
print(f"features_df: \n{features_df}")
for i in range(2):
print(f"----- Iteration: {i} -----")
# Create the model and generators
print(f"Creating the model and generators")
Gs = sg.StellarGraph(graph, node_features=features_df)
unsupervisedSamples = UnsupervisedSampler(Gs, nodes=graph.nodes(), length=5, number_of_walks=3, seed=RANDOM_SEED)
train_gen = GraphSAGELinkGenerator(Gs, 50, [5, 5], seed=RANDOM_SEED).flow(unsupervisedSamples)
graphsage = GraphSAGE(layer_sizes=[100, 100], generator=train_gen, bias=True, dropout=0.0, normalize="l2")
x_inp_src, x_out_src = graphsage.node_model(flatten_output=False)
x_inp_dst, x_out_dst = graphsage.node_model(flatten_output=False)
x_inp = [x for ab in zip(x_inp_src, x_inp_dst) for x in ab]
x_out = [x_out_src, x_out_dst]
edge_embedding_method = "l2"
prediction = link_classification(output_dim=1, output_act="sigmoid", edge_embedding_method=edge_embedding_method)(x_out)
# Create and train the Keras model
model = keras.Model(inputs=x_inp, outputs=prediction)
learning_rate = 1e-2
model.compile(
optimizer=keras.optimizers.Adam(lr=learning_rate),
loss=keras.losses.binary_crossentropy,
metrics=[keras.metrics.binary_accuracy])
_ = model.fit_generator(train_gen, epochs=5, verbose=2, use_multiprocessing=False, workers=1, shuffle=False)
# Create the embeddings
print(f"Creating the embeddings")
nodes = list(graph.nodes)
nodes.sort()
print(f"Nodes: {nodes}")
# Create a generator that serves up nodes for use in embedding prediction / creation
node_gen = GraphSAGENodeGenerator(Gs, 50, [5, 5], seed=RANDOM_SEED).flow(nodes)
embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)
embeddings = embedding_model.predict_generator(node_gen, workers=4, verbose=1)
embeddings = embeddings[:, 0, :]
np.set_printoptions(threshold=10)
print(f"embeddings: {embeddings.shape} \n{embeddings}")
执行代码时有许多调试(打印输出)语句。(示例输出如下所示)。请注意,尽管输入、图形配置、模型配置和随机查看值相同,但嵌入是不同的。
----- Iteration: 0 -----
:
:
1/1 [==============================] - 0s 58ms/step
embeddings: (34, 100)
[[-0.10566715 0.02253576 -0.18743701 ... -0.1028127 0.03689012
-0.02482301]
[-0.03171733 0.01606975 -0.08616363 ... -0.11775644 0.0429472
-0.02371055]
[-0.05802531 0.03910012 -0.10229243 ... -0.15050544 0.06637941
-0.01950052]
...
[ 0.03011296 0.08852117 -0.01836969 ... -0.154132 0.03844732
-0.08643046]
[ 0.01052345 -0.0123206 0.08913474 ... -0.11741614 0.03202919
-0.04432516]
[ 0.01951274 0.06263477 0.07959272 ... -0.10350229 0.05735112
-0.0368157 ]]
:
:
----- Iteration: 1 -----
embeddings: (34, 100)
[[ 0.11182436 -0.02642134 0.01168384 ... 0.10322241 -0.01680471
-0.03918815]
[ 0.02391489 0.02674667 -0.00091334 ... 0.12946768 -0.02389602
-0.01414653]
[ 0.08718258 -0.01711811 -0.05704292 ... 0.13477756 -0.00658288
-0.05889895]
...
[ 0.06843725 -0.13134597 -0.10870655 ... 0.11091235 -0.05146989
-0.06138216]
[-0.00593233 -0.05901312 -0.02113489 ... -0.01590953 -0.02516254
-0.02280537]
[ 0.00871993 -0.04059998 -0.07237951 ... -0.01590569 -0.00954109
-0.01116194]]
解决方案
这以前是 stellargraph 中的一个错误,现在已在 v0.9.0 中解决https://github.com/stellargraph/stellargraph/releases/tag/v0.9.0
无监督 GraphSAGE 现在已经更新并测试了可重复性。确保设置了所有种子,运行相同的管道应该提供可重现的嵌入。
目前,无监督 GraphSAGE 的“确保设置所有种子”意味着:
- 修复这些外部包的种子:
numpy
、tensorflow
和random
- 在构造 UnsupervisedSampler 和 GraphSAGELinkGenerator 对象时提供种子。这些类用于执行随机游走和邻域抽样。
推荐阅读
- javascript - 更新状态后组件不返回新元素?
- unity3d - 我需要实现哪些模块或行为来实现所描述的行为?
- azure - Azure Function 中的本地文件
- postgresql - AWS RDS PostgreSQL - 从/复制到 EC2 实例上的 csv 文件
- c++ - KITTI 数据集有多少地面实况误差?
- javascript - 添加图像以键入和删除效果
- vue.js - 如何创建 vuetify 图标滑动控制台
- ios - 使用 .netrc 文件安装 iOS 地图框
- python - pgsync.exc.ForeignKeyError:'“public.book”和“public.book_author”之间没有外键关系'
- java - java.lang.ClassNotFoundException:首先