python - 如何将 L1 正则化添加到 PyTorch NN 模型?
问题描述
在寻找在 PyTorch 模型中实现 L1 正则化的方法时,我遇到了这个问题,它现在已经 2 岁了,所以我想知道这个主题是否有任何新内容?
我还发现了这种处理丢失 l1 函数的最新方法。但是我不明白如何将它用于基本 NN,如下所示。
class FFNNModel(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim, dropout_rate):
super(FFNNModel, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
self.drop_layer = nn.Dropout(p=self.dropout_rate)
self.fully = nn.ModuleList()
current_dim = input_dim
for h_dim in hidden_dim:
self.fully.append(nn.Linear(current_dim, h_dim))
current_dim = h_dim
self.fully.append(nn.Linear(current_dim, output_dim))
def forward(self, x):
for layer in self.fully[:-1]:
x = self.drop_layer(F.relu(layer(x)))
x = F.softmax(self.fully[-1](x), dim=0)
return x
我希望在训练之前简单地把它放好:
model = FFNNModel(30,5,[100,200,300,100],0.2)
regularizer = _Regularizer(model)
regularizer = L1Regularizer(regularizer, lambda_reg=0.1)
和
out = model(inputs)
loss = criterion(out, target) + regularizer.__add_l1()
有谁知道如何应用这些“即用型”类?
解决方案
我还没有运行有问题的代码,所以如果某些东西不能正常工作,请回复。通常,我会说您链接的代码不必要地复杂(可能是因为它试图通用并允许以下所有类型的正则化)。我想它的使用方式是
model = FFNNModel(30,5,[100,200,300,100],0.2)
regularizer = L1Regularizer(model, lambda_reg=0.1)
接着
out = model(inputs)
loss = criterion(out, target) + regularizer.regularized_all_param(0.)
您可以检查它regularized_all_param
是否会迭代模型的参数,如果它们的名称以 结尾weight
,它将累积它们的绝对值之和。由于某种原因,需要手动初始化缓冲区,这就是我们传入0.
.
确实,如果您希望有效地规范 L1 并且不需要任何花里胡哨,那么类似于您的第一个链接的更手动的方法将更具可读性。会这样
l1_regularization = 0.
for param in model.parameters():
l1_regularization += param.abs().sum()
loss = criterion(out, target) + l1_regularization
这确实是这两种方法的核心。您使用该Module.parameters
方法迭代所有模型参数并总结它们的 L1 范数,然后它成为您的损失函数中的一个术语。而已。您链接的 repo 提供了一些花哨的机制来将其抽象出来,但是从您的问题来看,它失败了:)
推荐阅读
- javascript - 使用css或js单击缩略图时隐藏字段
- apache-spark - 组合“n”个数据文件以制作单个 Spark Dataframe
- pdf - 将文档转换为 pdf 时,如何防止 Microsoft Word 的 PrintDate 域代码变为纯文本?
- linux - 如何转义 sed 的输入文件
- r - 左对齐数据框中的列
- python - 从查询中获取用户的投票
- javascript - WP CSS & JS 没有入队
- amazon-ec2 - 将来自不同主机的多个IP地址指向具有不同端口的相同域名
- autohotkey - 自动热键映射修改器
- java - 我不明白在这种情况下线程是如何工作的