tensorflow - Tensorflow:批量分类图像
问题描述
我已经按照这个 TensorFlow 教程使用迁移学习方法对图像进行分类。使用在预训练的 MobileNet V2 模型之上添加的近 16,000 个手动分类图像(大约 40/60 分割为 1/0),我的模型在保留测试集上实现了 96% 的准确率。然后我保存了生成的模型。
接下来,我想使用这个经过训练的模型对新图像进行分类。为此,我按照下面描述的方式调整了教程代码的一部分(最后是#Retrieve abatch of the test set)。该代码有效,但是,它只处理一批 32 张图像,仅此而已(源文件夹中有数百张图像)。我在这里想念什么?请指教。
# Import libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import preprocessing
from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib.pyplot as plt
import numpy as np
import os
# Load saved model
model = tf.keras.models.load_model('/model')
# Re-compile model
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# Define paths
PATH = 'Data/'
new_dir = os.path.join(PATH, 'New_images') # New_images must contain at least one class (sub-folder)
IMG_SIZE = (640, 640)
BATCH_SIZE = 32
new_dataset = image_dataset_from_directory(new_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE)
# Retrieve a batch of images from the test set
image_batch, label_batch = new_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)
print('Predictions:\n', predictions.numpy())
len(new_dataset) # equals 25, i.e., there are 25 batches
解决方案
替换此代码:
# Retrieve a batch of images from the test set
image_batch, label_batch = new_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
有了这个:
predictions = model.predict(new_dataset,batch_size=BATCH_SIZE).flatten()
tf.data.Dataset
对象可以直接传递给方法predict()
。参考
推荐阅读
- multithreading - 数据库是否可以阻止 Scala 线程中的并行表访问?
- python - “张量流”没有属性“会话”
- android - 多个活动的导航抽屉,无需选择导航抽屉活动
- kubernetes - 在代理服务器后面运行 BookInfo 示例失败,调用 webhook “pilot.validation.istio.io”失败
- javascript - 直接通过其 id 访问元素时的扩展问题,而不是 getElementById
- c++ - “ostream &os”有什么用?
- stripe-payments - 如何从 Stripe 向我的客户卡汇款?
- javascript - vuetify v-select 值在 Created 和 onChange 上是否不同?
- javascript - 创建一个 java 脚本函数来进行 api 调用并检索数据
- algorithm - 求未加权无向图的直径