diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index d8dca753bda4..a70ffdb88987 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -18,6 +18,7 @@ import mxnet as mx import numpy as np import json +import math from common import with_seed from copy import deepcopy @@ -56,6 +57,21 @@ def _create_pred_label(): label[:shape[0] // 2] = 0 return pred, label + def _compare_metric_result(m1, m2): + # Compare names + assert m1[0] == m2[0] + # Compare values + if isinstance(m1[1], (list, tuple)): + assert len(m1[1]) == len(m2[1]) + for r1, r2 in zip(m1[1], m2[1]): + assert r1 == r2 or \ + (math.isnan(r1) and + math.isnan(r2)) + else: + assert m1[1] == m2[1] or \ + (math.isnan(m1[1]) and + math.isnan(m2[1])) + shape = kwargs.pop('shape', (10,10)) use_same_shape = kwargs.pop('use_same_shape', False) m1 = mx.metric.create(metric, *args, **kwargs) @@ -78,7 +94,7 @@ def _create_pred_label(): pred, label = _create_pred_label() m1.update([label], [pred]) m2.update([label], [pred]) - assert m1.get() == m2.get() + _compare_metric_result(m1.get(), m2.get()) @with_seed() def test_global_metric():