python - 在提供训练数据时,Keras 似乎比 tensorflow 慢
问题描述
我目前正在将一个项目从 tensorflow 转换为 keras。
一切似乎都很好,我对使用 keras 构建模型的简单程度印象深刻。但是,使用 Keras 进行训练要慢得多,因为我的 GPU 使用率要低得多。
我正在使用Tensorflow 生成器数据集来加载我的训练数据。幸运的是,keras 似乎毫无问题地接受了这一点。
问题
然而,在使用 tensorflow 对数据集进行训练时,我的 GPU 平均利用率约为 70%。当我使用 Keras 使用相同的数据集生成器训练相同的网络时,我只获得了大约 35% 的 GPU 利用率
问题似乎是我有一个非常简单的网络,因此我需要尽可能快地将数据提供给 GPU,因为与实际进行反向传播相比,这里花费了很多时间。
使用 tensorflow 的关键似乎是不使用 feed-dicts,而是使用我的数据集中的张量作为图形的输入。基本上,这可以减少到
x, y = iterator.get_next() # Get the dataset tensors
loss = tf.reduce_sum(tf.square(y - model_out)) # Use the y tensor directly for loss
# Use x as the input layer in my model <- Implememntation omitted
我想用 keras 实现同样的目标,因此我做了类似的事情,我将 x 设置为输入,将 y 设置为目标张量。(我能以某种方式摆脱将 y 放入目标张量的列表中吗?)
x, y = iterator.get_next() # Get the dataset tensors
model_input = keras.Input(tensor=x)
# Build model with model_input as input layer and something as output layer. <- Implememntation omitted
model = tf.keras.Model(inputs=model_input, outputs=something) # Insert the dataset tensor directly as input
model.compile(loss='mean_squared_error',
optimizer=#something,
metrics=['accuracy'],
target_tensors=[y]) # Input the dataset y tensor directly for use in the loss calculation
基本上应该将 x 设置为输入张量,将 y 设置为直接用于损失的张量,就像在 tensorflow 版本中一样。我现在可以model.fit
在不显式提供 x 和 y 参数的情况下进行校准,因为它们直接在图中使用
model.fit(validation_data=validation_iterator,
steps_per_epoch=5000,
validation_steps=1)
对我来说,似乎我现在正在使用 keras 和 tensorflow 做同样的事情,但是,keras 的速度要慢得多,大约是纯 tensorflow 实现的 GPU 利用率的一半我在这里做错了什么,或者如果我应该接受这种减速想使用 keras 吗?
解决方案
我在 TensorFlow 1.13 上遇到了同样的问题,并通过升级到 TensorFlow 1.14 / 2.0.0 解决了这个问题。
为了进行完整性检查,我将 TensorFlow 图(原样)包装为 Keras 模型,并使用model.fit()
. 使用 TensorFlow 1.13 时,相对于训练纯 TensorFlow 实现的吞吐量,我的吞吐量降低了 50%。在这两种情况下,我都使用了相同的tf.data.dataset
输入管道。
使用 TensorFlow 1.14 版解决了这个问题(现在我在上述两种情况下都获得了相同的吞吐量)。后来我迁移到 TensorFlow 2.0.0 (alpha) 并且在这两种情况下也获得了相同的吞吐量。
推荐阅读
- base64 - 是的 dataURI 验证
- python - 在 Pandas Dataframe 中组合列的值和状态条件
- c++ - 为什么当 std::vector 重定位存储时,复制 ctor 优于 move-ctor?
- php - 如何使用 API Gateway + Lambda + PHP 制作包含 zip 文件的响应
- dictionary - 带有嵌套列表的字典中的 For 循环
- python - 如何将所有图像数据转换为数组
- css - VueJS
- keras - 生成 300 * 300 * 3 图像的 GAN 的生成器和判别器模型的架构应该是什么?
- apache-spark-sql - 在 Spark Sql 中,如果我们有 when(A&B) 并且如果 A 被评估为 false,那么 B 还会被评估吗?
- python - 迭代空的 2d 与非空的 python