python - How can scatter_std influence the value of the weights in a GNN programmed with pytorch?
问题描述
I'm trying to program a GNN using pytorch. I currently use scatter_add and scatter_mean to summerize the edge input for the node update, and then everything works fine. However if I also want to use scatter_std, things go wrong: when the model updates it's parameters after the first batch, all weights are set to 'nan'. I have no idea how scatter_std() can influence this, but apparently it does. Has anyone seen this before?
This is my node update:
class NodeModel_1(torch.nn.Module):
def __init__(self):
super(NodeModel_1, self).__init__()
self.node_mlp = nn.ModuleList()
for l in range(len(NN_node_layers_1)-1):
# Add linear layers to the neural networks
self.node_mlp.append(nn.Linear(NN_node_layers_1[l], NN_node_layers_1[l+1]))
def forward(self, node_attr, edge_index, edge_attr, u = None, batch = None):
row, col = edge_index
out = torch.cat([scatter_mean(edge_attr, col, dim=0, dim_size=node_attr.size(0)),scatter_add(edge_attr, row, dim=0, dim_size=node_attr.size(0)),scatter_std(edge_attr, row, dim=0, dim_size=node_attr.size(0), unbiased = False)], dim = 1)
out = torch.cat([out, node_attr],dim = 1)
for layer in self.node_mlp[:-1]:
out = F.elu(layer(out))
out = self.node_mlp[-1](out)
return out
And this is my network function:
def forward(self, data):
print(self.encoder_edge[0].weight)
node_attr, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
#First put data through the encoders
for layer in self.encoder_node[:-1]:
node_attr = F.elu(layer(node_attr))
node_attr = self.encoder_node[-1](node_attr)
for layer in self.encoder_edge[:-1]:
edge_attr = F.elu(layer(edge_attr))
edge_attr = self.encoder_edge[-1](edge_attr)
print(self.encoder_node[0].weight)
# Save initial input to feed to the neural network
node_attr_zero, edge_attr_zero = 1*node_attr, 1*edge_attr
# Update network for the first time
node_attr , edge_attr, u = self.network_1(node_attr, edge_index, edge_attr)
edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
node_attr , edge_attr, u = self.network_2(node_attr, edge_index, edge_attr)
edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
node_attr , edge_attr, u = self.network_3(node_attr, edge_index, edge_attr)
# Decode node information
for layer in self.decoder_node[:-1]:
node_attr = F.elu(layer(node_attr))
node_attr = self.decoder_node[-1](node_attr)
print(self.encoder_edge[0].weight)
return node_attr