python - 如何在 Keras 中实现这种非全连接结构?
问题描述
我正在尝试在 Keras 中实现一个 3 层非全连接网络,用于对具有相同输出维度的 2 个子模型的输出权重进行分类。这应该使用一个二进制掩码矩阵来完成,该矩阵由 0 和 1 组成,每个神经元应该连接到隐藏层中的神经元(因此每个神经元有两个 1)。我附上了一张图表以及我正在使用的 PyTorch 代码。
class Weight_classifier(nn.Module):
def __init__(self, func):
super(Weight_classifier, self).__init__()
# self.weight_layer = nn.Linear(OUT_nodes[func]*3, OUT_nodes[func])
self.weight_layer = MaskedLinear(OUT_nodes[func]*2, OUT_nodes[func], 'data/{}_maskmatrix.csv'.format(func), func).cuda()
self.outlayer= nn.Linear(OUT_nodes[func], OUT_nodes[func])
def forward(self, weight_features):
weight_out = self.weight_layer(weight_features)
# weight_out = F.sigmoid(weight_out)
weight_out = F.relu(weight_out)
weight_out = F.sigmoid(self.outlayer(weight_out))
return weight_out
class MaskedLinear(nn.Linear):
def __init__(self, in_features, out_features, relation_file, func, bias=True):
super(MaskedLinear, self).__init__(in_features, out_features, bias)
mask = self.readRelationFromFile(relation_file, func)
self.register_buffer('mask', mask)
self.iter = 0
def forward(self, input):
masked_weight = self.weight * self.mask
return F.linear(input, masked_weight, self.bias)
def readRelationFromFile(self, relation_file, func):
mask = []
with open(relation_file, 'r') as f:
for line in f:
l = [int(x) for x in line.strip().split(',')[OUT_nodes[func]:]]
assert len(l) == OUT_nodes[func]*2
for item in l:
assert item == 1 or item == 0
mask.append(l)
return Variable(torch.Tensor(mask))
谢谢