From f477cfa6bead62f80d11e7539153fc48230c8ace Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 18 Sep 2019 19:46:09 +0530 Subject: [PATCH] assert_allclose -> rtol=1e-10 --- tests/python/unittest/test_ndarray.py | 33 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 7091abf7308a..7be59df6efda 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -21,6 +21,7 @@ from itertools import permutations, combinations_with_replacement import os import pickle as pkl +import functools from nose.tools import assert_raises, raises from common import with_seed, assertRaises, TemporaryDirectory from mxnet.test_utils import almost_equal @@ -1887,14 +1888,16 @@ def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_excep check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False) -@with_seed() -def test_update_ops_mutation(): - def assert_mutate(x, y, op): +def _test_update_ops_mutation_impl(): + assert_allclose = functools.partial( + np.testing.assert_allclose, rtol=1e-10) + + def assert_mutate(x, y): np.testing.assert_raises( - AssertionError, np.testing.assert_allclose, x, y) + AssertionError, assert_allclose, x, y) - def assert_unchanged(x, y, op): - np.testing.assert_allclose(x, y) + def assert_unchanged(x, y): + assert_allclose(x, y) def test_op(op, num_inputs, mutated_inputs, **kwargs): for dim in range(1, 7): @@ -1919,9 +1922,9 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs): for idx, (pre_array, post_array) in \ enumerate(zip(pre_arrays, post_arrays)): if idx in mutated_inputs: - assert_mutate(pre_array, post_array, op) + assert_mutate(pre_array, post_array) else: - assert_unchanged(pre_array, post_array, op) + assert_unchanged(pre_array, post_array) test_op(mx.nd.signsgd_update, 2, [0], ** {'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3, @@ -1952,6 +1955,20 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs): {'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3}) +@with_seed() +def test_update_ops_mutation(): + _test_update_ops_mutation_impl() + + +# Problem : +# https://github.com/apache/incubator-mxnet/pull/15768#issuecomment-532046408 +@with_seed(412298777) +def test_update_ops_mutation_failed_seed(): + # The difference was -5.9604645e-08 which was + # lower than then `rtol` of 1e-07 + _test_update_ops_mutation_impl() + + def test_large_int_rounding(): large_integer = 50000001