首页 > 解决方案 > 如何从 Jax 中的损失函数返回值字典?

问题描述

假设您有一个损失函数,并且您想在训练时跟踪损失的各个子组件。这样做的最“jax”方式是什么?

def loss(params, x, y):
    ...
    loss_1 = ...
    loss_2 = ...
    loss = loss_1 + 0.1 * loss_2
    return dict(loss=loss, loss_1=loss_1, loss_2=loss_2)


@jax.jit
def update(params, tau, y):
    f_value, grads = jax.value_and_grad(loss)(params, tau, y)
    # something like this

您是否想只使用grad一个函数来拉出损失,然后再次重新计算值?有没有办法value_and_grad更有效地做到这一点?

标签: pythonloggingjax

解决方案


感谢@jakevdp本人促使我考虑一些替代的谷歌查询,事实证明,截至https://github.com/google/jax/pull/484,grad函数有一个aux选项。我认为这对于迁移到 jax 的 tensorflow 2 用户来说并不是很明显,因为您明确使用 GradientTape 的方式。

类似以下示例的内容显示了返回的辅助信息。它甚至似乎处理了一个 dict,这对于在更新循环中定期记录很有用。

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
theta = jax.random.normal(key, (10, 1))
y = np.random.randn(10, 1)
alpha = 0.01

def loss(theta, y):
    loss_reg = jnp.sum(theta ** 2)
    loss_data = jnp.sum((y - theta) ** 2)
    loss = loss_data + alpha * loss_reg
    return loss, dict(loss_reg=loss_reg, loss_data=loss_data)

grad, aux = jax.grad(loss, has_aux=True)(theta, y)

display(grad)
display(aux)

try:
    jax.grad(loss)(theta, y)
except TypeError as e:
    print(f'yes got error {e}')

输出:

DeviceArray([[-1.4899637 ],
             [-0.71481365],
             [-0.6030376 ],
             [-0.8263864 ],
             [-1.8103108 ],
             [ 0.69435316],
             [-1.5611547 ],
             [-1.6380725 ],
             [ 0.9838154 ],
             [ 0.21186407]], dtype=float32)
{'loss_data': DeviceArray(3.3714797, dtype=float32),
 'loss_reg': DeviceArray(2.658556, dtype=float32)}
yes got error Gradient only defined for scalar-output functions. Output was (DeviceArray(3.3980653, dtype=float32), {'loss_data': DeviceArray(3.3714797, dtype=float32), 'loss_reg': DeviceArray(2.658556, dtype=float32)}).

推荐阅读