pytorch - 定义一个自定义变量 pytorch 张量
问题描述
我有一个类似的问题:
有一个形式的输入向量(n1, n2)
,我想要一个模型f(n1, n2) = f(n2, n1)
。我希望这个模型是线性的。例如2x2
我想学习的矩阵。该2x2
矩阵有 4 个权重,训练将尝试学习这些权重。然而,收缩可以巧妙地减少问题。
f(nvec) = W*nvec
,存在nvec=(n1, n2)
。因此,设X = [[0, 1],[1, 0]]
是一个2x2
翻转向量的矩阵,则f(n2,n1) = W*X*(n1, n2) = f(n1, n2) = W (n1, n2)
。
基本上,等式意味着W*X = W
,或者如果W = [[w1, w2], [w3, w4]]
,则等式意味着W = [[w1, w1], [w3, w3]]
,即矩阵只有 2 个自由参数。如何在 pytorch 中定义这样的模型?
谢谢。
解决方案
如果您的参数有特殊的结构,并且您不想使用默认层存储的默认参数,您可以自己定义一个层并用于nn.Parameter
存储您的参数。
例如,
class CustomLayer(nn.Module): # layer must be derived from nn.Module class
def __init__(self): # you can have arguments here...
super(CustomLayer, self).__init__()
self.params = nn.Parameter(data=torch.rand((2, 1), dtype=torch.float), requires_grad=True)
def forward(self, x):
w = self.params.repeat(1, 2) # create a 2x2 matrix from the two parameters
y = torch.bmm(x, w)
return y
推荐阅读
- xslt - 如何编号
带有 XSLT 的标签 - selenium - 将 OWASP ZAP 代理用于现有的 Selenium 测试套件
- netsuite - 提交后如何检查订单项是否有任何变化?
- java - Spring Boot中没有构造函数的依赖注入
- android - GnssNavigationMessage.STATUS_UNKNOWN?
- c# - 如何读取 CommandLineInterface-App 的返回值返回?
- spring - Thymeleaf 设置默认值
- node.js - 猫鼬,查找和更新符合条件的数组中的项目(仅应更改数组中的数学项目)
- javascript - 在WebView上启用javascript不起作用android studio
- solidity - 在 BSCScan.com 上验证合同