首页 > 解决方案 > Tensorflow:点积中的“轴”参数

问题描述

有人可以告诉我应该如何使用这个axis论点tf.tensordot吗?

我阅读了文档,但它很复杂,我仍然感到困惑。我看到另一个问题,在答案axistf.one_hot和答案中提出了一些关于此事的很好的见解,但这对我没有帮助tf.tensordot。我想你也可以给我一些见解。

例如,我知道我可以像这样对向量和张量进行点积:

my_vector = tf.random.uniform(shape=[n])
my_tensor = tf.random.uniform(shape=[m, n])

dp = tf.tensordot(my_tensor, my_vector, 1)

但是当我对它们进行批处理并向它们添加一个维度以使其具有形状(b, n)(b, m, n)获得 a(b, m, 1)时,现在我不知道如何在每批产品中添加一个点。

标签: pythontensorflowtensorflow2.0axisdot-product

解决方案


您想要执行的操作无法(以有效方式)使用tf.tensordot. 但是,该操作有一个专用函数 ,tf.linalg.matvec它将与开箱即用的批次一起使用。你也可以用tf.einsum, like做同样的事情tf.einsum('bmn,bn->bm', my_tensors, my_vectors)

关于tf.tensordot,通常它计算两个给定张量的“全部与全部”乘积,但匹配和减少一些轴。当没有给定轴时(您必须显式传递axes=[[], []]才能执行此操作),它会创建一个张量,其中两个输入的维度连接在一起。所以,如果你有my_tensors形状(b, m, n)my_vectors形状(b, n),你这样做:

res = tf.tensordot(my_tensors, my_vectors, axes=[[], []])

你得到res与形状(b, m, n, b, n),这样res[p, q, r, s, t] == my_tensors[p, q, r] * my_vectors[s, t]

axes参数用于指定输入张量中“匹配”的维度。沿匹配轴的值相乘和相加(如点积),因此这些匹配的维度会从输出中减少。axes可以采取两种不同的形式:

  • 如果它是单个整数,N则第一个参数的最后一个N维度与 的第一个N维度匹配b。在您的示例中,这对应于具有 和元素n的维度。my_tensormy_vector
  • 如果它是一个列表,它必须包含两个子列表axes_aaxes_b,每个子列表具有相同数量N的整数。在这种形式中,您明确指出给定值的哪些维度是匹配的。因此,在您的示例中,您可以传递axes=[[1], [0]],这意味着“1将第一个参数 ( my_tensor) 的维度0与第二个参数 ( my_vector) 的维度相匹配”。

如果您现在my_tensors 有 with shape(b, m, n)my_vectorswith shape (b, n),那么您可能希望将2第一个的尺寸1与第二个的尺寸相匹配,这样您就可以通过axes=[[2], [1]]. 但是,这将为您提供一个res形状(b, m, b)为matrix和 vectorres[i, :, j]乘积的结果。然后,您可以只获取您想要的结果(那些 where ),或多或少地像my_tensors[i]my_vectors[j]i == jtf.transpose(tf.linalg.diag_part(tf.transpose(res, [1, 0, 2])))


推荐阅读