tensorflow - 在具有多个 Keras 模型的 TF2 自定义训练循环中应用梯度的正确方法
问题描述
我正在努力使用涉及多个 Keras 模型的 GradientTape 实现自定义训练循环。我有 3 个网络,model_a
,model_b
和model_c
。我创建了一个列表来保存它们trainbale_weights
:
trainables = list()
trainables.append(model_a.trainable_weights) # CovNet
trainables.append(model_b.trainable_weights) # CovNet
trainables.append(model_c.trainable_weights) # Fully Connected Network
然后我计算损失并尝试将梯度应用为:
loss = 0.
optimizer = tf.keras.optimizers.Adam()
for _, (x, y) in enumerate(train_dataset):
with tf.GradientTape() as tape:
y = ...
loss = ... # custom loss function!
gradients = tape.gradient(loss, trainables)
optimizer.apply_gradients(zip(gradients, trainables))
但是我收到以下错误,我不确定错误出在哪里:
AttributeError: 'list' object has no attribute '_in_graph_mode'
如果我迭代渐变和可训练然后应用渐变它可以工作,但我不确定这是否是正确的方法。
for i in range(len(gradients)):
optimizer.apply_gradients(zip(gradients[i], trainables[i]))
解决方案
问题是tape.gradient
期望trainables
是可训练变量的平面列表,而不是列表列表。您可以通过将所有可训练权重连接到一个平面列表中来解决此问题:
trainables = model_a.trainable_weights + model_b.trainable_weights + model_c.trainable_weights
推荐阅读
- isabelle - Isabelle:数据类型的补充
- visual-c++ - 将字符串转换为 ASCII C++
- javascript - 登录后重定向不正确 - Spring Boot
- c++ - 查找数组中哪一行的平均温度最高
- javascript - 如何让我的文本出现在我的滑块顶部?
- amazon-web-services - AWS chmod 400 pem:没有这样的文件或目录问题
- javascript - 一旦屏幕通过特定点或 jsx 标签,如何获取滚动事件?
- c++ - C ++如何生成-32到32或-64到64之间的随机数并且不包括零值?
- javascript - 图像在 jQuery 中没有按顺序褪色?
- javascript - 在 JavaScript 中验证表单