python - 如何使用 Tensorflow Cifar10 教程代码进行推断?
问题描述
我是 TensorFlow 的绝对初学者。
如果我有一张图片(或一组图片)想尝试使用 Cifar10 TensorFlow 教程中的代码进行分类,我该怎么做?
我完全不知道从哪里开始。
解决方案
- 完全按照教程使用基本 CIFAR10 数据集训练模型。
- 使用您自己的输入创建一个新图表 - 可能最容易使用 a
tf.placeholder
并按照下面的方式提供数据,但还有很多其他方法。 - 开始一个会话,加载之前保存的权重。
- 运行会话(
feed_dict
如果您使用的是 aplaceholder
如上所述)。
.
import tensorflow as tf
train_dir = '/tmp/cifar10_train' # or use FLAGS as in the train example
batch_size = 8
height = 32
width = 32
image = tf.placeholder(shape=(batch_size, height, width, 3), dtype=tf.uint8)
std_img = tf.image.per_image_standardization(image)
logits = cifar10.inference(std_img)
predictions = tf.argmax(logits, axis=-1)
def get_image_data_batches():
n_batchs = 100
for i in range(n_batchs):
yield (np.random.uniform(size=(batch_size, height, width, 3)*255).astype(np.uint8)
def do_stuff_with(logit_vals, prediction_vals):
pass
with tf.Session() as sess:
# restore variables
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(train_dir))
# run inference
for batch_data in get_image_data_batches():
logit_vals, prediction_vals = sess.run([logits, predictions], feed_dict={image: image_data})
do_stuff_with(logit_vals, prediction_vals)
有更好的方法可以将数据放入图表(请参阅 参考资料tf.data.Dataset
),但我相信tf.placeholder
s 是学习和开始运行某些东西的最简单方法。
另请查看tf.estimator.Estimator
s 以获得更清洁的会话管理方式。它与本教程中的完成方式非常不同,而且灵活性稍差,但对于标准网络,它们可以节省您编写大量样板代码。
推荐阅读
- python - 如何使用 Selenium 和 Python 在文本区域内发送文本
- python-3.x - 为什么这些修改后的冒泡排序算法在 python3 中不产生相同的顺序?
- python - 如何从多个列表中生成字典?
- c - 您如何枚举 libav for v4l2 中的格式、分辨率和帧速率?
- javascript - 为什么 Node.js 的异步特性使其不适合视频/图像处理?
- python - 使用 pyinstaller 创建的可执行文件不运行
- flutter - 颤动中的淡入或不透明
- laravel - Laravel in_arrary 显示标签
- r - R 查看器窗格未显示 D3 交互式图表
- python - 如何最好地将函数应用于数据框(应用()或不应用)?