python - jax vmap:强制执行正确的形状
问题描述
我正在使用vmap对部分代码进行矢量化处理。这是一个最小的例子,在矢量化之前:
dim = 2
def sum(x):
a = np.ones((dim,))
return np.dot(x, a)
num_samples = 100
samples = np.ones((num_samples, dim))
sum(samples[0]) # 2
使用 vmap:
sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2
但这可能会出错,在矢量化之后:
sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1
这里发生的是samples[0]
具有形状的(2,)
。矢量化函数调用沿第一个轴拆分其输入参数,因此输入 2 个 shape 数组(1,)
。由于使用 广播a
,结果输出再次具有形状(2,)
并堆叠到(2,2)
数组中。
这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他一些隐藏其损坏形状的广播规则所消耗。
是否可以强制执行正确的形状?
解决方案
“这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他一些隐藏其损坏形状的广播规则所消耗。”
请注意,这vmap
正是它应该在这里做的事情,即对第零维进行矢量化,而 numpy 广播正在做它应该做的事情。当然,问题是用户给出了一个形状错误的数组,因为vmap
期望 x 的第零维中的向量化输入。相反,用户应该写
sum(samples[0:1])
保持适当的形状。
换句话说:如果你打算将 vmap 应用到一个函数上,你就不能像一开始就没有应用 vmap 一样使用这个函数。您需要考虑函数的更改行为。
“是否可以强制执行正确的形状?”
vmap
本身没有能力强制输入的形状。如果您特别担心用户给函数赋予错误的形状,您可以将其构建到原始函数中。例如,
def sum(x):
if (x.shape[-1] != dim):
raise Exception()
a = np.ones((dim,))
return np.dot(x, a)
如果你不给它正确的形状,即使在应用之后也会破裂vmap
。
推荐阅读
- django - 如何将查询参数传递给 django rest api javascript 客户端
- c# - 从日期时间中减去未知间隔
- java - 迭代数组并打印不同的字符串输出
- angular - Ionic 4:创建模拟存储
- augmented-reality - 如何从华为智能手机中删除隐藏的 ARCore 应用程序?
- java - 模拟 Kafka 代理进行 Spring Boot JUnit 测试,无需直接调用
- https - 使用 ABAP 发送 HTTPS POST 请求
- java - 如何使用 Thread.sleep()?
- php - 无法在 ubuntu 18 中安装 Laravel Telescope
- apache-kafka - Spring Cloud Stream Kafka Stream 与原生 Kafka Stream 应用程序和生产者之间的 Avro 消息不兼容