首页 > 解决方案 > 在提供训练数据时,Keras 似乎比 tensorflow 慢

问题描述

我目前正在将一个项目从 tensorflow 转换为 keras。

一切似乎都很好,我对使用 keras 构建模型的简单程度印象深刻。但是,使用 Keras 进行训练要慢得多,因为我的 GPU 使用率要低得多。

我正在使用Tensorflow 生成器数据集来加载我的训练数据。幸运的是,keras 似乎毫无问题地接受了这一点。

问题

然而,在使用 tensorflow 对数据集进行训练时,我的 GPU 平均利用率约为 70%。当我使用 Keras 使用相同的数据集生成器训练相同的网络时,我只获得了大约 35% 的 GPU 利用率

问题似乎是我有一个非常简单的网络,因此我需要尽可能快地将数据提供给 GPU,因为与实际进行反向传播相比,这里花费了很多时间。

使用 tensorflow 的关键似乎是使用 feed-dicts,而是使用我的数据集中的张量作为图形的输入。基本上,这可以减少到

x, y = iterator.get_next()                     # Get the dataset tensors
loss = tf.reduce_sum(tf.square(y - model_out)) # Use the y tensor directly for loss
# Use x as the input layer in my model <- Implememntation omitted

我想用 keras 实现同样的目标,因此我做了类似的事情,我将 x 设置为输入,将 y 设置为目标张量。(我能以某种方式摆脱将 y 放入目标张量的列表中吗?)

x, y = iterator.get_next()                                    # Get the dataset tensors

model_input = keras.Input(tensor=x)
# Build model with model_input as input layer and something as output layer. <- Implememntation omitted
model = tf.keras.Model(inputs=model_input, outputs=something) # Insert the dataset tensor directly as input

model.compile(loss='mean_squared_error',                                                                                                                   
              optimizer=#something,
              metrics=['accuracy'],                                                                                                                        
              target_tensors=[y]) # Input the dataset y tensor directly for use in the loss calculation

基本上应该将 x 设置为输入张量,将 y 设置为直接用于损失的张量,就像在 tensorflow 版本中一样。我现在可以model.fit在不显式提供 x 和 y 参数的情况下进行校准,因为它们直接在图中使用

model.fit(validation_data=validation_iterator,
          steps_per_epoch=5000,
          validation_steps=1)

对我来说,似乎我现在正在使用 keras 和 tensorflow 做同样的事情,但是,keras 的速度要慢得多,大约是纯 tensorflow 实现的 GPU 利用率的一半我在这里做错了什么,或者如果我应该接受这种减速想使用 keras 吗?

标签: pythontensorflowkeras

解决方案


我在 TensorFlow 1.13 上遇到了同样的问题,并通过升级到 TensorFlow 1.14 / 2.0.0 解决了这个问题。

为了进行完整性检查,我将 TensorFlow 图(原样)包装为 Keras 模型,并使用model.fit(). 使用 TensorFlow 1.13 时,相对于训练纯 TensorFlow 实现的吞吐量,我的吞吐量降低了 50%。在这两种情况下,我都使用了相同的tf.data.dataset输入管道。

使用 TensorFlow 1.14 版解决了这个问题(现在我在上述两种情况下都获得了相同的吞吐量)。后来我迁移到 TensorFlow 2.0.0 (alpha) 并且在这两种情况下也获得了相同的吞吐量。


推荐阅读