Skip to content

Commit a9dd0b8

Browse files
committed
replace step_per_epoch by steps_per_epoch in examples
1 parent 543d929 commit a9dd0b8

23 files changed

+37
-43
lines changed

docs/casestudies/colorize.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,9 @@ def get_config():
298298
return TrainConfig(
299299
dataflow=dataset,
300300
optimizer=tf.train.AdamOptimizer(lr),
301-
callbacks=Callbacks([StatPrinter(), PeriodicCallback(ModelSaver(), 3)])]),
301+
callbacks=[PeriodicCallback(ModelSaver(), 3)],
302302
model=Model(),
303-
step_per_epoch=dataset.size(),
303+
steps_per_epoch=dataset.size(),
304304
max_epoch=100,
305305
)
306306
```

examples/A3C-Gym/train-atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_config():
219219
],
220220
session_config=get_default_sess_config(0.5),
221221
model=M,
222-
step_per_epoch=STEP_PER_EPOCH,
222+
steps_per_epoch=STEP_PER_EPOCH,
223223
max_epoch=1000,
224224
)
225225

examples/CTC-TIMIT/train-timit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_data(path, isTrain, stat_file):
8989

9090

9191
def get_config(ds_train, ds_test):
92-
step_per_epoch = ds_train.size()
92+
steps_per_epoch = ds_train.size()
9393

9494
lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True)
9595

@@ -105,7 +105,7 @@ def get_config(ds_train, ds_test):
105105
InferenceRunner(ds_test, [ScalarStats('error')]), 2),
106106
],
107107
model=Model(),
108-
step_per_epoch=step_per_epoch,
108+
steps_per_epoch=steps_per_epoch,
109109
max_epoch=70,
110110
)
111111

examples/Char-RNN/char-rnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_config():
103103

104104
ds = CharRNNData(param.corpus, 100000)
105105
ds = BatchData(ds, param.batch_size)
106-
step_per_epoch = ds.size()
106+
steps_per_epoch = ds.size()
107107

108108
lr = symbolic_functions.get_scalar_var('learning_rate', 2e-3, summary=True)
109109

@@ -115,7 +115,7 @@ def get_config():
115115
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
116116
],
117117
model=Model(),
118-
step_per_epoch=step_per_epoch,
118+
steps_per_epoch=steps_per_epoch,
119119
max_epoch=50,
120120
)
121121

examples/DeepQNetwork/DQN.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_config():
190190
# save memory for multiprocess evaluator
191191
session_config=get_default_sess_config(0.6),
192192
model=M,
193-
step_per_epoch=STEP_PER_EPOCH,
193+
steps_per_epoch=STEP_PER_EPOCH,
194194
)
195195

196196

examples/DoReFa-Net/alexnet-dorefa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def get_config():
247247
ClassificationError('wrong-top5', 'val-error-top5')])
248248
],
249249
model=Model(),
250-
step_per_epoch=10000,
250+
steps_per_epoch=10000,
251251
max_epoch=100,
252252
)
253253

examples/DoReFa-Net/svhn-digit-dorefa.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_config():
147147
data_train = AugmentImageComponent(data_train, augmentors)
148148
data_train = BatchData(data_train, 128)
149149
data_train = PrefetchDataZMQ(data_train, 5)
150-
step_per_epoch = data_train.size()
150+
steps_per_epoch = data_train.size()
151151

152152
augmentors = [imgaug.Resize((40, 40))]
153153
data_test = AugmentImageComponent(data_test, augmentors)
@@ -169,7 +169,7 @@ def get_config():
169169
[ScalarStats('cost'), ClassificationError()])
170170
],
171171
model=Model(),
172-
step_per_epoch=step_per_epoch,
172+
steps_per_epoch=steps_per_epoch,
173173
max_epoch=200,
174174
)
175175

examples/GAN/DCGAN-CelebA.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def get_config():
112112
callbacks=[ModelSaver()],
113113
session_config=get_default_sess_config(0.5),
114114
model=Model(),
115-
step_per_epoch=300,
115+
steps_per_epoch=300,
116116
max_epoch=200,
117117
)
118118

examples/GAN/Image2Image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def get_config():
173173
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
174174
],
175175
model=Model(),
176-
step_per_epoch=dataset.size(),
176+
steps_per_epoch=dataset.size(),
177177
max_epoch=300,
178178
)
179179

examples/GAN/InfoGAN-mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_config():
161161
callbacks=[ModelSaver()],
162162
session_config=get_default_sess_config(0.5),
163163
model=Model(),
164-
step_per_epoch=500,
164+
steps_per_epoch=500,
165165
max_epoch=100,
166166
)
167167

examples/HED/hed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def view_data():
166166
def get_config():
167167
logger.auto_set_dir()
168168
dataset_train = get_data('train')
169-
step_per_epoch = dataset_train.size() * 40
169+
steps_per_epoch = dataset_train.size() * 40
170170
dataset_val = get_data('val')
171171

172172
lr = get_scalar_var('learning_rate', 3e-5, summary=True)
@@ -181,7 +181,7 @@ def get_config():
181181
BinaryClassificationStats('prediction', 'edgemap4d'))
182182
],
183183
model=Model(),
184-
step_per_epoch=step_per_epoch,
184+
steps_per_epoch=steps_per_epoch,
185185
max_epoch=100,
186186
)
187187

examples/Inception/inception-bn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def get_config():
153153
logger.auto_set_dir()
154154
# prepare dataset
155155
dataset_train = get_data('train')
156-
step_per_epoch = 5000
156+
steps_per_epoch = 5000
157157
dataset_val = get_data('val')
158158

159159
lr = get_scalar_var('learning_rate', 0.045, summary=True)
@@ -172,7 +172,7 @@ def get_config():
172172
],
173173
session_config=get_default_sess_config(0.99),
174174
model=Model(),
175-
step_per_epoch=step_per_epoch,
175+
steps_per_epoch=steps_per_epoch,
176176
max_epoch=80,
177177
)
178178

examples/Inception/inceptionv3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def get_config():
281281
],
282282
session_config=get_default_sess_config(0.9),
283283
model=Model(),
284-
step_per_epoch=5000,
284+
steps_per_epoch=5000,
285285
max_epoch=100,
286286
)
287287

examples/PennTreebank/PTB-LSTM.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def get_config():
110110
data3, wd2id = get_PennTreeBank()
111111
global VOCAB_SIZE
112112
VOCAB_SIZE = len(wd2id)
113-
step_per_epoch = (data3[0].shape[0] // BATCH - 1) // SEQ_LEN
113+
steps_per_epoch = (data3[0].shape[0] // BATCH - 1) // SEQ_LEN
114114

115115
train_data = TensorInput(
116116
lambda: ptb_producer(data3[0], BATCH, SEQ_LEN),
117-
step_per_epoch)
117+
steps_per_epoch)
118118
val_data = TensorInput(
119119
lambda: ptb_producer(data3[1], BATCH, SEQ_LEN),
120120
(data3[1].shape[0] // BATCH - 1) // SEQ_LEN)

examples/ResNet/cifar10-resnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_config():
134134

135135
# prepare dataset
136136
dataset_train = get_data('train')
137-
step_per_epoch = dataset_train.size()
137+
steps_per_epoch = dataset_train.size()
138138
dataset_test = get_data('test')
139139

140140
lr = get_scalar_var('learning_rate', 0.01, summary=True)
@@ -149,7 +149,7 @@ def get_config():
149149
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
150150
],
151151
model=Model(n=NUM_UNITS),
152-
step_per_epoch=step_per_epoch,
152+
steps_per_epoch=steps_per_epoch,
153153
max_epoch=400,
154154
)
155155

examples/ResNet/imagenet-resnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def get_config():
199199
HumanHyperParamSetter('learning_rate'),
200200
],
201201
model=Model(),
202-
step_per_epoch=5000,
202+
steps_per_epoch=5000,
203203
max_epoch=110,
204204
)
205205

examples/ResNet/svhn-resnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_config():
6363

6464
# prepare dataset
6565
dataset_train = get_data('train')
66-
step_per_epoch = dataset_train.size()
66+
steps_per_epoch = dataset_train.size()
6767
dataset_test = get_data('test')
6868

6969
lr = get_scalar_var('learning_rate', 0.01, summary=True)
@@ -78,7 +78,7 @@ def get_config():
7878
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
7979
],
8080
model=Model(n=18),
81-
step_per_epoch=step_per_epoch,
81+
steps_per_epoch=steps_per_epoch,
8282
max_epoch=500,
8383
)
8484

examples/SimilarityLearning/mnist-embeddings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def get_config(model):
133133
logger.auto_set_dir()
134134

135135
dataset = model.get_data()
136-
step_per_epoch = dataset.size()
136+
steps_per_epoch = dataset.size()
137137

138138
lr = symbf.get_scalar_var('learning_rate', 1e-4, summary=True)
139139

@@ -145,7 +145,7 @@ def get_config(model):
145145
ModelSaver(),
146146
ScheduledHyperParamSetter('learning_rate', [(10, 1e-5), (20, 1e-6)])
147147
],
148-
step_per_epoch=step_per_epoch,
148+
steps_per_epoch=steps_per_epoch,
149149
max_epoch=20,
150150
)
151151

examples/SpatialTransformer/mnist-addition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def get_config():
148148
logger.auto_set_dir()
149149

150150
dataset_train, dataset_test = get_data(True), get_data(False)
151-
step_per_epoch = dataset_train.size() * 5
151+
steps_per_epoch = dataset_train.size() * 5
152152

153153
lr = symbf.get_scalar_var('learning_rate', 5e-4, summary=True)
154154

@@ -163,7 +163,7 @@ def get_config():
163163
],
164164
session_config=get_default_sess_config(0.5),
165165
model=Model(),
166-
step_per_epoch=step_per_epoch,
166+
steps_per_epoch=steps_per_epoch,
167167
max_epoch=500,
168168
)
169169

examples/cifar-convnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_config(cifar_classnum):
107107

108108
# prepare dataset
109109
dataset_train = get_data('train', cifar_classnum)
110-
step_per_epoch = dataset_train.size()
110+
steps_per_epoch = dataset_train.size()
111111
dataset_test = get_data('test', cifar_classnum)
112112

113113
sess_config = get_default_sess_config(0.5)
@@ -130,7 +130,7 @@ def lr_func(lr):
130130
],
131131
session_config=sess_config,
132132
model=Model(cifar_classnum),
133-
step_per_epoch=step_per_epoch,
133+
steps_per_epoch=steps_per_epoch,
134134
max_epoch=150,
135135
)
136136

examples/mnist-convnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_config():
124124

125125
dataset_train, dataset_test = get_data()
126126
# how many iterations you want in each epoch
127-
step_per_epoch = dataset_train.size()
127+
steps_per_epoch = dataset_train.size()
128128

129129
lr = tf.train.exponential_decay(
130130
learning_rate=1e-3,
@@ -148,7 +148,7 @@ def get_config():
148148
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
149149
],
150150
model=Model(),
151-
step_per_epoch=step_per_epoch,
151+
steps_per_epoch=steps_per_epoch,
152152
max_epoch=100,
153153
)
154154

examples/svhn-digit-convnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_config():
8989
logger.auto_set_dir()
9090

9191
data_train, data_test = get_data()
92-
step_per_epoch = data_train.size()
92+
steps_per_epoch = data_train.size()
9393

9494
lr = tf.train.exponential_decay(
9595
learning_rate=1e-3,
@@ -107,7 +107,7 @@ def get_config():
107107
[ScalarStats('cost'), ClassificationError()])
108108
],
109109
model=Model(),
110-
step_per_epoch=step_per_epoch,
110+
steps_per_epoch=steps_per_epoch,
111111
max_epoch=350,
112112
)
113113

tensorpack/models/model_desc.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from abc import ABCMeta, abstractmethod
77
import tensorflow as tf
8-
import inspect
98
import pickle
109
import six
1110

@@ -102,12 +101,7 @@ def build_graph(self, model_inputs):
102101
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to
103102
InputVars of this model.
104103
"""
105-
if len(inspect.getargspec(self._build_graph).args) == 3:
106-
logger.warn("[DEPRECATED] _build_graph(self, input_vars, is_training) is deprecated! \
107-
Use _build_graph(self, input_vars) and get_current_tower_context().is_training instead.")
108-
self._build_graph(model_inputs, get_current_tower_context().is_training)
109-
else:
110-
self._build_graph(model_inputs)
104+
self._build_graph(model_inputs)
111105

112106
@abstractmethod
113107
def _build_graph(self, inputs):

0 commit comments

Comments
 (0)