python - 将标量输入到 Tensorflow 2 模型的正确方法
问题描述
在我的 Tensorflow 2 模型中,我希望我的批量大小是参数化的,这样我就可以动态构建具有适当批量大小的张量。我有以下代码:
batch_size_param = 128
tf_batch_size = tf.keras.Input(shape=(), name="tf_batch_size", dtype=tf.int32)
batch_indices = tf.range(0, tf_batch_size, 1)
md = tf.keras.Model(inputs={"tf_batch_size": tf_batch_size}, outputs=[batch_indices])
res = md(inputs={"tf_batch_size": batch_size_param})
代码在以下位置引发错误tf.range
:
ValueError: Shape must be rank 0 but is rank 1
for 'limit' for '{{node Range}} = Range[Tidx=DT_INT32](Range/start, tf_batch_size, Range/delta)' with input shapes: [], [?], []
我认为问题在于tf.keras.Input
自动尝试在第一维扩展输入数组的事实,因为它期望输入的部分形状没有批量大小,并将根据输入数组的形状附加批量大小,这在我的情况下是一个标量。我可以将标量值作为一个常量整数输入tf.range
,但是这一次,在编译模型图后我将无法更改它。
有趣的是,即使我也检查了文档,我也没有找到一种仅将标量输入到 TF-2 模型中的正确方法。那么,处理这种情况的最佳方法是什么?
解决方案
不要使用tf.keras.Input
,只需通过子类化来定义模型。
import tensorflow as tf
class ScalarModel(tf.keras.Model):
def __init__(self):
super().__init__()
def call(self, x):
return tf.range(0, x, 1)
print(ScalarModel()(10))
# tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)
推荐阅读
- javascript - 如何查找节点模块导出的文件系统路径?
- jquery - 如何在带有动画的jquery中实现侧边栏切换
- macos - 从word文档中提取引用的宏
- orocrm - Oro CRM 在安装期间不忽略数据库的 parameters.yml
- c# - 任务WhenAll异常
- python - 如何在 AWS CodeDeploy 中自动重启我的 python 应用程序
- linux - 如果我更改了复制文件的方式,为什么我的 Docker 构建速度会加快?
- reactjs - 我可以将我的代码推送到我的私人服务器吗?
- jquery - 预加载器处于活动状态时如何使用 JQuery 禁用滚动
- sql - 对外部表执行查询但只返回原始表?