From 3b1b1853b9859599999fd3bc8a5c54fa8353b71f Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 29 Jul 2019 23:29:37 -0700 Subject: [PATCH] add softmax --- python/mxnet/gluon/contrib/estimator/estimator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 9f52e64852e7..5e3804784ba8 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -25,6 +25,7 @@ from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd from ...data import DataLoader +from ...loss import SoftmaxCrossEntropyLoss from ...loss import Loss as gluon_loss from ...trainer import Trainer from ...utils import split_and_load @@ -184,8 +185,8 @@ def prepare_loss_and_metrics(self): """ if any(not hasattr(self, attribute) for attribute in ['train_metrics', 'val_metrics']): - # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() - if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): + # Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss() + if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]): self.train_metrics = [Accuracy()] self.val_metrics = [] for loss in self.loss: