首页 > 解决方案 > 如何获取pytorch网络的输入和最终输出

问题描述

是否有关于如何获取 pytorch 网络的输入和输出节点的 API?我尝试了 model.features(),但这无济于事。


示例:我得到一个 pytorch 网络,它在 netron 中的结构: network

Conv2d、MaxPool2d 和 Linear 可以轻松解析。我在获取输入节点和输出节点的名称和大小等信息时遇到了麻烦。

标签: neural-networkpytorch

解决方案


获取特定层的输入意味着前一层的输出。

因此,为了在前向传递和后向传递过程中获取任何信息,例如特定层的输出、梯度,或者如果您想修改其中任何一个,pytorch 中有一个称为Hooks的概念。

在此处查找更多信息https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks


推荐阅读