首页 > 解决方案 > 张量切片会丢失 TensorFlow 中的形状信息

问题描述

我正在尝试动态切片张量以自动调整其形状以进行下一次迭代。然而,我意识到,当在图形模式下切片时,张量的形状信息会丢失,因此我无法进一步对其应用某些操作,这需要知道给定张量的形状。下面我附上了一个示例代码,在我的具体示例中,该opt_with_slicing函数vectorized_map在一个更大的函数中定义,该函数负责自动区分。由于原始函数太大而无法包含在此处,因此我对其进行了相应的简化;

a = tf.constant(np.linspace(0.,10.,11,endpoint=True)[::-1])
b = tf.ones((2,10))

def opt_with_slicing(x, some_cutoff: float):
    a, b = x

    new_size = tf.math.count_nonzero(
        tf.cast(a >= some_cutoff, dtype=tf.int32), dtype=tf.int32
    )
    
    tf.print(f"new size {new_size}, initial size {a.get_shape()}")
    
    test1 = b[:, :new_size]
    test2 = tf.slice(b, [0, 0], [b.get_shape()[0], new_size])
    
    tf.print(f"test1 shape {test1.get_shape()}, test2 shape {test2.get_shape()}")
    return test1, test2

tf.function(opt_with_slicing)([a,b], 5.)  

# Output:
# new size Tensor("count_nonzero/Cast_1:0", shape=(), dtype=int32), initial size (11,)
# test1 shape (2, None), test2 shape (2, None)
# (<tf.Tensor: shape=(2, 6), dtype=float32, numpy=
#  array([[1., 1., 1., 1., 1., 1.],
#         [1., 1., 1., 1., 1., 1.]], dtype=float32)>,
#  <tf.Tensor: shape=(2, 6), dtype=float32, numpy=
#  array([[1., 1., 1., 1., 1., 1.],
#         [1., 1., 1., 1., 1., 1.]], dtype=float32)>)

正如您从打印出来的 和 的形状信息中看到的那样test1test2由于这是一个动态操作,我无法知道new_size执行之前的形状信息。有没有办法在不破坏图形模式的情况下恢复函数的形状信息?

PS:我也试过了boolean_mask

mask = tf.greater_equal(a, some_cutoff)
masked_shape = tf.boolean_mask(a, mask).get_shape()[0]

masked_shape事实证明None也是如此。

系统信息:

标签: pythontensorflow

解决方案


推荐阅读