python - PyTorch中点操作导致的层权重
问题描述
我正在尝试实现一个网络,其中层的权重是通过张量运算来计算的。这是我对单层的代码,对所有 conv 和 fc 层重复:
class NOWANet(nn.Module):
def __init__(self, V, Wstacked):
super(NOWANet, self).__init__()
# first conv layer
# define layer
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
# define tensors that will be used to calculate weights and biases
self.conv1_W_V = nn.Parameter(V[0], requires_grad=True)
self.conv1_B_V = nn.Parameter(V[1], requires_grad=True)
self.conv1_W_stack = nn.Parameter(Wstacked[0], requires_grad=False)
self.conv1_B_stack = nn.Parameter(Wstacked[1], requires_grad=False)
# set layer weights and bias
self.conv1.weight = nn.Parameter(torch.tensordot(self.conv1_W_V,
self.conv1_W_stack,
dims=1, out=None),
requires_grad=True)
self.conv1.bias = nn.Parameter(torch.tensordot(self.conv1_B_V,
self.conv1_B_stack,
dims=1, out=None),
requires_grad=True)
这个想法是在 V 和 W_stack 张量之间执行一个张量点,然后应该将其用作权重和偏差(每个一个张量点)。问题是我只想优化 V 以便 W_stack 保持不变。当前编写的方式正确地初始化了权重,但是反向传播优化了层的实际权重,V 向量没有变化。
关于如何做到这一点的任何想法或建议?