首页 > 解决方案 > 如何重置torch Sequential中的图层参数?

问题描述

在以下代码中(其原始形式来自此处):

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels)
                       )

    def forward(self, x, edge_index):

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

我想重置里面每一层的参数Seq。要重置参数,此答案建议:

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

但是我如何Sequential在以下代码段中使用 ie 来做到这一点:

self.mlp = Seq(Linear(in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels)
                       )

从第一个代码?

标签: pythonpython-3.xpytorch

解决方案


推荐阅读