首页 > 解决方案 > 在tensorflow中交换张量的元素

问题描述

我在尝试交换具有可变长度的张量元素时遇到了令人惊讶的困难。据我了解,切片赋值仅支持变量,因此在运行以下代码时,我得到错误ValueError: Sliced assignment is only supported for variables

def add_noise(tensor):
  length = tf.size(tensor)

  i = tf.random_uniform((), 0, length-2, dtype=tf.int32)
  aux = tensor[i]
  tensor = tensor[i].assign(tensor[i+1])
  tensor = tensor[i+1].assign(aux)

  return tensor

with tf.Session() as sess:
  tensor = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32)
  print sess.run(add_noise(tensor))

如何交换张量中的元素?

标签: pythontensorflow

解决方案


您可以使用 TensorFlow分散函数scatter_nd来交换tensor元素。您还可以在一次scatter操作中实现多次交换。

tensor = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32)  # input
# let's swap 1st and 4th elements, and also 5th and 6th elements (in terms of 0-based indexing)
indices = tf.constant([[0], [4], [2], [3], [1], [5], [6]])  # indices mentioning the swapping pattern
shape = tf.shape(tensor)  # shape of the scattered_tensor, zeros will be injected if output shape is greater than input shape
scattered_tensor = tf.scatter_nd(indices, tensor, shape)

with tf.Session() as sess:
  print sess.run(scattered_tensor)
  # [0 4 2 3 1 6 5]

推荐阅读