Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #117 from antinucleon/master
Browse files Browse the repository at this point in the history
move mnist to ualberta server
  • Loading branch information
antinucleon committed Sep 21, 2015
2 parents f69c194 + 8ad680b commit 99d5925
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
19 changes: 11 additions & 8 deletions example/notebooks/alexnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import mxnet as mx"
]
},
Expand Down Expand Up @@ -402,7 +401,7 @@
}
],
"source": [
"mx.visualization.plot_network(\"AlexNet\", softmax)"
"mx.viz.plot_network(\"AlexNet\", softmax)"
]
},
{
Expand All @@ -425,28 +424,32 @@
"# We set batch size for to 256\n",
"batch_size = 256\n",
"# We need to set correct path to image record file\n",
"# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n",
"# For ```mean_image```. if it doesn't exist, the iterator will generate one\n",
"# On HDD, single thread is able to process 800 images / sec\n",
"# the input shape is in format (channel, height, width)\n",
"# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n",
"# rand_mirror option make source image randomly mirrored\n",
"# We use 2 threads to processing our data\n",
"train_dataiter = mx.io.ImageRecordIter(\n",
" shuffle=True,\n",
" path_imgrec=\"./Data/ImageNet/train.rec\",\n",
" mean_img=\"./Data/ImageNet/mean_224.bin\",\n",
" rand_crop=True,\n",
" rand_mirror=True,\n",
" input_shape=(3, 224, 224),\n",
" data_shape=(3, 224, 224),\n",
" batch_size=batch_size,\n",
" nthread=2)\n",
" prefetch_buffer=4,\n",
" preprocess_threads=2)\n",
"# similarly, we can declare our validation iterator\n",
"val_dataiter = mx.io.ImageRecordIter(\n",
" path_imgrec=\"./Data/ImageNet/val.rec\",\n",
" mean_img=\"./Data/ImageNet/mean_224.bin\",\n",
" rand_crop=False,\n",
" rand_mirror=False,\n",
" input_shape=(3, 224, 224),\n",
" data_shape=(3, 224, 224),\n",
" batch_size=batch_size,\n",
" nthread=2)"
" prefetch_buffer=4,\n",
" preprocess_threads=2)"
]
},
{
Expand Down Expand Up @@ -531,7 +534,7 @@
"# When we use data iterator, we don't need to set y because label comes from data iterator directly\n",
"# In this case, eval_data is also a data iterator\n",
"# We will use accuracy to measure our model's performace\n",
"model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n",
"model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc')\n",
"# You need to wait for a while to get the result"
]
},
Expand Down
22 changes: 9 additions & 13 deletions tests/python/common/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@ def GetMNIST_pkl():
def GetMNIST_ubyte():
if not os.path.isdir("data/"):
os.system("mkdir data/")
if not os.path.exists('data/train-images-idx3-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P data/")
os.system("gunzip data/train-images-idx3-ubyte.gz")
if not os.path.exists('data/train-labels-idx1-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P data/")
os.system("gunzip data/train-labels-idx1-ubyte.gz")
if not os.path.exists('data/t10k-images-idx3-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -P data/")
os.system("gunzip data/t10k-images-idx3-ubyte.gz")
if not os.path.exists('data/t10k-labels-idx1-ubyte'):
os.system("wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -P data/")
os.system("gunzip data/t10k-labels-idx1-ubyte.gz")
if (not os.path.exists('data/train-images-idx3-ubyte')) or \
(not os.path.exists('data/train-labels-idx1-ubyte')) or \
(not os.path.exists('data/t10k-images-idx3-ubyte')) or \
(not os.path.exists('data/t10k-labels-idx1-ubyte')):
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P data/")
os.chdir("./data")
os.system("unzip -u mnist.zip")
os.chdir("..")

# download cifar
def GetCifar10():
Expand All @@ -34,5 +30,5 @@ def GetCifar10():
if not os.path.exists('data/cifar10.zip'):
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P data/")
os.chdir("./data")
os.system("unzip cifar10.zip")
os.system("unzip -u cifar10.zip")
os.chdir("..")

0 comments on commit 99d5925

Please sign in to comment.