首页 > 解决方案 > 张量流 CNN MNIST 示例:批量大小在模型中的工作原理

问题描述

在 TensorFlow 的 CNN MNIST 示例中,我不明白批量大小是如何工作的,当他们调用模型时,他们将 bach 的大小指定为 100:

train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
num_epochs=None,shuffle=True)
mnist_classifier.train(input_fn=train_input_fn,steps=20000,hooks=[logging_hook])

但是当模型被调用时:

def cnn_model_fn(features, labels, mode):
  # Input Layer
  # Reshape X to 4-D tensor: [batch_size, width, height, channels]
  # MNIST images are 28x28 pixels, and have one color channel
  input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

他们将 -1 放入批量大小,我在 tensorflow 教程中阅读,当他们告诉计算机推断该维度时,他们使用了 -1 我不明白的是,在我们放入 100 之前,现在因为 -1 不明白如何输入你能帮我解释一下模型的批量大小吗?谢谢你。

标签: tensorflowbatch-processingmnistconvolutional-neural-networktensorflow-estimator

解决方案


tl;博士

batch_size方法中的属性与函数中的属性tf.reshape()完全不同。batch_sizetf.estimator.inputs.numpy_input_fn

批量大小input_fn

batch_size该方法的属性tf.estimator.inputs.numpy_input_fn控制在特定时期(或时间实例)将训练或评估数据集的多少观察(或行或记录)。因此,在提供的示例中,batch_size = 100表示数据集中的 100 行(在本例中为图像)将在每个 epoch 由学习算法进行训练。

重塑张量

该方法tf.reshape用于改变张量的形状。该方法tf.reshape具有属性(tensor, shape)。从文档中,该shape属性有一个特殊的值-1,它推断该特定轴的维度以保留总大小。因此,从提供的示例[-1, 28, 28, 1]转换为[batch_size, row, column, channel]. A batch_sizeof -1 意味着 TensorFlow 将保持图像的大小,因为它被重新整形为所有图像的 784 个输入特征(即 28 * 28)的一维数组。


推荐阅读