From 996d90fe7b540edec18b3132696aee6153002348 Mon Sep 17 00:00:00 2001 From: Pracheer Gupta Date: Thu, 25 May 2017 22:33:31 -0700 Subject: [PATCH] Pre-trained model tutorial fixes. (#6453) Before the change on running the tutorial for the first time: "UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])". It also showed probability of >>1 due to incorrect usage of np.argsort(). --- docs/tutorials/basic/ndarray.md | 2 +- docs/tutorials/python/predict_image.md | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/tutorials/basic/ndarray.md b/docs/tutorials/basic/ndarray.md index 63a5c48071b5..34333d8d4f99 100644 --- a/docs/tutorials/basic/ndarray.md +++ b/docs/tutorials/basic/ndarray.md @@ -10,7 +10,7 @@ to `numpy.ndarray`. Like the corresponding NumPy data structure, MXNet's So you might wonder, why not just use NumPy? MXNet offers two compelling advantages. First, MXNet's `NDArray` supports fast execution on a wide range of hardware configurations, including CPU, GPU, and multi-GPU machines. _MXNet_ -also scales to distribute systems in the cloud. Second, MXNet's NDArray +also scales to distributed systems in the cloud. Second, MXNet's `NDArray` executes code lazily, allowing it to automatically parallelize multiple operations across the available hardware. diff --git a/docs/tutorials/python/predict_image.md b/docs/tutorials/python/predict_image.md index 6d6e295fc1c4..90db3896e54a 100644 --- a/docs/tutorials/python/predict_image.md +++ b/docs/tutorials/python/predict_image.md @@ -24,9 +24,10 @@ occurances of `mx.cpu()` with `mx.gpu()` to accelerate the computation. ```python sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0) -mod = mx.mod.Module(symbol=sym, context=mx.cpu()) -mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))]) -mod.set_params(arg_params, aux_params) +mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None) +mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], + label_shapes=mod._label_shapes) +mod.set_params(arg_params, aux_params, allow_missing=True) with open('synset.txt', 'r') as f: labels = [l.rstrip() for l in f] ``` @@ -68,8 +69,8 @@ def predict(url): prob = mod.get_outputs()[0].asnumpy() # print the top-5 prob = np.squeeze(prob) - prob = np.argsort(prob)[::-1] - for i in prob[0:5]: + a = np.argsort(prob)[::-1] + for i in a[0:5]: print('probability=%f, class=%s' %(prob[i], labels[i])) ```