-
Notifications
You must be signed in to change notification settings - Fork 382
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1bde91c
commit b4f4454
Showing
3 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
''Trains a simple convnet on the MNIST dataset. | ||
Gets to 99.25% test accuracy after 12 epochs | ||
(there is still a lot of margin for parameter tuning). | ||
16 seconds per epoch on a GRID K520 GPU. | ||
'' | ||
|
||
from __future__ import print_function | ||
import keras | ||
from keras.datasets import mnist | ||
from keras.models import Sequential | ||
from keras.layers import Dense, Dropout, Flatten | ||
from keras.layers import Conv2D, MaxPooling2D | ||
from keras import backend as K | ||
|
||
batch_size = 128 | ||
num_classes = 10 | ||
epochs = 12 | ||
|
||
# input image dimensions | ||
img_rows, img_cols = 28, 28 | ||
|
||
# the data, split between train and test sets | ||
(x_train, y_train), (x_test, y_test) = mnist.load_data() | ||
|
||
if K.image_data_format() == 'channels_first': | ||
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) | ||
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) | ||
input_shape = (1, img_rows, img_cols) | ||
else: | ||
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) | ||
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) | ||
input_shape = (img_rows, img_cols, 1) | ||
|
||
x_train = x_train.astype('float32') | ||
x_test = x_test.astype('float32') | ||
x_train /= 255 | ||
x_test /= 255 | ||
print('x_train shape:', x_train.shape) | ||
print(x_train.shape[0], 'train samples') | ||
print(x_test.shape[0], 'test samples') | ||
|
||
# convert class vectors to binary class matrices | ||
y_train = keras.utils.to_categorical(y_train, num_classes) | ||
y_test = keras.utils.to_categorical(y_test, num_classes) | ||
|
||
model = Sequential() | ||
model.add(Conv2D(32, kernel_size=(3, 3), | ||
activation='relu', | ||
input_shape=input_shape)) | ||
model.add(Conv2D(64, (3, 3), activation='relu')) | ||
model.add(MaxPooling2D(pool_size=(2, 2))) | ||
model.add(Dropout(0.25)) | ||
model.add(Flatten()) | ||
model.add(Dense(128, activation='relu')) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(num_classes, activation='softmax')) | ||
|
||
model.compile(loss=keras.losses.categorical_crossentropy, | ||
optimizer=keras.optimizers.Adadelta(), | ||
metrics=['accuracy']) | ||
|
||
model.fit(x_train, y_train, | ||
batch_size=batch_size, | ||
epochs=epochs, | ||
verbose=1, | ||
validation_data=(x_test, y_test)) | ||
score = model.evaluate(x_test, y_test, verbose=0) | ||
print('Test loss:', score[0]) | ||
print('Test accuracy:', score[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
|
||
### 卷积神经网络CNN实现手写数字识别 | ||
|
||
在学习机器学习的时候,首当其冲的就是准备一份通用的数据集,方便与其他的算法进行比较。 | ||
|
||
### MNIST简介 | ||
|
||
MNIST数据集原网址:http://yann.lecun.com/exdb/mnist/ | ||
|
||
 | ||
|
||
数据集是这样的一些手写数字 | ||
|
||
**问题:通过某个算法将0-9的数字进行分类** | ||
|
||
### 下载 | ||
Github源码下载:数据集(源文件+解压文件+字体图像jpg格式),py源码文件 | ||
文件目录 | ||
```python | ||
/utils/data_util.py 用于加载MNIST数据集方法文件 | ||
/utils/test.py 用于测试的文件,一个简单的KNN测试MNIST数据集 | ||
/data/train-images.idx3-ubyte 训练集X | ||
/dataset/train-labels.idx1-ubyte 训练集y | ||
/dataset/data/t10k-images.idx3-ubyte 测试集X | ||
/dataset/data/t10k-labels.idx1-ubyte 测试集y | ||
``` | ||
|
||
### 结构解释 | ||
MNIST数据集解释 | ||
将MNIST文件解压后,发现这些文件并不是标准的图像格式。这些图像数据都保存在二进制文件中。每个样本图像的宽高为28*28。 | ||
|
||
mnist的结构如下,选取train-images | ||
```python | ||
[code]TRAINING SET IMAGE FILE (train-images-idx3-ubyte): | ||
|
||
[offset] [type] [value] [description] | ||
0000 32 bit integer 0x00000803(2051) magic number | ||
0004 32 bit integer 60000 number of images | ||
0008 32 bit integer 28 number of rows | ||
0012 32 bit integer 28 number of columns | ||
0016 unsigned byte ?? pixel | ||
0017 unsigned byte ?? pixel | ||
........ | ||
xxxx unsigned byte ?? pixel | ||
``` | ||
|
||
首先该数据是以二进制存储的,我们读取的时候要以’rb’方式读取;其次,真正的数据只有[value]这一项,其他的[type]等只是来描述的,并不真正在数据文件里面。也就是说,在读取真实数据之前,我们要读取4个 | ||
|
||
32 bit integer | ||
.由[offset]我们可以看出真正的pixel是从0016开始的,一个int 32位,所以在读取pixel之前我们要读取4个 32 bit integer,也就是magic number, number of images, number of rows, number of columns. 当然,在这里使用struct.unpack_from()会比较方便. | ||
|
||
### 算法实现 |