首页 > 解决方案 > TensorFlow py_function 设置返回值形状

问题描述

我正在尝试设置从tf.py_function调用返回的 numpy 数组的形状。函数返回一个数组元组及其形状,所以我可以在里面设置它的形状tf.function。但是我不能在 set_shape 调用中使用返回的形状数组。我尝试将其转换为 int32 或通过索引等访问它,但都导致TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'strided_slice:0' shape=<unknown> dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'错误。

@tf.function
def samples(img, mask) -> tuple:
    xs, xs_shape = tf.py_function(process,[128, 128], [tf.float32, tf.int32])
    xs.set_shape(tf.TensorShape(xs_shape))
    return xs

ds = ds.map(samples)

标签: pythontensorflowkerastensorflow2.0

解决方案


这是一个人为的例子。我认为这可能是您的要求。

def f(x):
    print("print ", x)
    # Include Logic using 'x' which is a tensor.
    # Sample  return values
    return [tf.convert_to_tensor([1,2]), tf.convert_to_tensor([1,2]).shape]


@tf.function
def samples(x) -> tuple:
    xs, xs_shape = f(x)
    xs.set_shape(xs_shape)
    print( xs_shape)
    return xs

dataset = tf.data.Dataset.range(1, 6)
dataset = dataset.map(  samples )

推荐阅读