首页 > 解决方案 > 如何使用 Tensorflow 功能 API 沿批量维度进行广播?

问题描述

在某些应用中,比如 slot attention(在Pytorch 中实现),有必要沿着批处理维度进行广播。但是,我看不到如何使用功能 API 执行此操作。例如,

import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, input.shape)

引发以下错误:

ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)

因此,我求助于子类tf.keras.Model化,但我想将我的代码保留在功能 API 中。有谁知道如何做到这一点?

标签: pythontensorflow

解决方案


最后通过使用找到了答案tf.keras.backend.shape

const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, [tf.keras.backend.shape(input)[0], 4] )

# Shape of const is now (None, 4)

推荐阅读