python - 分批训练,但在 Tensorflow 中使用 tf.estimators 执行单个图像分类
问题描述
为图像分类构建卷积神经网络是我第一次涉足 Python 和 Tensorflow。我有一个要训练的大型图像数据库,我的应用程序是对单个图像进行快速、实时的分类。目前,为了执行单个图像的推理,我必须使训练批量大小等于 1。如果我不这样做(并且训练批量大小为 16),我会收到以下错误:
ValueError: Cannot feed value of shape (1, 132, 128, 1) for Tensor
'tf_reshape1:0', which has shape '(16, 132, 128, 1)'
我真的很想灵活地训练更大的批量,同时仍然能够对单个图像进行分类。
我的问题与此处提出的其他问题非常相似,例如批量训练但在 Tensorflow 中测试单个数据项?和Tensorflow:层大小取决于批量大小?但我缺乏 tf 经验意味着我无法将建议的解决方案实施到我的代码中。
我已经编写了用于图像分类的训练和推理代码,它们基于 tensorflow 网站上给出的用于分类 MNIST 数据库https://www.tensorflow.org/tutorials/estimators/cnn#train_eval_mnist的示例。我的代码使用 tf.estimators https://www.tensorflow.org/guide/estimators一个高级 tensorflow API。上面提出和回答的类似问题的解决方案建议修改我没有(故意)使用的 tf.placeholders。我已经复制了我在下面使用的输入函数和模型函数的代码。我确定我需要发布更多信息,但这也是我提出的第一个 SO 问题,如果我忘记了很多事情,我深表歉意。
训练:
bead_classifier = tf.estimator.Estimator(
model_fn=bead_model_fn,
model_dir=r'/tmp/trained_model')
# Train the model
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": training_imgs},
y=training_labels,
batch_size=16,
num_epochs=None,
shuffle=True)
bead_classifier.train(
input_fn=train_input_fn,
steps = 13000)
hooks=[logging_hook])
模型函数(仅前几层):
def bead_model_fn(features, labels, mode):
"""Model function for CNN."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
input_layer = tf.reshape(features["x"], [-1, 132, 128, 1], name =
'tf_reshape1')
# Convolutional Layer #1
conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=4,
kernel_size=[2, 2],
padding="same",
activation=tf.nn.relu)
# Pooling Layer #1
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
# Convolutional Layer #2
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=8,
kernel_size=[2, 2],
padding="same",
activation=tf.nn.relu)
解决方案
推荐阅读
- mysql - 选择关联表的所有行的数量并加入这些行。MySQL
- python - 检查 3D 点是否在基于正方形的金字塔中
- r - 如何在 r 中建立状态空间模型:将时变参数与 GARCH-M 模型相结合
- xslt - 有没有办法在 XSL 1.0 模板中使用 for 循环在以下场景中打印金额?
- blazor - 在自定义组件中包装 MudBlazor 组件 - @bind-Value 的问题
- python - 如何使用 lxml 解析 pymupdf 的 xml 提取?
- javascript - Accept.js 来自 Authorize.net 实现
- sharepoint - 从 SharePoint 格式的列调用即时流
- api - Google Maps API 的奇怪问题
- javascript - 使用 ElectronForge 而不是使用 create-react-app 获取“常量声明需要初始化值”错误