python-3.x - Module.parameters() 如何找到参数?
问题描述
我注意到,每当您创建一个新的网络扩展torch.nn.Module
时,您都可以立即调用net.parameters()
以查找与反向传播相关的参数。
import torch
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
net = MyNet()
print(list(net.parameters()))
但后来我想知道,这怎么可能?我只是将此Linear
图层对象分配给一个成员变量,但它没有记录在其他任何地方(或者是吗?)。必须能够以某种MyNet
方式跟踪使用的参数,但如何?
解决方案
真的很简单,只需通过元编程检查属性并检查它们的类型
class Example():
def __init__(self):
self.special_thing = nn.Parameter(torch.rand(2))
self.something_else = "ok"
def get_parameters(self):
for key, value in self.__dict__.items():
if type(value) == nn.Parameter:
print(key, "is a parameter!")
e = Example()
e.get_parameters()
# => special_thing is a parameter!
推荐阅读
- symfony - Twig/Symfony "markdown_to_html" 过滤器未找到 **debug:translation**
- reactjs - 在 Material-ui Autocomplete 组件上设置文本颜色、轮廓和填充
- python - Django `get_or_build` 类似于 `get_or_create`?
- javascript - shinyBS - 拖动时删除工具提示
- python - 如何以编程方式将可调用对象传递给 gunicorn 而不是参数
- php - 来自一个数据库的php sql查询插入另一个数据库
- python - 收到代理服务器错误时如何防止脚本被取消
- hive - 在来自 HIVE 的 JDBCTemplate 结果上使用 Java8 Stream
- mongodb - 无法将数组添加到 mongodb
- mongodb - 如何在解组 MongoDB 文档时忽略空值?