python - TensorFlow py_function 设置返回值形状
问题描述
我正在尝试设置从tf.py_function
调用返回的 numpy 数组的形状。函数返回一个数组元组及其形状,所以我可以在里面设置它的形状tf.function
。但是我不能在 set_shape 调用中使用返回的形状数组。我尝试将其转换为 int32 或通过索引等访问它,但都导致TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'strided_slice:0' shape=<unknown> dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
错误。
@tf.function
def samples(img, mask) -> tuple:
xs, xs_shape = tf.py_function(process,[128, 128], [tf.float32, tf.int32])
xs.set_shape(tf.TensorShape(xs_shape))
return xs
ds = ds.map(samples)
解决方案
这是一个人为的例子。我认为这可能是您的要求。
def f(x):
print("print ", x)
# Include Logic using 'x' which is a tensor.
# Sample return values
return [tf.convert_to_tensor([1,2]), tf.convert_to_tensor([1,2]).shape]
@tf.function
def samples(x) -> tuple:
xs, xs_shape = f(x)
xs.set_shape(xs_shape)
print( xs_shape)
return xs
dataset = tf.data.Dataset.range(1, 6)
dataset = dataset.map( samples )
推荐阅读
- swift - 隐藏字段未以 Eureka swift 形式提交其值
- javascript - Deleting multiple child components from parent component in react
- networking - Kubernetes services cannot reach each other anymore
- html - Justified Gallery issue
- python - 如何在for循环中遍历熊猫行
- android - How to change the color of radiobuttons dynamically?
- javascript - 编译多个 hash.location
- xcode - 使用 XcodeGen 自定义配置
- azure-service-fabric - Deploying an application with Application Parameters through FabricClient
- netlify - netlify 会阻止 php 表单处理吗?