python - TensorBoard tf.cond 性能
问题描述
我写了一个非常简单的自定义层如下:
class Custom_Layer1234(keras.layers.Layer):
def __init__(self, inputname , units=45, input_dim=45):
super(Custom_Layer1234, self).__init__()
w_init = tf.random_normal_initializer()
b_init = tf.zeros_initializer()
self.w_0 = tf.Variable(initial_value=w_init(shape=(input_dim, units,)),
name='w0{}'.format(inputname), trainable=True,)
self.w_1 = tf.Variable(initial_value=w_init(shape=(input_dim, units,)),
name='w1{}'.format(inputname), trainable=True,)
self.b_0 = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"),
name='b0{}'.format(inputname), trainable=True)
self.b_1 = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"),
name='b1{}'.format(inputname), trainable=True)
@tf.function
def call(self, inputs):
diff_1 = inputs[0][1]
if diff_1 <= 0 :
y = tf.matmul(inputs, self.w_0) + self.b_0
else:
y = tf.matmul(inputs, self.w_1) + self.b_1
return tf.nn.relu(y)
然后,我将输入张量切成更小的张量,将其循环并将每个“切片”馈送到此自定义层之一,因此我怀疑在运行时基于每个输入 [i] 存在各种“不同分支”图。该模型运行速度很慢,而 GPU 利用率大部分时间低于 30%。
我打开 Tensorboard 来检查如何更快地训练模型并按照它的说明进行操作。我的输入管道没有问题。预处理也很好。我还尝试了 Tensorboard 的建议:
TF_GPU_THREAD_MODE=gpu_private
mixed-precision
没有人会缩短培训时间。我删除了层中的 if 条件并重新运行模型:
# if diff_1 <= 0 :
y = tf.matmul(inputs, self.w_0) + self.b_0
# else:
# y = tf.matmul(inputs, self.w_1) + self.b_1
并且速度明显更快,所以我猜测许多 tf.cond 会导致速度变慢。我在想是否有办法为每个输入[i]单独保存这些“路线”尽可能多的独特图形(tf.Graph),并且可以重复使用,这样就不需要重复进行这种计算(我的假设是否正确?)。或者 TensorFlow 已经在这样做了。我可以从对代码或运行模式或图形优化器 Grappler 的任何更改中受益吗?这样训练可以更快。
非常感谢
解决方案
在计算图中引入条件确实会显着降低 TF 性能。例如,使用带有mask
参数的 RNN 层会显着降低其性能。
我不是这方面的专家,但我对这个问题的理解如下。要更新网络的参数,TF 必须跟踪所有梯度。当没有条件语句时,可以对批次中的所有样本应用相同的公式。相反,对于条件语句,必须知道数据通过计算图的各个“路径”。我的猜测是,这会使矢量化(使用批处理)无效。
出于类似的原因,矢量化对于通过网络的数据前向传播也可能变得不那么有效。
考虑到代码优化,您可以尝试以下方法。制作两个不同的图层,layer_A
然后layer_B
,这将根据值进行您想要的计算diff_1
。然后,不要使用if
语句,而是尝试
# you may need to change indecies of inputs
result = tf.where(inputs[:,0,1]>0, layer_A(inputs), layer_B(inputs))
不确定这会更快,但我会尝试一下。
推荐阅读
- c# - 实时搜索数百万条记录
- c++ - 不能用()创建原子
- angular - Angular 9 this._datepicker._registerInput 不是 MatDatepickerInput.set [as matDatepicker] 的函数
- wordpress - 域链重定向
- mysql - 编写此 mySQL 查询的正确方法 select count(ag.merk) where merk = 'Ferrari' from auto_gegevens ag
- webrtc - 在 mac 基础上运行打开服务器
- python - mysql.connector.errors.DatabaseError: 1366 (HY000): Incorrect integer value: '' for column 'risklevel' at row 1
- android - 当播放器暂停和应用程序被杀死时,ExoPlayer 通知重新出现
- jupyter-lab - 在 Jupyter 实验室中遇到 ModuleNotFoundError
- php - Laravel Lucid 模拟工作