forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist-convnet.py
executable file
·140 lines (117 loc) · 6.2 KB
/
mnist-convnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-convnet.py
from tensorpack import tfv1 as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import summary
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
IMAGE_SIZE = 28
class Model(ModelDesc):
# See tutorial at https://tensorpack.readthedocs.io/tutorial/training-interface.html#with-modeldesc-and-trainconfig
def inputs(self):
"""
Define all the inputs (with type, shape, name) that the graph will need.
"""
return [tf.TensorSpec((None, IMAGE_SIZE, IMAGE_SIZE), tf.float32, 'input'),
tf.TensorSpec((None,), tf.int32, 'label')]
def build_graph(self, image, label):
"""This function should build the model which takes the input variables (defined above)
and return cost at the end."""
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
# See tutorial at https://tensorpack.readthedocs.io/tutorial/symbolic.html
with argscope(Conv2D, kernel_size=3, activation=tf.nn.relu, filters=32):
# LinearWrap is just a syntax sugar.
# See tutorial at https://tensorpack.readthedocs.io/tutorial/symbolic.html
logits = (LinearWrap(image)
.Conv2D('conv0')
.MaxPooling('pool0', 2)
.Conv2D('conv1')
.Conv2D('conv2')
.MaxPooling('pool1', 2)
.Conv2D('conv3')
.FullyConnected('fc0', 512, activation=tf.nn.relu)
.Dropout('dropout', rate=0.5)
.FullyConnected('fc1', 10, activation=tf.identity)())
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
correct = tf.cast(tf.nn.in_top_k(predictions=logits, targets=label, k=1), tf.float32, name='correct')
accuracy = tf.reduce_mean(correct, name='accuracy')
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
# 1. written to tensosrboard
# 2. written to stat.json
# 3. printed after each epoch
# You can also just call `tf.summary.scalar`. But moving summary has some other benefits.
# See tutorial at https://tensorpack.readthedocs.io/tutorial/summary.html
train_error = tf.reduce_mean(1 - correct, name='train_error')
summary.add_moving_summary(train_error, accuracy)
# Use a regex to find parameters to apply weight decay.
# Here we apply a weight decay on all W (weight matrix) of all fc layers
# If you don't like regex, you can certainly define the cost in any other methods.
wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
total_cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, total_cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']))
# the function should return the total cost to be optimized
return total_cost
def optimizer(self):
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
# This will also put the summary in tensorboard, stat.json and print in terminal,
# but this time without moving average
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr)
def get_data():
# We don't need any fancy data loading for this simple example.
# See dataflow tutorial at https://tensorpack.readthedocs.io/tutorial/dataflow.html
train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True)
train = PrintData(train)
return train, test
if __name__ == '__main__':
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This len(data) is the default value.
steps_per_epoch = len(dataset_train)
# get the config which contains everything necessary in a training
config = TrainConfig(
model=Model(),
# The input source for training. FeedInput is slow, this is just for demo purpose.
# In practice it's best to use QueueInput or others.
# See tutorial at https://tensorpack.readthedocs.io/tutorial/extend/input-source.html
data=FeedInput(dataset_train),
# We use a few simple callbacks in this demo.
# See tutorial at https://tensorpack.readthedocs.io/tutorial/callback.html
callbacks=[
ModelSaver(), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
ScalarStats( # produce `val_accuracy` and `val_cross_entropy_loss`
['cross_entropy_loss', 'accuracy'], prefix='val')),
# MaxSaver needs to come after InferenceRunner to obtain its score
MaxSaver('val_accuracy'), # save the model with highest accuracy
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
)
# Use a simple trainer in this demo.
# More trainers with multi-gpu or distributed functionalities are available.
# See tutorial at https://tensorpack.readthedocs.io/tutorial/trainer.html
launch_train_with_config(config, SimpleTrainer())