首页 > 解决方案 > 如何在 Keras 中实现这种非全连接结构?

问题描述

我正在尝试在 Keras 中实现一个 3 层非全连接网络,用于对具有相同输出维度的 2 个子模型的输出权重进行分类。这应该使用一个二进制掩码矩阵来完成,该矩阵由 0 和 1 组成,每个神经元应该连接到隐藏层中的神经元(因此每个神经元有两个 1)。我附上了一张图表以及我正在使用的 PyTorch 代码。

在此处输入图像描述

class Weight_classifier(nn.Module):
def __init__(self, func):
    super(Weight_classifier, self).__init__()
    # self.weight_layer = nn.Linear(OUT_nodes[func]*3, OUT_nodes[func])
    self.weight_layer = MaskedLinear(OUT_nodes[func]*2, OUT_nodes[func], 'data/{}_maskmatrix.csv'.format(func), func).cuda()
    self.outlayer= nn.Linear(OUT_nodes[func], OUT_nodes[func])

def forward(self, weight_features):
    weight_out = self.weight_layer(weight_features)
    # weight_out = F.sigmoid(weight_out)
    weight_out = F.relu(weight_out)
    weight_out = F.sigmoid(self.outlayer(weight_out))
    return weight_out


class MaskedLinear(nn.Linear):
def __init__(self, in_features, out_features, relation_file, func, bias=True):
    super(MaskedLinear, self).__init__(in_features, out_features, bias)

    mask = self.readRelationFromFile(relation_file, func)
    self.register_buffer('mask', mask)
    self.iter = 0

def forward(self, input):
    masked_weight = self.weight * self.mask
    return F.linear(input, masked_weight, self.bias)

def readRelationFromFile(self, relation_file, func):
    mask = []
    with open(relation_file, 'r') as f:
        for line in f:
            l = [int(x) for x in line.strip().split(',')[OUT_nodes[func]:]]
            assert len(l) == OUT_nodes[func]*2
            for item in l:
                assert item == 1 or item == 0 
            mask.append(l)
    return Variable(torch.Tensor(mask))

谢谢

标签: pythontensorflowmachine-learningkeraspytorch

解决方案


推荐阅读