首页 > 解决方案 > set_weights 停止在 tensorflow 2 上工作

问题描述

出于不相关的原因,从 keras --> tensorflow.keras 更新我们的代码库。

keras 2.3.1 张量流 2.1.0

此代码适用于 keras,但在 tf.keras 上失败:

weights = applications.VGG16(weights='imagenet', include_top=False).get_weights()
model.set_weights(tempweights)

错误:

You called `set_weights(weights)` on layer "model" with a  weight list of length 26, but the layer was expecting 32 weights. Provided weights: [array([[[[ 4.29470569e-01,  1.17273867e-01,  3.40...

使用 keras.applications.VGG16().get_weights() 不能修复它,导致完全相同的错误。

已经检查了这些看起来相似的 github 问题,但无法找到我的问题的原因:

https://github.com/keras-team/keras/issues/7229 https://github.com/keras-team/keras/issues/4753

以下作品:

model.weights[0] = weights[0]

看起来像 keras.model.set_weights() 有一些处理不匹配权重的行为,但是 tf.tensorflow.keras 没有这种行为?

确认的。行为发生了变化。Include_top=True 在我们的特定情况下修复了它,因为这使我们的模型匹配。当模型不匹配时,Keras 不会抛出错误,我不确定它在这种情况下究竟做了什么。虽然不打算调查。将解决方案留在这里。

标签: kerastensorflow2.0tf.keras

解决方案


推荐阅读