首页 > 解决方案 > 为感知损失计算 VGG 特征的正确方法

问题描述

在计算VGG Perceptual loss的时候,虽然没见过,但是感觉把 GT 图像的 VGG 特征的计算封装在里面就可以了torch.no_grad()

所以基本上我觉得以下就可以了,

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

# Now compute L1 loss

或者应该使用,

gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)

在这两种方法requires_grad中,都设置了 VGG 参数False并将 VGG 置于eval()模式。

第一种方法将节省大量 GPU 资源,并且感觉应该在数值上等于第二种方法,因为不需要通过 GT 图像进行反向传播。但在大多数实现中,我发现第二种方法用于计算 VGG 感知损失。

那么在 PyTorch 中实现 VGG 感知损失时,我们应该选择哪个选项呢?

标签: pythondeep-learningpytorchconv-neural-networkvgg-net

解决方案


第一种方式:

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

尽管 VGG 处于eval模式并且其参数保持固定,但您仍然需要通过它将梯度从特征损失传播到输出nw_op。但是,没有理由计算这些梯度 wrt gt


推荐阅读