python - tf.GradientTape() 位置对模型训练时间的影响
问题描述
我正在尝试更新每个时期的权重,但我正在分批处理数据。问题是,为了规范化损失,我需要在训练循环之外记录 TensorFlow 变量(以进行跟踪和规范化)。但是当我这样做时,训练时间是巨大的。
我认为,它会将所有批次的变量累积到图中并在最后计算梯度。
我已经开始在 for 循环外和 for 循环内跟踪变量,后者比第一个更快。我对为什么会发生这种情况感到困惑,因为无论我做什么,我的模型的可训练变量和损失都保持不变。
# Very Slow
loss_value = 0
batches = 0
with tf.GradientTape() as tape:
for inputs, min_seq in zip(dataset, minutes_sequence):
temp_loss_value = my_loss_function(inputs, min_seq)
batches +=1
loss_value = loss_value + temp_loss_value
# The following line takes huge time.
grads = tape.gradient(loss_value, model.trainable_variables)
# Very Fast
loss_value = 0
batches = 0
for inputs, min_seq in zip(dataset, minutes_sequence):
with tf.GradientTape() as tape:
temp_loss_value = my_loss_function(inputs, min_seq)
batches +=1
loss_value = loss_value + temp_loss_value
# If I do the following line, the graph will break because this are out of tape's scope.
loss_value = loss_value / batches
# the following line takes huge time
grads = tape.gradient(loss_value, model.trainable_variables)
当我在 for 循环内声明 tf.GradientTape() 时,它非常快,但我在外面它很慢。
PS - 这是一个自定义损失,架构只包含一个大小为 10 的隐藏层。
我想知道,tf.GradientTape() 的位置的不同之处以及它应该如何用于在批处理数据集中更新每个时期的权重。
解决方案
磁带变量主要用于观察可训练的张量变量(记录变量的先前值和变化值),以便我们可以根据损失函数计算一个训练时期的梯度。它是此处用于记录变量状态的 python 上下文管理器构造的实现。关于 python 上下文管理器的优秀资源在这里. 因此,如果在循环内部,它将记录该前向传递的变量(权重),以便我们可以一次性计算所有这些变量的梯度(而不是像在没有像 tensorflow 这样的库的幼稚实现中那样基于堆栈的梯度传递) . 如果它在循环之外,它将记录所有时期的状态,并且根据 Tensorflow 源代码,如果使用 TF2.0,它也会刷新,这与模型开发人员必须负责刷新的 TF1.x 不同。在您的示例中,您没有设置任何作家,但如果设置了任何作家,它也会这样做。因此,对于上面的代码,它将继续记录(内部使用 Graph.add_to_collection 方法)所有权重,并且随着 epoch 的增加,您应该会看到速度变慢。减速率将与网络大小(可训练变量)和当前 epoch 数成正比。
所以把它放在循环内是正确的。此外,梯度应该应用在 for 循环内部而不是外部(在相同的缩进级别),否则你只在训练循环结束时应用梯度(在最后一个 epoch 之后)。我看到你的训练可能不适合当前梯度检索的位置(之后它被应用到你的代码中,尽管你在代码片段中省略了它)。
我刚刚发现的另一个关于梯度磁带的好资源。
推荐阅读
- python-3.x - 软件包安装失败 - OsX 中的 psycopg2
- python - 重命名多个子目录中的单个 .txt 文件
- c# - Discord:使用 .NET 库更改区域?
- vue.js - 将数据从 Vuex getter 传递到子组件的问题
- node.js - 有没有办法将 telnet 会话输出作为邮递员测试脚本的一部分?
- php - 检查数组php中的任何一个匹配项
- java - 如何识别已取消的 ScheduledFuture 是否实际上未取消?
- python - 使用python创建由2种颜色的框组成的图像
- winforms - 使用 OpenAL 和 VC++ 在 winforms 中绘制音频波形
- excel - 做直到循环。满足一个或另一个条件