python - TensorFlow 功能太慢
问题描述
我创建了一个 TensorFlow 模型,它处理属于一个观察的 50 个不同图像。所以 A Batch 的形式为(32, 50, 128, 128, 1)
。模型定义为:
input = layers.Input((50, 128, 128, 1))
sub_models = []
for mcol in range(50):
x = layers.Conv2D(32, kernel_size=(3, 3), input_shape=(128, 128, 1))(input[:, mcol, :, :])
x = layers.MaxPool2D(pool_size=(2, 2))(x)
x = layers.Dropout(0.25)(x)
x = layers.Flatten()(x)
x = layers.Dense(128)(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(32)(x)
sub_models.append(x)
combined = tf.keras.layers.concatenate(sub_models)
z = layers.Dense(1024)(combined)
z = layers.Dense(512)(z)
z = layers.Dense(512)(z)
z = layers.Dense(2, activation="softmax")(z)
model = tf.keras.Model(input, z)
模型看起来像这样(输入更少的更简单的版本):
我的火车步骤如下:
with tf.GradientTape() as tape:
logits = model(x_batch_train[:, :50, :, :, None], training=True)
loss_value = loss(Y, logits)
问题是训练步骤非常慢,在 V100 GPU 上每一步都需要几秒钟。我认为问题出在for循环上。有没有办法以更智能的方式定义模型,并且需要更少的时间?
解决方案
如果您将数据重新格式化为(32、128、128、50),50 个通道,每个图像一个,您可以使用 groups 关键字参数(https://www.tensorflow.org/api_docs /python/tf/keras/layers/Conv2D#args )
import tensorflow as tf
from tensorflow.keras import layers
input = layers.Input((50, 128, 128, 1))
# Reshape data (50, 128, 128, 1) -> (50, 128, 128)
x = tf.keras.backend.squeeze(input, axis=-1)
# Transpose (50, 128, 128) -> (128, 128, 50)
x = layers.Permute((2, 3, 1), input_shape=(50, 128, 128))(x)
# NOTE! The groups = 50 part is what breaks up the network
x = layers.Conv2D(32 * 50,
kernel_size=(3, 3), input_shape=(128, 128, 50), groups=50)(x)
# Reshape to max pool 3D
# 126, 126, 50 * 32 -> 126, 126, 50, 32
x = layers.Reshape((126, 126, 50, 32))(x)
x = layers.MaxPool3D(pool_size=(2, 2, 1))(x)
x = layers.Dropout(0.25)(x)
# Change (63, 63, 50, 32) -> (50, 63 * 63 * 32)
x = layers.Permute((3, 1, 2, 4), input_shape=(63, 63, 50, 32))(x)
x = layers.Reshape((50, 63 * 63 * 32))(x)
x = layers.Dense(128)(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(32)(x)
# Join everything together as per the spec
combined = layers.Flatten()(x)
z = layers.Dense(1024)(combined)
z = layers.Dense(512)(z)
z = layers.Dense(512)(z)
z = layers.Dense(2, activation="softmax")(z)
model = tf.keras.Model(input, z)
话虽如此,循环本身不应成为速度瓶颈(确保您实际上是在 GPU 上运行),因为您只是在构建计算图,但这仍然应该加快速度。
推荐阅读
- ffmpeg - ffmpeg 输出文件大小的增长速度快于电影长度的线性增长
- r - 如何从 modelsummary 包中的 msummary 的 lmer() 模型中提取拟合优度统计信息
- rxjs - 想用rx.js合并多个数据源,支持增删数据源
- reactjs - 用道具反应测试浅快照问题
- scala - 是否有工具可以解释 Scala 程序中每个符号的含义以及如何解析它?
- c# - 查询字符串和字符串比较
- python - R retriculate 直接从 R 代码中执行 python 代码
- java - 在 Google Slides 中打开 ppt 或 pptx,在 Android Studio 中打开 google docs 中的 doc 或 word
- javascript - 将字符串数组转换为整数数组并将它们插入行
- c# - C# 使用多暗淡数组作为 DataRow (MSTest) 的输入