-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_mnist.py
44 lines (36 loc) · 1.55 KB
/
data_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import absolute_import
from __future__ import print_function
from future.standard_library import install_aliases
install_aliases()
import os
import gzip
import struct
import array
import numpy as np
from urllib.request import urlretrieve
def download(url, filename):
if not os.path.exists('data'):
os.makedirs('data')
out_file = os.path.join('data', filename)
if not os.path.isfile(out_file):
urlretrieve(url, out_file)
def mnist():
base_url = 'http://yann.lecun.com/exdb/mnist/'
def parse_labels(filename):
with gzip.open(filename, 'rb') as fh:
magic, num_data = struct.unpack(">II", fh.read(8))
return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(filename):
with gzip.open(filename, 'rb') as fh:
magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)
for filename in ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']:
download(base_url + filename, filename)
train_images = parse_images('data/train-images-idx3-ubyte.gz')
train_labels = parse_labels('data/train-labels-idx1-ubyte.gz')
test_images = parse_images('data/t10k-images-idx3-ubyte.gz')
test_labels = parse_labels('data/t10k-labels-idx1-ubyte.gz')
return train_images, train_labels, test_images, test_labels