pytorch - 如何计算pytorch中的参数重要性?
问题描述
我想开发一个终身学习系统,所以我需要防止重要参数发生变化。我阅读了相关论文'Memory Aware Synapses: Learning what (not) to forget',提到了一种方法,我需要计算每个的梯度对应于每个输入图像的参数,那么我应该如何在 pytorch 中编写我的代码? “记忆感知突触:学习(不)忘记什么”
解决方案
您可以.backward()
在损失函数上使用标准优化程序和方法来做到这一点。
首先,按照链接中的定义进行缩放:
class Scaler:
def __init__(self, parameters, delta):
self.parameters = parameters
self.delta = delta
def step(self):
"""Multiplies gradients in place."""
for param in self.parameters:
if param.grad is None:
raise ValueError("backward() has to be called before running scaler")
param.grad *= self.delta
可以像使用它一样使用它optimizer.step()
,见下文(见评论):
model = torch.nn.Sequential(
torch.nn.Linear(10, 100), torch.nn.ReLU(), torch.nn.Linear(100, 1)
)
scaler = Scaler(model.parameters(), delta=0.001)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.MSELoss()
X, y = torch.randn(64, 10), torch.randn(64)
# Optimization loop
EPOCHS = 10
for _ in range(EPOCHS):
output = model(X)
loss = criterion(output, y)
loss.backward() # Now model has the gradients
optimizer.step() # Optimize model's parameters
print(next(model.parameters()).grad)
scaler.step() # Scaler gradients
optimizer.zero_grad() # Zero gradient before next step
在scaler.step()
您param.grad
为每个参数内部提供渐变缩放之后(就像在Scaler
'sstep
方法中访问的那些),因此您可以对它们做任何您想做的事情。
推荐阅读
- vba - 对 Word VBA 中的 WdPasteOptions 枚举感到困惑
- java - 在spring设计中使用arrayblocking队列作为消费者生产者来处理巨大的文件
- sql - 如何检测 SQL Server 查询中的循环引用 - SQL Server 2017
- google-apps-script - Google Spreadsheet Hide and unhide rows based on cell values
- mysql - 在sql中获取某个数字
- spring-boot - 管理员如何使用 Spring Boot 锁定/解锁在 mysql 数据库中注册的用户?
- firebase - Firebase:收听路径的子项创建
- javascript - 信息:指定的 Android SDK 构建工具版本 (28.0.2) 被忽略
- r - 循环遍历R中列表上的项目
- python - 如何使用自定义名称创建数据框