-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare.py
63 lines (51 loc) · 1.96 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
SVM適用前のデータの前処理を行います.
MNISTファイル(gzip)を、CSVファイルに変換します.
"""
import os
import gzip
import struct
def csv_image(fname, type_):
"""
画像データを出力します.
@param {String} fname - MNISTのファイル名
@param {String} type_ - one of { training | test }
"""
print("%s processing..." % fname)
# 画像データをGzipファイルから読み取ります.
with gzip.open(os.path.join("mnist", fname), "rb") as f:
_, cnt, rows, cols = struct.unpack(">IIII", f.read(16))
# 画像読み込み
images = []
for i in range(cnt):
binarys = f.read(rows * cols)
images.append(",".join([str(b) for b in binarys]))
# CSV結果として出力します.
with open(os.path.join("csv", type_ + "_image.csv"), "w") as f:
f.write("\n".join(images))
def csv_label(fname, type_):
"""
ラベルデータを出力します.
@param {String} fname - MNISTのファイル名
@param {String} type_ - one of { training | test }
"""
print("%s processing..." % fname)
# ラベルデータをGzipファイルから読み取ります.
with gzip.open(os.path.join("mnist", fname), "rb") as f:
_, cnt = struct.unpack(">II", f.read(8))
labels = []
for i in range(cnt):
label = str(struct.unpack("B", f.read(1))[0])
labels.append(label)
# CSV結果として出力します.
with open(os.path.join("csv", type_ + "_label.csv"), "w") as f:
f.write("\n".join(labels))
if __name__ == "__main__":
if not os.path.exists("csv"):
os.mkdir("csv")
# トレーニングデータ.
csv_image("train-images-idx3-ubyte.gz", "training")
csv_label("train-labels-idx1-ubyte.gz", "training")
# テストデータ.
csv_image("t10k-images-idx3-ubyte.gz", "test")
csv_label("t10k-labels-idx1-ubyte.gz", "test")