Skip to content

Commit

Permalink
Fix flaky test test_global_metric (apache#15756)
Browse files Browse the repository at this point in the history
* Fix flaky test test_global_metric

* Retrigger CI

* retrigger CI

* retrigger CI

* retrigger CI
  • Loading branch information
ptrendx authored and Ubuntu committed Aug 20, 2019
1 parent 1906f88 commit 90e3d3f
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mxnet as mx
import numpy as np
import json
import math
from common import with_seed
from copy import deepcopy

Expand Down Expand Up @@ -56,6 +57,21 @@ def _create_pred_label():
label[:shape[0] // 2] = 0
return pred, label

def _compare_metric_result(m1, m2):
# Compare names
assert m1[0] == m2[0]
# Compare values
if isinstance(m1[1], (list, tuple)):
assert len(m1[1]) == len(m2[1])
for r1, r2 in zip(m1[1], m2[1]):
assert r1 == r2 or \
(math.isnan(r1) and
math.isnan(r2))
else:
assert m1[1] == m2[1] or \
(math.isnan(m1[1]) and
math.isnan(m2[1]))

shape = kwargs.pop('shape', (10,10))
use_same_shape = kwargs.pop('use_same_shape', False)
m1 = mx.metric.create(metric, *args, **kwargs)
Expand All @@ -78,7 +94,7 @@ def _create_pred_label():
pred, label = _create_pred_label()
m1.update([label], [pred])
m2.update([label], [pred])
assert m1.get() == m2.get()
_compare_metric_result(m1.get(), m2.get())

@with_seed()
def test_global_metric():
Expand Down

0 comments on commit 90e3d3f

Please sign in to comment.