python - 使用 StellarGraph 的 Watch-your-step 模型无法在 GPU 上运行
问题描述
我正在尝试使用 StellarGraph 使用 WatchYourStep 算法训练一个大型图形嵌入。
出于某种原因,该模型仅在 CPU 上进行训练,而不使用 GPU。
使用:
- TensorFlow-GPU 2.3.1
- 有 2 个 GPU,cuda 10.1
- 在 nvidia-docker 容器内运行。
- 我知道 tesnorflow 确实找到了 GPU。(
tf.debugging.set_log_device_placement(True)
) - 我试图在
with tf.device('/GPU:0'):
- 我试图用
tf.distribute.MirroredStrategy()
. - 尝试卸载 tensorflow 并重新安装 tensorflow-gpu。
然而,在运行nvidia-smi时,我在 GPU 上看不到任何活动,而且训练速度非常慢。
如何调试这个?
def watch_your_step_model():
'''use the config to geenrate the WatchYourStep model'''
cfg = load_config()
generator = generator_for_watch_your_step()
num_walks = cfg['num_walks']
embedding_dimension = cfg['embedding_dimension']
learning_rate = cfg['learning_rate']
wys = WatchYourStep(
generator,
num_walks=num_walks,
embedding_dimension=embedding_dimension,
attention_regularizer=regularizers.l2(0.5),
)
x_in, x_out = wys.in_out_tensors()
model = Model(inputs=x_in, outputs=x_out)
model.compile(loss=graph_log_likelihood, optimizer=optimizers.Adam(learning_rate))
return model, generator, wys
def train_watch_your_step_model(epochs = 3000):
cfg = load_config()
batch_size = cfg['batch_size']
steps_per_epoch = cfg['steps_per_epoch']
callbacks, checkpoint_file = watch_your_step_callbacks(cfg)
# strategy = tf.distribute.MirroredStrategy()
# print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# with strategy.scope():
model, generator, wys = watch_your_step_model()
train_gen = generator.flow(batch_size=batch_size, num_parallel_calls=8)
train_gen.prefetch(20480000)
history = model.fit(
train_gen,
epochs=epochs,
verbose=1,
steps_per_epoch=steps_per_epoch,
callbacks = callbacks
)
copy_last_trained_wys_weights_to_data()
return history, checkpoint_file
with tf.device('/GPU:0'):
train_watch_your_step_model()
解决方案
我只是按照以下说明操作:https ://github.com/stellargraph/stellargraph/issues/546 。
它对我有用。
基本上,您必须从 stellargraph github 编辑文件 setup.py 并删除 tensorflow 要求(第 25 和 27 行https://github.com/stellargraph/stellargraph/blob/develop/setup.py)。
推荐阅读
- vue.js - 使用 v-for 在单击时反转布尔值?
- javascript - 什么相当于javascript中的reduce
- c++ - 即使包含库,对某些函数的未定义引用
- delphi - 如何比较枚举类型集
- python - 无法使用 python ctypes 加载 C++ 共享库
- azure - Azure App Insights 与 Gov Cloud Stage 和 Prod 上的 APIM 集成
- c - 逐字节打印
- php - Wordpress 数据库意外更改了 URL
- symfony - 编辑控制器中的 config.yml 文件
- split - 在 macOS bash 上使用 split -d