首页 > 解决方案 > pytorch 会在网络中使用 python 计算代码正确执行吗?

问题描述

以下面的假代码为例:

class():
    def forward(input):
        x = some_torch_layers(input)
        x = some_torch_layers(x)
        ...
        x = sum(x) # or numpy or other operations
        x = some_torch_layers(x)
        return x

pytorch 网络会运行良好吗?特别是,虽然sum(x)在后向过程中表现良好。

标签: machine-learningneural-networkdeep-learningcomputer-visionpytorch

解决方案


TL;DR

为了让 PyTorch “表现良好”,它需要通过网络传播梯度。PyTorch 不(也不能)知道如何区分任意 numpy 代码,它只能通过 PyTorch 张量操作传播梯度。
在您的示例中,梯度将在 numpy 处停止,sum因此只会训练最顶层的火炬层(numpy 操作和 之间criterion的层),其他层(输入和 numpy 操作之间的层)将具有零梯度,因此它们的参数将保持不变在整个培训过程中固定。


推荐阅读