tensorflow - 如何在 keras 中执行虚拟批量标准化(VBN)
问题描述
本文讨论VBN 。并实现了Here、Here和Here。我不想去核心/完整代码。我只想知道,如何使用 VBN 作为 keras 层,因为我不是很专业的 tensorflow/keras 编码器。我一般使用简单的批量归一化(BN)如下
model.add(BatchNormalization(momentum=0.8))
以类似的方式,如何在以下 keras 代码中使用 VBN 而不是 BN?
model.add(Dense(256,input_dim=self.input_dim))
model.add(LeakyReLU(alpha=.2))
model.add(BatchNormalization(momentum=0.8))%I want to replace this with VBN
model.add(Dense(512))
......
.......
解决方案
他们在第一个链接中说
该
__init__
API 旨在尽可能地模仿tf.compat.v1.layers.batch_normalization
。
因此,如果您查看https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization,
它会说您将此函数用作...
x_norm = tf.layers.batch_normalization(x, training=training)
因此,如果我理解得很好,
使用功能 API https://keras.io/getting-started/functional-api-guide/,
您可能应该执行以下操作:
layer_n = VBN(**kwargs, layer_n-1)
我希望它有帮助
推荐阅读
- pytorch - 在预训练的 VGG16 模型中冻结卷积层
- c# - 修改 JSON.Net 如何处理序列化类,而不直接访问 JSON.Net
- drools - drools中重复的pojo验证规则
- vmware - Windows 11 更新后 VMware 崩溃
- traefik - 有没有办法让 traefik 使用集群 ip 而不是 pod ip
- javascript - 浏览器窗口关闭后如何调用firebase函数
- c++ - 如果类实现类模板,则定义函数,而不考虑模板参数
- postgresql - 使用 Postrgres JSON 函数如何获取数组的聚合对象
- azure - Azure Monitor - 服务总线命名空间的图表和警报使用情况(当前大小/总大小)
- ios - DJI WaypointMission - 获取定位失败/无人机定位