首页 > 解决方案 > 如何将 PyTorch 子模块保持在 eval 模式?

问题描述

我有一个预训练模型,我将它与正在训练的模型结合使用。我希望预训练模型始终处于评估模式,但另一个模型将在评估和训练模式之间来回移动。不过,我仍然希望预训练模型成为另一个模型的子模块(例如,以便所有参数都保留在同一设备上)。有没有办法做到这一点?这是一个最小的例子:

from torch import nn

class FixedModule(nn.Module):
    pass

class TrainableModule(nn.Module):
    def __init__(self, fixed_module):
        super().__init__()
        self.fixed_module = fixed_module

fixed = FixedModule().eval()
assert not fixed.training

trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training

trainable.train()
assert trainable.fixed_module.training  # I'd like this to give an error

我知道我可以解决这个问题,例如,总是做

trainable.train()
trainable.fixed_module.eval()

但这很容易出错,并且不适用于现有代码。

标签: pythonpytorch

解决方案


一种解决方案是像这样覆盖train

from torch import nn

class FixedModule(nn.Module):
    pass

class TrainableModule(nn.Module):
    def __init__(self, fixed_module):
        super().__init__()
        self.fixed_module = fixed_module

    def train(self):
        super().train()
        self.fixed_module.eval()

fixed = FixedModule().eval()
assert not fixed.training

trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training

trainable.train()
assert trainable.fixed_module.training  # This gives an error now

推荐阅读