首页 > 解决方案 > 使用model.fit时如何将'training'参数传递给tf,keras.Model

问题描述

所以我有这个由子类化 API 编写的模型,调用签名看起来像 call(x, training),在执行 batchnorm 和 dropout 时需要训练参数来区分训练和非训练。当我使用 model.fit 时,如何让模型前向传递知道我处于训练模式或评估模式?

谢谢!

标签: pythontensorflowkeras

解决方案


实际上,在文档https://www.tensorflow.org/beta/guide/keras/custom_layers_and_models中,它说“某些层,特别是 BatchNormalization 层和 Dropout 层,在训练和推理期间具有不同的行为。对于这样的层,在调用方法中公开训练(布尔)参数是标准做法。

通过在调用中公开此参数,您可以启用内置的训练和评估循环(例如拟合)以在训练和推理中正确使用该层。”所以我认为训练参数是由 keras 自动传入的。我试图删除训练参数的默认值并且没有抛出任何错误,所以很可能是 keras 内置循环做了这件事。


推荐阅读