首页 > 解决方案 > 如何将 caffe 的 BatchNorm 重量转换为 pytorch BathNorm?

问题描述

caffe 模型的 BathNorm 和 Scale 权重可以从 pycaffe 中读取,在 BatchNorm 中是三个权重,在 Scale 中是两个权重。我尝试使用如下代码将这些权重复制到 pytorch BatchNorm:

if 'conv3_final_bn' == name:
    assert len(blobs) == 3, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.running_mean'] = blobs[0].data
    torch_mod['conv3_final_bn.running_var'] = blobs[1].data
elif 'conv3_final_scale' == name:
    assert len(blobs) == 2, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.weight'] = blobs[0].data
    torch_mod['conv3_final_bn.bias'] = blobs[1].data

这两个 BatchNorm 的行为不同。我也尝试设置 conv3_final_bn.weight=1 和 conv3_final_bn.bias=0 来验证caffe的BN层,结果也不匹配。

我应该如何处理错误的匹配?

标签: caffepytorchpycaffe

解决方案


知道了!caffe 的 BatchNorm 中还有第三个参数。代码应该是:

if 'conv3_final_bn' == name:
    assert len(blobs) == 3, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.running_mean'] = blobs[0].data / blobs[2].data[0]
    torch_mod['conv3_final_bn.running_var'] = blobs[1].data / blobs[2].data[0]
elif 'conv3_final_scale' == name:
    assert len(blobs) == 2, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.weight'] = blobs[0].data
    torch_mod['conv3_final_bn.bias'] = blobs[1].data

推荐阅读