首页 > 解决方案 > 如何使用 Tensorflow Cifar10 教程代码进行推断?

问题描述

我是 TensorFlow 的绝对初学者。

如果我有一张图片(或一组图片)想尝试使用 Cifar10 TensorFlow 教程中的代码进行分类,我该怎么做?

我完全不知道从哪里开始。

标签: pythontensorflow

解决方案


  1. 完全按照教程使用基本 CIFAR10 数据集训练模型。
  2. 使用您自己的输入创建一个新图表 - 可能最容易使用 atf.placeholder并按照下面的方式提供数据,但还有很多其他方法。
  3. 开始一个会话,加载之前保存的权重。
  4. 运行会话(feed_dict如果您使用的是 aplaceholder如上所述)。

.

import tensorflow as tf

train_dir = '/tmp/cifar10_train'  # or use FLAGS as in the train example
batch_size = 8
height = 32
width = 32

image = tf.placeholder(shape=(batch_size, height, width, 3), dtype=tf.uint8)
std_img = tf.image.per_image_standardization(image)
logits = cifar10.inference(std_img)
predictions = tf.argmax(logits, axis=-1)

def get_image_data_batches():
    n_batchs = 100
    for i in range(n_batchs):
        yield (np.random.uniform(size=(batch_size, height, width, 3)*255).astype(np.uint8)

def do_stuff_with(logit_vals, prediction_vals):
    pass

with tf.Session() as sess:
    # restore variables
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(train_dir))
    # run inference
    for batch_data in get_image_data_batches():
        logit_vals, prediction_vals = sess.run([logits, predictions], feed_dict={image: image_data})
        do_stuff_with(logit_vals, prediction_vals)

有更好的方法可以将数据放入图表(请参阅 参考资料tf.data.Dataset),但我相信tf.placeholders 是学习和开始运行某些东西的最简单方法。

另请查看tf.estimator.Estimators 以获得更清洁的会话管理方式。它与本教程中的完成方式非常不同,而且灵活性稍差,但对于标准网络,它们可以节省您编写大量样板代码。


推荐阅读