首页 > 解决方案 > 在 TensorFlow 中,如何沿参差不齐的维度索引 RaggedTensor?

问题描述

我需要通过沿参差不齐的维度进行索引来获取参差不齐的张量中的值。一些索引有效 ( [:, :x],[:, -x:][:, x:y]),但不是直接索引 ( [:, x]):

R = tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]])
print(R[:, :2]) # RaggedTensor([[1, 2], [4, 5]])
print(R[:, 1:2]) # RaggedTensor([[2], [5]])
print(R[:, 1])  # ValueError: Cannot index into an inner ragged dimension.

文档解释了为什么失败:

RaggedTensors 支持多维索引和切片,但有一个限制:不允许索引到参差不齐的维度。这种情况是有问题的,因为指示的值可能存在于某些行中但不存在于其他行中。在这种情况下,我们是否应该(1)引发 IndexError 并不明显;(2) 使用默认值;或 (3) 跳过该值并返回一个比我们开始时行数更少的张量。遵循 Python 的指导原则(“面对歧义,拒绝猜测的诱惑”),我们目前不允许这种操作。

这是有道理的,但我如何实际实现选项 1、2 和 3?我必须将参差不齐的数组转换为张量的 Python 数组,然后手动迭代它们吗?有没有更有效的解决方案?无需通过 Python 解释器就可以在 TensorFlow 图中 100% 工作?

标签: pythontensorflowragged

解决方案


如果您有一个 2D RaggedTensor,那么您可以通过以下方式获得行为 (3):

def get_column_slice_v3(rt, column):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  return slice.flat_values

您可以通过添加 rt.nrows() == tf.size(slice.flat_values) 的断言来获得行为 (1):

def get_column_slice_v1(rt, column):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  with tf.assert_equal(rt.nrows(), tf.size(slice.flat_values):
    return tf.identity(slice.flat_values)

为了获得行为(2),我认为最简单的方法可能是连接一个默认值向量,然后再次切片:

def get_colum_slice_v2(rt, column, default=None):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  if default is None:
    defaults = tf.zeros([slice.nrows(), 1], slice.dtype)
  ele:
    defaults = tf.fill([slice.nrows(), 1], default)
  slice_plus_default = tf.concat([rt, defaults], axis=1)
  slice2 = slice_plus_defaults[:1]
  return slice2.flat_values

可以扩展这些以支持更高维的不规则张量,但逻辑会变得更复杂一些。此外,应该可以扩展这些以支持负列索引。


推荐阅读