首页 > 解决方案 > 用 jax 计算逐行(或逐轴)点积的最佳方法是什么?

问题描述

我有两个形状(N,M)的数值数组。我想计算一个逐行的点积。即产生一个形状为 (N,) 的数组,使得第 n 行是每个数组的第 n 行的点积。

我知道numpy的inner1d方法。使用 jax 执行此操作的最佳方法是什么?jax 有jax.numpy.inner,但这有别的作用。

标签: numpylinear-algebrajax

解决方案


你可以试试jax.numpy.einsum。这里使用 numpy einsum 的实现

import numpy as np
from numpy.core.umath_tests import inner1d

arr1 = np.random.randint(0,10,[5,5])
arr2 = np.random.randint(0,10,[5,5])

arr = np.inner1d(arr1,arr2)
arr
array([ 87, 200, 229,  81,  53])
np.einsum('...i,...i->...',arr1,arr2)
array([ 87, 200, 229,  81,  53])

推荐阅读