Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

move mnist to ualberta server #117

Merged
merged 1 commit into from
Sep 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions example/notebooks/alexnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import mxnet as mx"
]
},
Expand Down Expand Up @@ -402,7 +401,7 @@
}
],
"source": [
"mx.visualization.plot_network(\"AlexNet\", softmax)"
"mx.viz.plot_network(\"AlexNet\", softmax)"
]
},
{
Expand All @@ -425,28 +424,32 @@
"# We set batch size for to 256\n",
"batch_size = 256\n",
"# We need to set correct path to image record file\n",
"# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n",
"# For ```mean_image```. if it doesn't exist, the iterator will generate one\n",
"# On HDD, single thread is able to process 800 images / sec\n",
"# the input shape is in format (channel, height, width)\n",
"# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n",
"# rand_mirror option make source image randomly mirrored\n",
"# We use 2 threads to processing our data\n",
"train_dataiter = mx.io.ImageRecordIter(\n",
" shuffle=True,\n",
" path_imgrec=\"./Data/ImageNet/train.rec\",\n",
" mean_img=\"./Data/ImageNet/mean_224.bin\",\n",
" rand_crop=True,\n",
" rand_mirror=True,\n",
" input_shape=(3, 224, 224),\n",
" data_shape=(3, 224, 224),\n",
" batch_size=batch_size,\n",
" nthread=2)\n",
" prefetch_buffer=4,\n",
" preprocess_threads=2)\n",
"# similarly, we can declare our validation iterator\n",
"val_dataiter = mx.io.ImageRecordIter(\n",
" path_imgrec=\"./Data/ImageNet/val.rec\",\n",
" mean_img=\"./Data/ImageNet/mean_224.bin\",\n",
" rand_crop=False,\n",
" rand_mirror=False,\n",
" input_shape=(3, 224, 224),\n",
" data_shape=(3, 224, 224),\n",
" batch_size=batch_size,\n",
" nthread=2)"
" prefetch_buffer=4,\n",
" preprocess_threads=2)"
]
},
{
Expand Down Expand Up @@ -531,7 +534,7 @@
"# When we use data iterator, we don't need to set y because label comes from data iterator directly\n",
"# In this case, eval_data is also a data iterator\n",
"# We will use accuracy to measure our model's performace\n",
"model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n",
"model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc')\n",
"# You need to wait for a while to get the result"
]
},
Expand Down
22 changes: 9 additions & 13 deletions tests/python/common/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@ def GetMNIST_pkl():
def GetMNIST_ubyte():
if not os.path.isdir("data/"):
os.system("mkdir data/")
if not os.path.exists('data/train-images-idx3-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P data/")
os.system("gunzip data/train-images-idx3-ubyte.gz")
if not os.path.exists('data/train-labels-idx1-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P data/")
os.system("gunzip data/train-labels-idx1-ubyte.gz")
if not os.path.exists('data/t10k-images-idx3-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -P data/")
os.system("gunzip data/t10k-images-idx3-ubyte.gz")
if not os.path.exists('data/t10k-labels-idx1-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -P data/")
os.system("gunzip data/t10k-labels-idx1-ubyte.gz")
if (not os.path.exists('data/train-images-idx3-ubyte')) or \
(not os.path.exists('data/train-labels-idx1-ubyte')) or \
(not os.path.exists('data/t10k-images-idx3-ubyte')) or \
(not os.path.exists('data/t10k-labels-idx1-ubyte')):
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P data/")
os.chdir("./data")
os.system("unzip -u mnist.zip")
os.chdir("..")

# download cifar
def GetCifar10():
Expand All @@ -34,5 +30,5 @@ def GetCifar10():
if not os.path.exists('data/cifar10.zip'):
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P data/")
os.chdir("./data")
os.system("unzip cifar10.zip")
os.system("unzip -u cifar10.zip")
os.chdir("..")