python - Tensorflow:点积中的“轴”参数
问题描述
有人可以告诉我应该如何使用这个axis
论点tf.tensordot
吗?
我阅读了文档,但它很复杂,我仍然感到困惑。我看到另一个问题,在答案axis
中tf.one_hot
和答案中提出了一些关于此事的很好的见解,但这对我没有帮助tf.tensordot
。我想你也可以给我一些见解。
例如,我知道我可以像这样对向量和张量进行点积:
my_vector = tf.random.uniform(shape=[n])
my_tensor = tf.random.uniform(shape=[m, n])
dp = tf.tensordot(my_tensor, my_vector, 1)
但是当我对它们进行批处理并向它们添加一个维度以使其具有形状(b, n)
并(b, m, n)
获得 a(b, m, 1)
时,现在我不知道如何在每批产品中添加一个点。
解决方案
您想要执行的操作无法(以有效方式)使用tf.tensordot
. 但是,该操作有一个专用函数 ,tf.linalg.matvec
它将与开箱即用的批次一起使用。你也可以用tf.einsum
, like做同样的事情tf.einsum('bmn,bn->bm', my_tensors, my_vectors)
。
关于tf.tensordot
,通常它计算两个给定张量的“全部与全部”乘积,但匹配和减少一些轴。当没有给定轴时(您必须显式传递axes=[[], []]
才能执行此操作),它会创建一个张量,其中两个输入的维度连接在一起。所以,如果你有my_tensors
形状(b, m, n)
和my_vectors
形状(b, n)
,你这样做:
res = tf.tensordot(my_tensors, my_vectors, axes=[[], []])
你得到res
与形状(b, m, n, b, n)
,这样res[p, q, r, s, t] == my_tensors[p, q, r] * my_vectors[s, t]
。
该axes
参数用于指定输入张量中“匹配”的维度。沿匹配轴的值相乘和相加(如点积),因此这些匹配的维度会从输出中减少。axes
可以采取两种不同的形式:
- 如果它是单个整数,
N
则第一个参数的最后一个N
维度与 的第一个N
维度匹配b
。在您的示例中,这对应于具有 和元素n
的维度。my_tensor
my_vector
- 如果它是一个列表,它必须包含两个子列表
axes_a
和axes_b
,每个子列表具有相同数量N
的整数。在这种形式中,您明确指出给定值的哪些维度是匹配的。因此,在您的示例中,您可以传递axes=[[1], [0]]
,这意味着“1
将第一个参数 (my_tensor
) 的维度0
与第二个参数 (my_vector
) 的维度相匹配”。
如果您现在my_tensors
有 with shape(b, m, n)
和my_vectors
with shape (b, n)
,那么您可能希望将2
第一个的尺寸1
与第二个的尺寸相匹配,这样您就可以通过axes=[[2], [1]]
. 但是,这将为您提供一个res
形状(b, m, b)
为matrix和 vectorres[i, :, j]
乘积的结果。然后,您可以只获取您想要的结果(那些 where ),或多或少地像my_tensors[i]
my_vectors[j]
i == j
tf.transpose(tf.linalg.diag_part(tf.transpose(res, [1, 0, 2])))
推荐阅读
- c++ - C ++字符串文字相等性检查?
- ruby-on-rails - as_json 和动态方法的使用
- php - Foreach 到 PHP 中的变量
- swift - 预期返回“字符串”的函数中缺少返回
- javascript - 如何在继续之前等待 $.getScript() 中的函数完成?
- javascript - 在名称属性中使用索引并使用 vee 验证自定义错误消息?
- reactjs - TypeError 无法读取未定义的属性“baseTheme”
- google-maps - Flutter 的 flutter_google_map_view 小部件中的集群标记
- python - Fill empty enclosed parts from image
- database - 如何同步两个 docker 卷?