首页 > 解决方案 > TF2 在@tf.function 中枚举参差不齐的张量

问题描述

我的目标是将 RaggedTensor 沿其第一个维度拆分为张量字典。下面的示例代码可以正常工作,直到它被 @tf.function 修饰。然后,它会产生一个令人困惑的错误:ValueError: slice index 2 of dimension 0 out of bounds。对于 '{{node RaggedGetItem_2/strided_slice_2}} ....' 输入形状:[2]、[1]、[1]、[1] 和计算输入张量:输入[1] = <2>,输入[2] = <3>,输入[3] = <1>。

@tf.function
def fn():
    a = tf.constant([[1,2,3],
                 [4,5,6]])
    b = tf.constant([[7,8,9],
                 [10,11,12],
                 [13,14,15]])
    rt = tf.ragged.stack([a, b])
    
    d = {}
    for k,x in enumerate(rt):
        d.update({str(k):x})
    
    return d
    
fn()

有人可以解释发生了什么吗?它与python的副作用有关吗?

谢谢!

标签: pythontensorflowtensorflow2.0tensorflow2.xragged-tensors

解决方案


推荐阅读