首页 > 解决方案 > pytorch中检查点之间的线性插值 - 批量标准化

问题描述

我想在两个 PyTorch 训练的模型检查点之间进行线性插值。对于除批量归一化之外的所有层,我加载指定的 dict 并简单地进行线性整合,如下所示:

def interpolate_state_dicts(state_dict_1, state_dict_2, weight):
return {key: (1 - weight) * state_dict_1[key] + weight * state_dict_2[key]
        for key in state_dict_1.keys()}

我不知道我们是否可以简单地对 BN 层参数(权重、偏差、运行均值、运行标准)做同样的事情?我想这并不是那么简单,因为平均值和标准是针对特定批次计算的。

标签: deep-learningpytorchinterpolation

解决方案


推荐阅读