首页 > 解决方案 > jax vmap:强制执行正确的形状

问题描述

我正在使用vmap对部分代码进行矢量化处理。这是一个最小的例子,在矢量化之前:

dim = 2
def sum(x):
    a = np.ones((dim,))
    return np.dot(x, a)

num_samples = 100
samples = np.ones((num_samples, dim))

sum(samples[0]) # 2

使用 vmap:

sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2

但这可能会出错,在矢量化之后:

sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1

这里发生的是samples[0]具有形状的(2,)。矢量化函数调用沿第一个轴拆分其输入参数,因此输入 2 个 shape 数组(1,)。由于使用 广播a,结果输出再次具有形状(2,)并堆叠到(2,2)数组中。

这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他一些隐藏其损坏形状的广播规则所消耗。

是否可以强制执行正确的形状?

标签: pythonvectorizationjax

解决方案


“这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他一些隐藏其损坏形状的广播规则所消耗。”

请注意,这vmap正是它应该在这里做的事情,即对第零维进行矢量化,而 numpy 广播正在做它应该做的事情。当然,问题是用户给出了一个形状错误的数组,因为vmap期望 x 的第零维中的向量化输入。相反,用户应该写

sum(samples[0:1])

保持适当的形状。

换句话说:如果你打算将 vmap 应用到一个函数上,你就不能像一开始就没有应用 vmap 一样使用这个函数。您需要考虑函数的更改行为。

“是否可以强制执行正确的形状?”

vmap本身没有能力强制输入的形状。如果您特别担心用户给函数赋予错误的形状,您可以将其构建到原始函数中。例如,

def sum(x): 
    if (x.shape[-1] != dim): 
        raise Exception() 
    a = np.ones((dim,)) 
    return np.dot(x, a)

如果你不给它正确的形状,即使在应用之后也会破裂vmap


推荐阅读