tensorflow - 如何创建仅在评估阶段生效的 keras 层(并且在训练期间是透明的)?
问题描述
我想在我的模型中添加一个层,该层在评估期间接受输入,应用一些转换(在这种情况下是量化,但可以是任何值)并将其作为输出返回。然而,这一层在训练期间必须是完全透明的,这意味着它必须返回相同的输入张量。
我写了以下函数
from keras.layers import Lambda
import keras.backend as K
def myquantize(x):
return K.in_test_phase( K.clip(K.round(x*(2**5))/(2**5),-3.9,3.9) , x)
然后我通过 Lambda 层使用它
y = keras.layers.Conv1D(**args1)
y = keras.layers.AveragePooling1D(pool_size=2)(y)
y = keras.layers.Lambda(myquantize)(y)
y = keras.layers.Conv1D(**args2)
#...
现在,原则上 K.in_test_phase 应该在训练期间返回 x,在测试期间返回那个表达式。但是,用这样的层训练网络会阻止网络学习(即训练损失在 3 个 epoch 后停止减少),而如果我删除它,网络会继续正常训练。我假设这一层在训练期间实际上并不像预期的那样透明。
解决方案
in_test_phase
有一个参数training
,您可以显式设置该参数以指示您是否正在训练。如果您没有显式设置它,则使用 的值learning_phase
。当您重置图形或调用模型的不同类型的拟合/预测/评估函数时,此值会不断变化。
由于您的完整代码不存在,您可以使用training
参数。在训练期间将其设置为 True。save_weights
然后使用模型的功能保存模型的权重。当您希望测试您的模型时,请将training
参数设置为 False。然后使用函数加载权重load_weights
,然后您可以进行相应的操作。
推荐阅读
- arrays - 我无法从对象中获取值。获取未定义的对象错误
- python - 使用 Python splitlines() 将文本文件转换为列表,同时还将一些行组合成列表中的单个项目
- php - php-express - 如何查找服务器信息
- c# - winspool.Drv ClosePrinter 实际上并没有打印,但我可以在队列中看到它 c#
- r - ggplot2中的小时刻度
- terraform - terraform 查找默认值必须与地图元素具有相同的类型
- node.js - 安装 nodemailer 导致找不到模块错误
- sql-server - 如何将此 SQL Server 公式转换为 Excel 公式?
- mongodb - mongo:未能使用文本索引来满足 $text 查询
- opengl-es - OpenGL ES 2.0 中的纹理裁剪替换?