Skip to content

Commit

Permalink
Track epoch metric separately (apache#12182)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk authored and sandeep-krishnamurthy committed Sep 19, 2018
1 parent 73d8897 commit ce6525a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
10 changes: 7 additions & 3 deletions python/mxnet/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,13 @@ def __call__(self, param):
name_value = param.eval_metric.get_name_value()
if self.auto_reset:
param.eval_metric.reset()
msg = 'Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec'
msg += '\t%s=%f'*len(name_value)
logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec'
msg += '\t%s=%f'*len(name_value)
logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
else:
msg = 'Epoch[%d] Batch [0-%d]\tSpeed: %.2f samples/sec'
msg += '\t%s=%f'*len(name_value)
logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
else:
logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
param.epoch, count, speed)
Expand Down
9 changes: 8 additions & 1 deletion python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import time
import logging
import warnings
import copy
import numpy as np

from .. import metric
Expand Down Expand Up @@ -507,13 +508,15 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
validation_metric = eval_metric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
epoch_eval_metric = copy.deepcopy(eval_metric)

################################################################################
# training loop
################################################################################
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
epoch_eval_metric.reset()
nbatch = 0
data_iter = iter(train_data)
end_of_batch = False
Expand All @@ -529,8 +532,12 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
self.update_metric(eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
self.update_metric(epoch_eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
else:
self.update_metric(eval_metric, data_batch.label)
self.update_metric(epoch_eval_metric, data_batch.label)

try:
# pre fetch next batch
Expand All @@ -543,7 +550,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
monitor.toc_print()

if end_of_batch:
eval_name_vals = eval_metric.get_name_value()
eval_name_vals = epoch_eval_metric.get_name_value()

if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
Expand Down

0 comments on commit ce6525a

Please sign in to comment.