diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 4ddef69ebd65..595045860c33 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -180,8 +180,10 @@ def prepare_loss_and_metrics(self): Based on loss functions and training metrics in estimator Create metric wrappers to record loss values, Create copies of train loss/metric objects to record validation values - Returns train_metrics and val_metrics + Returns + ------- + train_metrics, val_metrics """ if any(not hasattr(self, attribute) for attribute in ['train_metrics', 'val_metrics']): @@ -264,10 +266,21 @@ def fit_batch(self, train_batch, Data and label of a batch from the training data loader. batch_axis : int, default 0 Batch axis to split the training data into devices. + + Returns + ------- + data: List of NDArray + Sharded data from the batch. + label: List of NDArray + Sharded label from the batch. + pred: List of NDArray + Prediction of each of the shareded batch. + loss: List of NDArray + Loss of each of the shareded batch. """ data, label = self._get_data_and_label(train_batch, self.context, batch_axis) - batch_size = data.shape[batch_axis] + batch_size = train_batch[0].shape[batch_axis] with autograd.record(): pred = [self.net(x) for x in data] @@ -278,6 +291,8 @@ def fit_batch(self, train_batch, self.trainer.step(batch_size) + return data, label, pred, loss + def fit(self, train_data, val_data=None, epochs=None, @@ -346,7 +361,7 @@ def fit(self, train_data, for handler in batch_begin: handler.batch_begin(estimator_ref, batch=batch) - self.fit_batch(batch, batch_axis) + _, label, pred, loss = self.fit_batch(batch, batch_axis) # batch end