python - TensorFlow tensor_scatter_nd_add 索引
问题描述
我想要做的是将一些列插入 tf 张量z
z = tf.zeros((100, 50))
indices = tf.constant([[0], [2]])
updates = tf.ones((50, 2))
tf.tensor_scatter_nd_add(z, indices, updates)
是我想将两列一列插入列0
和2
轴 = 1(列)到z
. 但我得到了错误
InvalidArgumentError:索引和更新的外部尺寸必须匹配。索引形状:[2,1],更新形状:[50,2] [Op:TensorScatterAdd]
如何控制索引参数以防止这种情况发生?