diff --git a/example/notebooks/alexnet.ipynb b/example/notebooks/alexnet.ipynb index b7bb6bf266c2..c030d873cd08 100644 --- a/example/notebooks/alexnet.ipynb +++ b/example/notebooks/alexnet.ipynb @@ -29,7 +29,6 @@ }, "outputs": [], "source": [ - "%matplotlib inline\n", "import mxnet as mx" ] }, @@ -402,7 +401,7 @@ } ], "source": [ - "mx.visualization.plot_network(\"AlexNet\", softmax)" + "mx.viz.plot_network(\"AlexNet\", softmax)" ] }, { @@ -425,28 +424,32 @@ "# We set batch size for to 256\n", "batch_size = 256\n", "# We need to set correct path to image record file\n", - "# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n", + "# For ```mean_image```. if it doesn't exist, the iterator will generate one\n", + "# On HDD, single thread is able to process 800 images / sec\n", "# the input shape is in format (channel, height, width)\n", "# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n", "# rand_mirror option make source image randomly mirrored\n", "# We use 2 threads to processing our data\n", "train_dataiter = mx.io.ImageRecordIter(\n", + " shuffle=True,\n", " path_imgrec=\"./Data/ImageNet/train.rec\",\n", " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", " rand_crop=True,\n", " rand_mirror=True,\n", - " input_shape=(3, 224, 224),\n", + " data_shape=(3, 224, 224),\n", " batch_size=batch_size,\n", - " nthread=2)\n", + " prefetch_buffer=4,\n", + " preprocess_threads=2)\n", "# similarly, we can declare our validation iterator\n", "val_dataiter = mx.io.ImageRecordIter(\n", " path_imgrec=\"./Data/ImageNet/val.rec\",\n", " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", " rand_crop=False,\n", " rand_mirror=False,\n", - " input_shape=(3, 224, 224),\n", + " data_shape=(3, 224, 224),\n", " batch_size=batch_size,\n", - " nthread=2)" + " prefetch_buffer=4,\n", + " preprocess_threads=2)" ] }, { @@ -531,7 +534,7 @@ "# When we use data iterator, we don't need to set y because label comes from data iterator directly\n", "# In this case, eval_data is also a data iterator\n", "# We will use accuracy to measure our model's performace\n", - "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n", + "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc')\n", "# You need to wait for a while to get the result" ] }, diff --git a/tests/python/common/get_data.py b/tests/python/common/get_data.py index 270132e448b8..65e8ac59ad6f 100644 --- a/tests/python/common/get_data.py +++ b/tests/python/common/get_data.py @@ -14,18 +14,14 @@ def GetMNIST_pkl(): def GetMNIST_ubyte(): if not os.path.isdir("data/"): os.system("mkdir data/") - if not os.path.exists('data/train-images-idx3-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P data/") - os.system("gunzip data/train-images-idx3-ubyte.gz") - if not os.path.exists('data/train-labels-idx1-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P data/") - os.system("gunzip data/train-labels-idx1-ubyte.gz") - if not os.path.exists('data/t10k-images-idx3-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -P data/") - os.system("gunzip data/t10k-images-idx3-ubyte.gz") - if not os.path.exists('data/t10k-labels-idx1-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -P data/") - os.system("gunzip data/t10k-labels-idx1-ubyte.gz") + if (not os.path.exists('data/train-images-idx3-ubyte')) or \ + (not os.path.exists('data/train-labels-idx1-ubyte')) or \ + (not os.path.exists('data/t10k-images-idx3-ubyte')) or \ + (not os.path.exists('data/t10k-labels-idx1-ubyte')): + os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P data/") + os.chdir("./data") + os.system("unzip -u mnist.zip") + os.chdir("..") # download cifar def GetCifar10(): @@ -34,5 +30,5 @@ def GetCifar10(): if not os.path.exists('data/cifar10.zip'): os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P data/") os.chdir("./data") - os.system("unzip cifar10.zip") + os.system("unzip -u cifar10.zip") os.chdir("..")