首页 > 解决方案 > TensorFlow 2.0 分散添加

问题描述

我想在 TensorFlow 2.0 中实现以下设计。

给定一个memory形状张量[a, b, c]
一个indices形状张量 和一个形状张 量[a, 1]
updates[a, c]

我想在中的值memory指示的位置递增。indicesupdates

tf.tensor_scatter_nd_add似乎不起作用:

tf.tensor_scatter_nd_add(memory, indices, updates)返回{InvalidArgumentError}Inner dimensions of output shape must match inner dimensions of updates shape. Output: [a,b,c] updates: [a,c] [Op:TensorScatterAdd]

真的有必要updates拥有与 一样多的内部维度memory吗?在我的逻辑中,memory[indices](作为伪代码)应该已经是一个张量 shape [a, c]。而且,形状tf.gather_nd(params=memory, indices=indices, batch_dims=1)已经是[a, c]

你能推荐一个替代品吗?

谢谢。

标签: tensorflowtensorflow2.0

解决方案


我想你想要的是这样的:

import tensorflow as tf

a, b, c = 3, 4, 5
memory = tf.ones([a, b, c])
indices = tf.constant([[2], [0], [3]])
updates = 10 * tf.reshape(tf.range(a * c, dtype=memory.dtype), [a, c])
print(updates.numpy())
# [[  0.  10.  20.  30.  40.]
#  [ 50.  60.  70.  80.  90.]
#  [100. 110. 120. 130. 140.]]

# Make indices for first dimension
ind_a = tf.range(tf.shape(indices, out_type=indices.dtype)[0])
# Make full indices
indices_2 = tf.concat([tf.expand_dims(ind_a, 1), indices], axis=1)
# Scatter add
out = tf.tensor_scatter_nd_add(memory, indices_2, updates)
print(out.numpy())
# [[[  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.  11.  21.  31.  41.]
#   [  1.   1.   1.   1.   1.]]
# 
#  [[ 51.  61.  71.  81.  91.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]]
# 
#  [[  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [101. 111. 121. 131. 141.]]]

推荐阅读