python - 如何从 Jax 中的损失函数返回值字典?
问题描述
假设您有一个损失函数,并且您想在训练时跟踪损失的各个子组件。这样做的最“jax”方式是什么?
def loss(params, x, y):
...
loss_1 = ...
loss_2 = ...
loss = loss_1 + 0.1 * loss_2
return dict(loss=loss, loss_1=loss_1, loss_2=loss_2)
@jax.jit
def update(params, tau, y):
f_value, grads = jax.value_and_grad(loss)(params, tau, y)
# something like this
您是否想只使用grad
一个函数来拉出损失,然后再次重新计算值?有没有办法value_and_grad
更有效地做到这一点?
解决方案
感谢@jakevdp本人促使我考虑一些替代的谷歌查询,事实证明,截至https://github.com/google/jax/pull/484,grad函数有一个aux选项。我认为这对于迁移到 jax 的 tensorflow 2 用户来说并不是很明显,因为您明确使用 GradientTape 的方式。
类似以下示例的内容显示了返回的辅助信息。它甚至似乎处理了一个 dict,这对于在更新循环中定期记录很有用。
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
theta = jax.random.normal(key, (10, 1))
y = np.random.randn(10, 1)
alpha = 0.01
def loss(theta, y):
loss_reg = jnp.sum(theta ** 2)
loss_data = jnp.sum((y - theta) ** 2)
loss = loss_data + alpha * loss_reg
return loss, dict(loss_reg=loss_reg, loss_data=loss_data)
grad, aux = jax.grad(loss, has_aux=True)(theta, y)
display(grad)
display(aux)
try:
jax.grad(loss)(theta, y)
except TypeError as e:
print(f'yes got error {e}')
输出:
DeviceArray([[-1.4899637 ],
[-0.71481365],
[-0.6030376 ],
[-0.8263864 ],
[-1.8103108 ],
[ 0.69435316],
[-1.5611547 ],
[-1.6380725 ],
[ 0.9838154 ],
[ 0.21186407]], dtype=float32)
{'loss_data': DeviceArray(3.3714797, dtype=float32),
'loss_reg': DeviceArray(2.658556, dtype=float32)}
yes got error Gradient only defined for scalar-output functions. Output was (DeviceArray(3.3980653, dtype=float32), {'loss_data': DeviceArray(3.3714797, dtype=float32), 'loss_reg': DeviceArray(2.658556, dtype=float32)}).
推荐阅读
- azure - 如何在 Azure 机器学习中使用历史数据集进行训练和预期数据集作为预测的输入
- swift - 标签不调整字体大小以适应宽度
- python - 配置文件中带有空格的python文件路径
- android - 我可以在真实设备中切换到 webview 但无法在模拟器上切换到 webview。使用的 Appium 版本是 1.7.2
- angular - 在身份验证中创建用户后无法将数据插入数据库
- javascript - 包含一个文件夹和其中的文件以使用电子构建器进行电子构建?
- sql - 从给定集合(> 80 个元素)中找到所有数字组合以达到给定最终总和的最佳性能方法
- javascript - Cloud Functions 返回未定义的、预期的 Promise 或值
- regex - 使用正则表达式或美丽的汤从 instagram 中获取某人的网站
- charts - Creating multi line graphs using Google Charts API