python - 如何将 PyTorch 子模块保持在 eval 模式?
问题描述
我有一个预训练模型,我将它与正在训练的模型结合使用。我希望预训练模型始终处于评估模式,但另一个模型将在评估和训练模式之间来回移动。不过,我仍然希望预训练模型成为另一个模型的子模块(例如,以便所有参数都保留在同一设备上)。有没有办法做到这一点?这是一个最小的例子:
from torch import nn
class FixedModule(nn.Module):
pass
class TrainableModule(nn.Module):
def __init__(self, fixed_module):
super().__init__()
self.fixed_module = fixed_module
fixed = FixedModule().eval()
assert not fixed.training
trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training
trainable.train()
assert trainable.fixed_module.training # I'd like this to give an error
我知道我可以解决这个问题,例如,总是做
trainable.train()
trainable.fixed_module.eval()
但这很容易出错,并且不适用于现有代码。
解决方案
一种解决方案是像这样覆盖train
:
from torch import nn
class FixedModule(nn.Module):
pass
class TrainableModule(nn.Module):
def __init__(self, fixed_module):
super().__init__()
self.fixed_module = fixed_module
def train(self):
super().train()
self.fixed_module.eval()
fixed = FixedModule().eval()
assert not fixed.training
trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training
trainable.train()
assert trainable.fixed_module.training # This gives an error now
推荐阅读
- ruby-on-rails - Jelastic CMD /入口点中的bundle exec puma错误
- ruby - 如何使用数组作为值对这个散列进行排序?
- tcl - TCL:将proc的输出重定向到文件
- xml - XSLT - 克隆 XML 过滤掉数组元素
- mongodb - 我有一个包含可能值 A 和 B 的数组的文档,如何根据它将其值添加到数组中?
- javascript - 每当添加反应时,如何执行代码块?(discord.js)
- c++ - 具有互斥锁缓存的类的移动构造函数的最佳实践
- python - 美汤刮-取空集
- r - 总结一个col并将最后一个索引处的值存储在R中的一个新列中
- graph - 用于绘制逻辑门和功能块的脚本语言