Skip to content

Commit

Permalink
add softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei authored and piyushghai committed Jul 30, 2019
1 parent 75ec743 commit 3b1b185
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from ...data import DataLoader
from ...loss import SoftmaxCrossEntropyLoss
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
Expand Down Expand Up @@ -184,8 +185,8 @@ def prepare_loss_and_metrics(self):
"""
if any(not hasattr(self, attribute) for attribute in
['train_metrics', 'val_metrics']):
# Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
# Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss()
if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
Expand Down

0 comments on commit 3b1b185

Please sign in to comment.