Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Track epoch metric separately
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Sep 4, 2018
1 parent 8e4aeee commit 00ec53e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
10 changes: 7 additions & 3 deletions python/mxnet/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,13 @@ def __call__(self, param):
name_value = param.eval_metric.get_name_value()
if self.auto_reset:
param.eval_metric.reset()
msg = 'Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec'
msg += '\t%s=%f'*len(name_value)
logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec'
msg += '\t%s=%f'*len(name_value)
logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
else:
msg = 'Epoch[%d] Batch [0-%d]\tSpeed: %.2f samples/sec'
msg += '\t%s=%f'*len(name_value)
logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
else:
logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
param.epoch, count, speed)
Expand Down
11 changes: 9 additions & 2 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,9 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',

if validation_metric is None:
validation_metric = eval_metric
epoch_eval_metric = eval_metric
if not isinstance(eval_metric, metric.EvalMetric):
epoch_eval_metric = metric.create(eval_metric)
eval_metric = metric.create(eval_metric)

################################################################################
Expand All @@ -514,6 +516,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
epoch_eval_metric.reset()
nbatch = 0
data_iter = iter(train_data)
end_of_batch = False
Expand All @@ -529,8 +532,12 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
self.update_metric(eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
self.update_metric(epoch_eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
else:
self.update_metric(eval_metric, data_batch.label)
self.update_metric(epoch_eval_metric, data_batch.label)

try:
# pre fetch next batch
Expand All @@ -543,7 +550,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
monitor.toc_print()

if end_of_batch:
eval_name_vals = eval_metric.get_name_value()
eval_name_vals = epoch_eval_metric.get_name_value()

if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
Expand All @@ -555,7 +562,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',

# one epoch of training is finished
for name, val in eval_name_vals:
self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
self.logger.info('Epoch[%d] Train-%s (averaged over entire epoch)=%f', epoch, name, val)
toc = time.time()
self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))

Expand Down

0 comments on commit 00ec53e

Please sign in to comment.