Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

remove NLL in gluon.metric #18794

Merged
merged 1 commit into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 13 additions & 78 deletions python/mxnet/gluon/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,12 @@ class :math:`k`.
Index of invalid label to ignore when
counting. By default, sets to -1.
If set to `None`, it will include all entries.
axis : int (default -1)
axis : int, default -1
The axis from prediction that was used to
compute softmax. By default use the last axis.
from_logits : boolean, default False
Whether `pred` is expected to be a logits tensor.
By default, we assume that `pred` encodes a probability distribution.
name : str
Name of this metric instance for display.
output_names : list of str, or None
Expand All @@ -1373,12 +1376,13 @@ class :math:`k`.
>>> print ce.get()
('cross-entropy', 0.57159948348999023)
"""
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, name='cross-entropy',
output_names=None, label_names=None):
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, from_logits=False,
name='cross-entropy', output_names=None, label_names=None):
super(CrossEntropy, self).__init__(
name, output_names=output_names, label_names=label_names)
self.ignore_label = ignore_label
self.axis = axis
self.from_logits = from_logits
self.eps = eps

def update(self, labels, preds):
Expand All @@ -1400,6 +1404,8 @@ def update(self, labels, preds):
assert label.size == pred.size/pred.shape[-1], \
"shape mismatch: %s vs. %s"%(label.shape, pred.shape)
label = label.reshape((label.size,))
if self.from_logits:
pred = ndarray.softmax(pred, axis=self.axis)
pred = ndarray.pick(pred.as_in_context(label.ctx), label.astype(dtype='int32'), axis=self.axis)
label = label.as_np_ndarray()
pred = pred.as_np_ndarray()
Expand Down Expand Up @@ -1469,11 +1475,11 @@ class Perplexity(CrossEntropy):
>>> print perp.get()
('Perplexity', 1.7710976285155853)
"""
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, name='perplexity',
output_names=None, label_names=None):
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, from_logits=False,
name='perplexity', output_names=None, label_names=None):
super(Perplexity, self).__init__(
name=name, eps=eps, ignore_label=ignore_label, axis=axis,
output_names=output_names, label_names=label_names)
eps=eps, ignore_label=ignore_label, axis=axis, from_logits=from_logits,
name=name, output_names=output_names, label_names=label_names)

def get(self):
if self.num_inst == 0:
Expand All @@ -1482,77 +1488,6 @@ def get(self):
return (self.name, math.exp(self.sum_metric/self.num_inst))


@register
@alias('nll_loss')
@use_np
class NegativeLogLikelihood(EvalMetric):
"""Computes the negative log-likelihood loss.

The negative log-likelihoodd loss over a batch of sample size :math:`N` is given by

.. math::
-\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}),

where :math:`K` is the number of classes, :math:`y_{nk}` is the prediceted probability for
:math:`k`-th class for :math:`n`-th sample. :math:`t_{nk}=1` if and only if sample
:math:`n` belongs to class :math:`k`.

Parameters
----------
eps : float
Negative log-likelihood loss is undefined for predicted value is 0,
so predicted values are added with the small constant.
name : str
Name of this metric instance for display.
output_names : list of str, or None
Name of predictions that should be used when updating with update_dict.
By default include all predictions.
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.

Examples
--------
>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels = [mx.nd.array([0, 1, 1])]
>>> nll_loss = mx.gluon.metric.NegativeLogLikelihood()
>>> nll_loss.update(labels, predicts)
>>> print nll_loss.get()
('nll-loss', 0.57159948348999023)
"""
def __init__(self, eps=1e-12, name='nll-loss',
output_names=None, label_names=None):
super(NegativeLogLikelihood, self).__init__(
name, eps=eps,
output_names=output_names, label_names=label_names)
self.eps = eps

def update(self, labels, preds):
"""Updates the internal evaluation result.

Parameters
----------
labels : list of `NDArray`
The labels of the data.

preds : list of `NDArray`
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.as_np_ndarray()
pred = pred.as_np_ndarray().as_in_ctx(label.ctx)

label = label.reshape(-1)
num_examples = pred.shape[0]
assert label.shape[0] == num_examples, (label.shape[0], num_examples)
prob = pred[numpy.arange(num_examples, dtype=numpy.int64), numpy.int64(label)]
nll = (-numpy.log(prob + self.eps)).sum()
self.sum_metric += nll
self.num_inst += num_examples


@register
@alias('pearsonr')
@use_np
Expand Down
20 changes: 13 additions & 7 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,25 @@ def test_metrics():
check_metric('perplexity', axis=-1)
check_metric('pearsonr')
check_metric('pcc')
check_metric('nll_loss')
check_metric('ce')
check_metric('loss')
composite = mx.gluon.metric.create(['acc', 'f1'])
check_metric(composite)

def test_nll_loss():
metric = mx.gluon.metric.create('nll_loss')
def test_ce():
metric = mx.gluon.metric.create('ce')
pred = mx.nd.array([[0.2, 0.3, 0.5], [0.6, 0.1, 0.3]])
label = mx.nd.array([2, 1])
metric.update([label], [pred])
_, loss = metric.get()
expected_loss = -(np.log(pred[0][2].asscalar()) + np.log(pred[1][1].asscalar())) / 2
assert loss == expected_loss
metric = mx.gluon.metric.create('ce', from_logits=True)
pred = mx.nd.log(pred)
metric.update([label], [pred])
_, loss = metric.get()
np.testing.assert_almost_equal(loss, expected_loss)


def test_acc():
pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
Expand Down Expand Up @@ -159,15 +165,15 @@ def test_multiclass_f1():
macroF1.update([label11, label12], [pred11, pred12])
assert microF1.num_inst == 6
assert macroF1.num_inst == 6

# from sklearn.metrics import f1_score
# overall_pred = [0, 1, 2, 0, 1, 2]
# overall_label = [0, 2, 1, 0, 0, 1]
fmacro = 0.26666666666666666 #f1_score(overall_label, overall_pred, average="macro")
fmicro = 0.3333333333333333 #f1_score(overall_label, overall_pred, average="micro")
np.testing.assert_almost_equal(microF1.get()[1], fmicro)
np.testing.assert_almost_equal(macroF1.get()[1], fmacro)

@xfail_when_nonstandard_decimal_separator
def test_multilabel_f1():
microF1 = mx.gluon.metric.create("f1", class_type="multilabel", average="micro")
Expand All @@ -183,7 +189,7 @@ def test_multilabel_f1():
macroF1.update([label], [pred])
microF1.update([label], [pred])
assert macroF1.get()[1] == 0.5 # one class is 1.0, the other is 0. (divided by 0)
np.testing.assert_almost_equal(microF1.get()[1], 2.0 / 3)
np.testing.assert_almost_equal(microF1.get()[1], 2.0 / 3)
macroF1.reset()
microF1.reset()

Expand All @@ -209,7 +215,7 @@ def test_mcc():
microMCC = mx.gluon.metric.create("mcc")

assert np.isnan(microMCC.get()[1])

# check divide by zero
pred = mx.nd.array([[0.9, 0.1],
[0.8, 0.2]])
Expand Down