首页 > 解决方案 > 难以理解张量的广播行为

问题描述

我正在尝试对两个维度 (1,5,64) 和 (1,5) 的张量进行逐元素乘法。据我所知,尽管它们的尺寸不匹配,但广播应该允许这样做。所以,我使用这段代码:

x = tf.range(0,64*5)
x = tf.reshape(x, [1,5, 64])

y = tf.range(0,5)
y = tf.reshape(y, [1, 5])

prodct = x*y

这会导致此错误:

InvalidArgumentError: Incompatible shapes: [1,5,64] vs. [1,5] [Op:Mul]

但是,如果我将第一个张量重塑为尺寸(1,64,5),那么它可以工作。代码:

x = tf.range(0,64*5)
x = tf.reshape(x, [1,64, 5])

y = tf.range(0,5)
y = tf.reshape(y, [1, 5])

prodct = x*y

我不明白为什么第一个代码不起作用。

标签: numpytensorflowmultiplicationtensorarray-broadcasting

解决方案


一般广播规则,当对两个数组进行操作时,按元素比较它们的形状。它从尾随(即最右边的)尺寸开始,然后向左工作。两个维度兼容时

  • 他们是平等的,或者
  • 其中之一是 1

如果不满足这些条件,ValueError: operands could not be broadcast together则会引发异常,指示数组具有不兼容的形状。结果数组的大小是沿输入的每个轴不为 1 的大小。

也遵循同样的精神。查看文档以获取更多示例和详细信息。对于您的情况,最右边的维度不遵循规则并引发错误。

1, 5, 64
   1, 5

但这会起作用,因为它遵守规则。

1, 64, 5
   1,  5

代码

中供参考。

import numpy as np 
a = np.arange(64*5).reshape(1, 64, 5)
b = np.arange(5).reshape(1,5)
(a*b).shape
(1, 64, 5)

import tensorflow as tf 
x = tf.reshape(tf.range(0,64*5), [1, 64, 5])
y = tf.reshape(tf.range(0,5), [1, 5])
(x*y).shape
TensorShape([1, 64, 5])

推荐阅读