diff --git a/docs/tutorials/python/predict_image.md b/docs/tutorials/python/predict_image.md index a9a0d29010c1..8be98d991366 100644 --- a/docs/tutorials/python/predict_image.md +++ b/docs/tutorials/python/predict_image.md @@ -69,6 +69,7 @@ def get_image(url, show=False): img = mx.image.imresize(img, 224, 224) # resize img = img.transpose((2, 0, 1)) # Channel first img = img.expand_dims(axis=0) # batchify + img = img.astype('float32') # for gpu context return img def predict(url):