diff --git a/python/mxnet/gluon/metric.py b/python/mxnet/gluon/metric.py index 08ee1411824c..766a95a4d6db 100644 --- a/python/mxnet/gluon/metric.py +++ b/python/mxnet/gluon/metric.py @@ -1352,9 +1352,12 @@ class :math:`k`. Index of invalid label to ignore when counting. By default, sets to -1. If set to `None`, it will include all entries. - axis : int (default -1) + axis : int, default -1 The axis from prediction that was used to compute softmax. By default use the last axis. + from_logits : boolean, default False + Whether `pred` is expected to be a logits tensor. + By default, we assume that `pred` encodes a probability distribution. name : str Name of this metric instance for display. output_names : list of str, or None @@ -1373,12 +1376,13 @@ class :math:`k`. >>> print ce.get() ('cross-entropy', 0.57159948348999023) """ - def __init__(self, eps=1e-12, ignore_label=None, axis=-1, name='cross-entropy', - output_names=None, label_names=None): + def __init__(self, eps=1e-12, ignore_label=None, axis=-1, from_logits=False, + name='cross-entropy', output_names=None, label_names=None): super(CrossEntropy, self).__init__( name, output_names=output_names, label_names=label_names) self.ignore_label = ignore_label self.axis = axis + self.from_logits = from_logits self.eps = eps def update(self, labels, preds): @@ -1400,6 +1404,8 @@ def update(self, labels, preds): assert label.size == pred.size/pred.shape[-1], \ "shape mismatch: %s vs. %s"%(label.shape, pred.shape) label = label.reshape((label.size,)) + if self.from_logits: + pred = ndarray.softmax(pred, axis=self.axis) pred = ndarray.pick(pred.as_in_context(label.ctx), label.astype(dtype='int32'), axis=self.axis) label = label.as_np_ndarray() pred = pred.as_np_ndarray() @@ -1469,11 +1475,11 @@ class Perplexity(CrossEntropy): >>> print perp.get() ('Perplexity', 1.7710976285155853) """ - def __init__(self, eps=1e-12, ignore_label=None, axis=-1, name='perplexity', - output_names=None, label_names=None): + def __init__(self, eps=1e-12, ignore_label=None, axis=-1, from_logits=False, + name='perplexity', output_names=None, label_names=None): super(Perplexity, self).__init__( - name=name, eps=eps, ignore_label=ignore_label, axis=axis, - output_names=output_names, label_names=label_names) + eps=eps, ignore_label=ignore_label, axis=axis, from_logits=from_logits, + name=name, output_names=output_names, label_names=label_names) def get(self): if self.num_inst == 0: @@ -1482,77 +1488,6 @@ def get(self): return (self.name, math.exp(self.sum_metric/self.num_inst)) -@register -@alias('nll_loss') -@use_np -class NegativeLogLikelihood(EvalMetric): - """Computes the negative log-likelihood loss. - - The negative log-likelihoodd loss over a batch of sample size :math:`N` is given by - - .. math:: - -\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}), - - where :math:`K` is the number of classes, :math:`y_{nk}` is the prediceted probability for - :math:`k`-th class for :math:`n`-th sample. :math:`t_{nk}=1` if and only if sample - :math:`n` belongs to class :math:`k`. - - Parameters - ---------- - eps : float - Negative log-likelihood loss is undefined for predicted value is 0, - so predicted values are added with the small constant. - name : str - Name of this metric instance for display. - output_names : list of str, or None - Name of predictions that should be used when updating with update_dict. - By default include all predictions. - label_names : list of str, or None - Name of labels that should be used when updating with update_dict. - By default include all labels. - - Examples - -------- - >>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])] - >>> labels = [mx.nd.array([0, 1, 1])] - >>> nll_loss = mx.gluon.metric.NegativeLogLikelihood() - >>> nll_loss.update(labels, predicts) - >>> print nll_loss.get() - ('nll-loss', 0.57159948348999023) - """ - def __init__(self, eps=1e-12, name='nll-loss', - output_names=None, label_names=None): - super(NegativeLogLikelihood, self).__init__( - name, eps=eps, - output_names=output_names, label_names=label_names) - self.eps = eps - - def update(self, labels, preds): - """Updates the internal evaluation result. - - Parameters - ---------- - labels : list of `NDArray` - The labels of the data. - - preds : list of `NDArray` - Predicted values. - """ - labels, preds = check_label_shapes(labels, preds, True) - - for label, pred in zip(labels, preds): - label = label.as_np_ndarray() - pred = pred.as_np_ndarray().as_in_ctx(label.ctx) - - label = label.reshape(-1) - num_examples = pred.shape[0] - assert label.shape[0] == num_examples, (label.shape[0], num_examples) - prob = pred[numpy.arange(num_examples, dtype=numpy.int64), numpy.int64(label)] - nll = (-numpy.log(prob + self.eps)).sum() - self.sum_metric += nll - self.num_inst += num_examples - - @register @alias('pearsonr') @use_np diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 88b9d9cedce2..d66f4e97d708 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -39,19 +39,25 @@ def test_metrics(): check_metric('perplexity', axis=-1) check_metric('pearsonr') check_metric('pcc') - check_metric('nll_loss') + check_metric('ce') check_metric('loss') composite = mx.gluon.metric.create(['acc', 'f1']) check_metric(composite) -def test_nll_loss(): - metric = mx.gluon.metric.create('nll_loss') +def test_ce(): + metric = mx.gluon.metric.create('ce') pred = mx.nd.array([[0.2, 0.3, 0.5], [0.6, 0.1, 0.3]]) label = mx.nd.array([2, 1]) metric.update([label], [pred]) _, loss = metric.get() expected_loss = -(np.log(pred[0][2].asscalar()) + np.log(pred[1][1].asscalar())) / 2 assert loss == expected_loss + metric = mx.gluon.metric.create('ce', from_logits=True) + pred = mx.nd.log(pred) + metric.update([label], [pred]) + _, loss = metric.get() + np.testing.assert_almost_equal(loss, expected_loss) + def test_acc(): pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]]) @@ -159,7 +165,7 @@ def test_multiclass_f1(): macroF1.update([label11, label12], [pred11, pred12]) assert microF1.num_inst == 6 assert macroF1.num_inst == 6 - + # from sklearn.metrics import f1_score # overall_pred = [0, 1, 2, 0, 1, 2] # overall_label = [0, 2, 1, 0, 0, 1] @@ -167,7 +173,7 @@ def test_multiclass_f1(): fmicro = 0.3333333333333333 #f1_score(overall_label, overall_pred, average="micro") np.testing.assert_almost_equal(microF1.get()[1], fmicro) np.testing.assert_almost_equal(macroF1.get()[1], fmacro) - + @xfail_when_nonstandard_decimal_separator def test_multilabel_f1(): microF1 = mx.gluon.metric.create("f1", class_type="multilabel", average="micro") @@ -183,7 +189,7 @@ def test_multilabel_f1(): macroF1.update([label], [pred]) microF1.update([label], [pred]) assert macroF1.get()[1] == 0.5 # one class is 1.0, the other is 0. (divided by 0) - np.testing.assert_almost_equal(microF1.get()[1], 2.0 / 3) + np.testing.assert_almost_equal(microF1.get()[1], 2.0 / 3) macroF1.reset() microF1.reset() @@ -209,7 +215,7 @@ def test_mcc(): microMCC = mx.gluon.metric.create("mcc") assert np.isnan(microMCC.get()[1]) - + # check divide by zero pred = mx.nd.array([[0.9, 0.1], [0.8, 0.2]])