-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist_dataset.py
48 lines (39 loc) · 1.55 KB
/
mnist_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import struct
import array as arr
from numpy import array, zeros, uint8, int8
def read(digits, dataset="training", path="."):
"""Loads MNIST files into 3D numpy arrays.
* Adapted from: http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
* Source: http://g.sweyla.com/blog/2012/mnist-numpy/
* MNIST: http://yann.lecun.com/exdb/mnist/
**Parameters**
:digits: list; digits we want to load
:dataset: string; 'training' or 'testing'
:path: string; path to the data set files
"""
if dataset is "training":
fname_img = os.path.join(path, 'train-images-idx3-ubyte')
fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')
elif dataset is "testing":
fname_img = os.path.join(path, 't10k-images-idx3-ubyte')
fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')
else:
raise ValueError("dataset must be 'testing' or 'training'")
flbl = open(fname_lbl, 'rb')
struct.unpack(">II", flbl.read(8))
lbl = arr.array("b", flbl.read())
flbl.close()
fimg = open(fname_img, 'rb')
_, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = arr.array("B", fimg.read())
fimg.close()
ind = [k for k in xrange(size) if lbl[k] in digits]
N = len(ind)
images = zeros((N, rows, cols), dtype=uint8)
labels = zeros((N, 1), dtype=int8)
for i in xrange(len(ind)):
images[i] = array(img[ind[i]*rows*cols:
(ind[i]+1)*rows*cols]).reshape((rows, cols))
labels[i] = lbl[ind[i]]
return images, labels