python - 如何处理多个数据集的批量标准化?
问题描述
我正在处理生成合成数据以帮助训练我的模型的任务。这意味着训练是在合成数据+真实数据上进行的,并在真实数据上进行测试。
有人告诉我,批量归一化层可能试图在训练时找到对所有人都有好处的权重,这是一个问题,因为我的合成数据的分布并不完全等于真实数据的分布。因此,我们的想法是拥有批量标准化层权重的不同“副本”。这样神经网络就可以为合成数据和真实数据估计不同的权重,并且只使用真实数据的权重进行评估。
有人可以建议我在pytorch中实际实现它的好方法吗?我的想法如下,在数据集中的每个训练阶段之后,我都会遍历所有 batchnorm 层并保存它们的权重。然后在下一个纪元开始时,我将再次迭代加载正确的权重。这是一个好方法吗?不过,我不确定在测试时应该如何处理批量规范权重,因为批量规范对它的处理方式不同。
解决方案
听起来您担心的问题是,当为一批真实数据和合成数据计算批范数时,您的神经网络将学习效果很好的权重,然后在测试时它会计算一个批范数只是真实数据?
与其尝试跟踪多个批次规范,您可能只想为您的批次规范层设置track_running_stats
为True
,然后在测试时将其置于 eval 模式。这将导致它在训练时计算多个批次的运行均值和方差,然后它将在稍后的测试时使用该均值和方差,而不是查看测试批次的批次统计信息。
(无论如何,这通常是您想要的,因为根据您的用例,您可能会向部署的模型发送非常小的批次,因此您希望使用预先计算的均值和方差,而不是依赖这些小批次的统计数据.)
如果您真的想在测试时计算新的均值和方差,我要做的不是将包含真实数据和合成数据的单个批次传递到您的网络中,而是传递一批真实数据,然后传递一批合成数据,并在反向传播之前平均两个损失。(请注意,如果您这样做,则以后不应依赖运行均值和方差——您必须将其设置track_running_stats
为False
合理的值。这是因为运行均值和方差统计仅在预期它们对于每个批次大致相同时才有用,并且您通过在不同批次中输入不同类型的数据来极化值。)
推荐阅读
- python - 处理大文件时读取 HDF5 属性并转换为 Pandas 数据框的最快方法是什么?
- python - PyTorch 的 FFT 不保留线性
- webpack - aframe 资产未在 nuxt.js 中加载 gltf
- r - 如何计算一个值存在多少次并计算 R 中的百分比?
- angular - 角度材料:SassError:“@include mat”之后的无效 CSS:预期 1 个选择器或规则,是“.core();”
- python - 如何在 VS Code 中从 Python 3 更改为 Python 2?
- javascript - 猫鼬播种结合到 index.js 文件
- sql - SQL选择所有不等于某个id的行,并将id列替换为该值-不进行交叉连接
- ibm-cloud - 如何使用 python SDK 创建存储桶?
- python - 有没有办法让 selenium 忽略 timeoutexception 错误并在它发生时继续?