首页 > 解决方案 > PyTorch中点操作导致的层权重

问题描述

我正在尝试实现一个网络,其中层的权重是通过张量运算来计算的。这是我对单层的代码,对所有 conv 和 fc 层重复:

class NOWANet(nn.Module):
    def __init__(self, V, Wstacked):
        super(NOWANet, self).__init__()

        # first conv layer
        # define layer
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)

        # define tensors that will be used to calculate weights and biases
        self.conv1_W_V = nn.Parameter(V[0], requires_grad=True)
        self.conv1_B_V = nn.Parameter(V[1], requires_grad=True)
        self.conv1_W_stack = nn.Parameter(Wstacked[0], requires_grad=False)
        self.conv1_B_stack = nn.Parameter(Wstacked[1], requires_grad=False)
        
        # set layer weights and bias
        self.conv1.weight = nn.Parameter(torch.tensordot(self.conv1_W_V, 
                                                         self.conv1_W_stack, 
                                                         dims=1, out=None),
                                         requires_grad=True)
        self.conv1.bias = nn.Parameter(torch.tensordot(self.conv1_B_V, 
                                                       self.conv1_B_stack, 
                                                       dims=1, out=None),
                                       requires_grad=True)

这个想法是在 V 和 W_stack 张量之间执行一个张量点,然后应该将其用作权重和偏差(每个一个张量点)。问题是我只想优化 V 以便 W_stack 保持不变。当前编写的方式正确地初始化了权重,但是反向传播优化了层的实际权重,V 向量没有变化。

关于如何做到这一点的任何想法或建议?

标签: pythonpytorch

解决方案


推荐阅读