首页 > 解决方案 > 将标量输入到 Tensorflow 2 模型的正确方法

问题描述

在我的 Tensorflow 2 模型中,我希望我的批量大小是参数化的,这样我就可以动态构建具有适当批量大小的张量。我有以下代码:

batch_size_param = 128

tf_batch_size = tf.keras.Input(shape=(), name="tf_batch_size", dtype=tf.int32)
batch_indices = tf.range(0, tf_batch_size, 1)

md = tf.keras.Model(inputs={"tf_batch_size": tf_batch_size}, outputs=[batch_indices])
res = md(inputs={"tf_batch_size": batch_size_param})

代码在以下位置引发错误tf.range

ValueError: Shape must be rank 0 but is rank 1
 for 'limit' for '{{node Range}} = Range[Tidx=DT_INT32](Range/start, tf_batch_size, Range/delta)' with input shapes: [], [?], []

我认为问题在于tf.keras.Input自动尝试在第一维扩展输入数组的事实,因为它期望输入的部分形状没有批量大小,并将根据输入数组的形状附加批量大小,这在我的情况下是一个标量。我可以将标量值作为一个常量整数输入tf.range,但是这一次,在编译模型图后我将无法更改它。

有趣的是,即使我也检查了文档,我也没有找到一种仅将标量输入到 TF-2 模型中的正确方法。那么,处理这种情况的最佳方法是什么?

标签: pythontensorflowdeep-learning

解决方案


不要使用tf.keras.Input,只需通过子类化来定义模型。

import tensorflow as tf


class ScalarModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, x):
        return tf.range(0, x, 1)


print(ScalarModel()(10))
# tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)

推荐阅读