首页 > 解决方案 > **kwargs 属性的含义放在机器学习模型的类中

问题描述

我想知道**kwargs我发现通常添加在某些机器学习模型类的构造函数中的属性的含义。例如考虑 PyTorch 中的神经网络:

class Model(nn.Module):

def __init__(self, input_dim, hidden_dim, output_dim, **kwargs)

是否**kwargs与稍后定义的额外参数相关联?

标签: machine-learningpytorch

解决方案


这并不特定于机器学习模型类,而是 Python 的一个特性。

你确实是对的,它对应于额外的关键字参数。它将本质上收集未在函数头中定义的剩余传递的命名参数,并将它们添加到字典变量kwargs中。这个变量实际上可以重命名为任何名称,习惯上'args'为可迭代的未命名参数(*args)和'kwargs'关键字参数(**kwargs)保留。

这增加了允许定义附加参数并将其传递给函数的灵活性,而无需在标头中明确说明它们的名称。一个常见的用例是扩展类时。这里我们实现了一个名为 的虚拟 3x3 2D 卷积层Conv3x3,它将扩展基本nn.Conv2d模块:

class Conv3x3(nn.Conv2d):
    def __init__(self, **kwargs):
       super().__init__(kernel_size=3, **kwargs)

如您所见,我们不需要命名所有参数,并且我们仍然保持与类初始化程序中相同的nn.Conv2d接口Conv3x3

>>> Conv3x3(in_channels=3, out_channels=16)
Conv3x3(3, 16, kernel_size=(3, 3), stride=(1, 1))

你可以用这两个结构做很多好事。您可以在这里找到其中的大部分内容。


推荐阅读