首页 > 解决方案 > How can scatter_std influence the value of the weights in a GNN programmed with pytorch?

问题描述

I'm trying to program a GNN using pytorch. I currently use scatter_add and scatter_mean to summerize the edge input for the node update, and then everything works fine. However if I also want to use scatter_std, things go wrong: when the model updates it's parameters after the first batch, all weights are set to 'nan'. I have no idea how scatter_std() can influence this, but apparently it does. Has anyone seen this before?

This is my node update:

class NodeModel_1(torch.nn.Module):
def __init__(self):
    super(NodeModel_1, self).__init__()
    self.node_mlp = nn.ModuleList()
    
    for l in range(len(NN_node_layers_1)-1):
        # Add linear layers to the neural networks
        self.node_mlp.append(nn.Linear(NN_node_layers_1[l], NN_node_layers_1[l+1]))
    
def forward(self, node_attr, edge_index, edge_attr, u = None, batch = None):
    row, col = edge_index
    
    out = torch.cat([scatter_mean(edge_attr, col, dim=0, dim_size=node_attr.size(0)),scatter_add(edge_attr, row, dim=0, dim_size=node_attr.size(0)),scatter_std(edge_attr, row, dim=0, dim_size=node_attr.size(0), unbiased = False)], dim = 1)
    out = torch.cat([out, node_attr],dim = 1)

    for layer in self.node_mlp[:-1]:
        out =  F.elu(layer(out))
    out = self.node_mlp[-1](out)
    return out

And this is my network function:

    def forward(self, data):
    print(self.encoder_edge[0].weight)
    node_attr, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
    #First put data through the encoders
    
    for layer in self.encoder_node[:-1]:
        node_attr =  F.elu(layer(node_attr))
    node_attr = self.encoder_node[-1](node_attr)
    for layer in self.encoder_edge[:-1]:
        edge_attr =  F.elu(layer(edge_attr))
    edge_attr = self.encoder_edge[-1](edge_attr)
    
    print(self.encoder_node[0].weight)
    # Save initial input to feed to the  neural network 
    node_attr_zero, edge_attr_zero = 1*node_attr, 1*edge_attr

    # Update network for the first time
    node_attr , edge_attr, u = self.network_1(node_attr, edge_index, edge_attr)
    
    edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
    node_attr = torch.cat([node_attr,node_attr_zero],dim =1)

    node_attr , edge_attr, u = self.network_2(node_attr, edge_index, edge_attr)
    
    edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
    node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
    node_attr , edge_attr, u = self.network_3(node_attr, edge_index, edge_attr)
    
    # Decode node information
    for layer in self.decoder_node[:-1]:
        node_attr =  F.elu(layer(node_attr))
    node_attr = self.decoder_node[-1](node_attr)
    print(self.encoder_edge[0].weight)
    return node_attr

标签: pythonpytorch

解决方案


推荐阅读