用到了model里的FeedForward.load和predict
import osimport mxnet as mximport numpy as npimport Imagefrom collections import namedtupleBatch = namedtuple('Batch',['data'])synsets = [0,1,2,3,4,5,6,7,8,9]def predict(img_url,model,synsets): img = Image.open(img_url) img = img.convert('L') img = img.resize((28,28),Image.ANTIALIAS) img.save(img_url) img = np.asarray(img,dtype=np.uint8) img = img.reshape(1,1,28,28).astype(np.float32)/255 val = mx.io.NDArrayIter(data=img) res = model.predict(X=val)[0] for i in range(0,10): print "%d: %.2f" % (synsets[i],res[i])model = mx.model.FeedForward.load('MNIST_MXNet',100)while(1): img_url = raw_input("Enter the img_url: ") predict(img_url,model,synsets)
save时用到的是 model.save('MNIST_MXNet',100)