首页 > 解决方案 > PyTorch 中的偏导数

问题描述

假设f(x)是一个权重为w的神经网络。如果我做

loss = -f(f(x))  
loss.backward()
self.optim.step()

那么这与朝着f'(f(x))df(x)/dw + df/dw(f(x))的方向前进是一样的。但是,如果我想忽略第二部分并且只朝f'(f(x))df(x)/dw的方向移动,那么在 PyTorch 中实现它的简单方法是什么?

标签: pytorchderivative

解决方案


据我了解,一种可能的解决方案就是拥有一个目标网络:

f_target.load_state_dict(f.state_dict())
f_target.eval()
loss = -f_target(f(x))
loss.backward()
self.optim.step()

推荐阅读