首页 > 解决方案 > 当作为参数传递给“密集”层时,为什么“input_shape”不包括批次维度?

问题描述

在 Keras 中,为什么input_shape当作为参数传递给层时不包括批次维度,Dense但在input_shape传递给build模型的方法时是否包括批次维度?

import tensorflow as tf
from tensorflow.keras.layers import Dense

if __name__ == "__main__":
    model1 = tf.keras.Sequential([Dense(1, input_shape=[10])])
    model1.summary()

    model2 = tf.keras.Sequential([Dense(1)])
    model2.build(input_shape=[None, 10])  # why [None, 10] and not [10]?
    model2.summary()

这是 API 设计的明智选择吗?如果是,为什么?

标签: pythontensorflowkeras

解决方案


您可以通过几种不同的方式指定模型的输入形状。例如,通过向模型的第一层提供以下参数之一:

  • batch_input_shape:第一个维度是批量大小的元组。
  • input_shape: 不包括批量大小的元组,例如,假设批量大小为Nonebatch_size,如果指定的话。
  • input_dim: 一个标量,表示输入的维度。

在所有这些情况下,Keras 都会在内部存储一个属性_batch_input_size来构建模型。

关于该build方法,我的猜测是这确实是一个有意识的选择——关于批量大小的信息可能对在某些(也许是未曾想到的)情况下构建模型有用。因此,包含批处理维度作为输入build的框架比不包含批处理维度的框架更通用和完整。尽管如此,我同意你的观点,即命名论点batch_input_shape而不是input_shape会使一切更加一致。


还值得一提的是,用户很少需要自己调用该build方法。这在需要时在内部发生。如今,甚至可以在创建模型时忽略参数(尽管这样的方法在模型构建之前将不起作用)。在这种情况下,Keras 能够从 的参数推断输入形状。input_shapesummaryxfit


推荐阅读