From 7f06bfa866335ee41635429b4e66eaf5ad685cfc Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Wed, 20 Mar 2019 18:44:00 -0700 Subject: [PATCH] =?UTF-8?q?Fixed=20issue=20where=20the=20estimator=20was?= =?UTF-8?q?=20printing=20beyond=20the=20dataset=20size=20=E2=80=A6=20(#144?= =?UTF-8?q?64)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI --- python/mxnet/gluon/estimator/estimator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index 159f7e220427..c160115ac807 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -242,7 +242,11 @@ def fit(self, train_data, self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] try: - self.train_stats['step'] = "{}/{}".format(batch_size * (i + 1), len(train_data._dataset)) + completed_samples = len(train_data._dataset) if i == len(train_data._dataset) - 1 \ + else batch_size * (i + 1) + # We need to check if this is the last batch in the current epoch and select + # the value to print appropriately + self.train_stats['step'] = "{}/{}".format(completed_samples, len(train_data._dataset)) except AttributeError: self.train_stats['step'] = i