首页 > 技术文章 > MXNet学习:预测结果-识别单张图片

Mu001999 2016-12-26 01:07 原文

用到了model里的FeedForward.load和predict

import os
import mxnet as mx
import numpy as np
import Image
from collections import namedtuple

Batch = namedtuple('Batch',['data'])
synsets = [0,1,2,3,4,5,6,7,8,9]


def predict(img_url,model,synsets):
    img = Image.open(img_url)
    img = img.convert('L')
    img = img.resize((28,28),Image.ANTIALIAS)
    img.save(img_url)
    img = np.asarray(img,dtype=np.uint8)
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    val = mx.io.NDArrayIter(data=img)
    res =  model.predict(X=val)[0]
    for i in range(0,10):
        print "%d: %.2f" % (synsets[i],res[i])


model = mx.model.FeedForward.load('MNIST_MXNet',100)
while(1):
    img_url = raw_input("Enter the img_url: ")
    predict(img_url,model,synsets)

save时用到的是 model.save('MNIST_MXNet',100) 

推荐阅读