首页 > 解决方案 > Keras 中的“model.trainable = False”是什么意思?

问题描述

我想在 Keras 中冻结一个预训练的网络。我base.trainable = False在文档中找到。但我不明白它是如何工作的。len(model.trainable_weights)我发现我有 30 个可训练的重量。这个怎么可能?该网络显示总可训练参数:16,812,353。冷冻后,我有 4 个可训练的重量。也许我不明白参数和重量之间的区别。不幸的是,我是深度学习的初学者。也许有人可以帮助我。

标签: pythonkerasdeep-learning

解决方案


默认情况下, KerasModel是可训练的- 您有两种方法可以冻结所有权重:

  1. model.trainable = False 编译模型之前
  2. for layer in model.layers: layer.trainable = False- 在编译之前和之后工作

(1) 必须在编译前完成,因为 Kerasmodel.trainable在编译时将其视为布尔标志,并在后台执行 (2)。完成上述任一操作后,您应该会看到:

print(model.trainable_weights)
# [] 

关于文档,可能已过时 - 请参阅上面的链接源代码,最新。


推荐阅读