Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[fix] Update test_update_ops_mutation tolerance #16198

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from itertools import permutations, combinations_with_replacement
import os
import pickle as pkl
import functools
from nose.tools import assert_raises, raises
from common import with_seed, assertRaises, TemporaryDirectory
from mxnet.test_utils import almost_equal
Expand Down Expand Up @@ -1887,14 +1888,16 @@ def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_excep
check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)


@with_seed()
def test_update_ops_mutation():
def assert_mutate(x, y, op):
def _test_update_ops_mutation_impl():
assert_allclose = functools.partial(
np.testing.assert_allclose, rtol=1e-10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you making it more sensitive. Will that cause more failures going forward?
I think you need different comparison function which says check for equality upto this precision -
Something like - assert_array_almost_equal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxjscience @kshitij12345 sorry there was race between I writing this and Xingjian merging the change. Can you please look at this comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to lower the rtol because sometimes it fails the assert_raise check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC , with the current check is if difference is less than 1e-10 , test will fail. Is this correct ?
I think it should be I don't care about 7th(or 8th precision or anything above) precision, it can go arbitrary lower than what I care about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are assert_raise, which means that if the difference is larger than rtol, assert_allclose will fail and it will raise the AssertionError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vikas-kum @sxjscience Thanks for taking a look.

The new tolerance of assert_allclose (numpy doc mentions to use assert_allclose over assert_almost_equal) is now 1e-10, which means that it is more stricter now when checking for NDArray not being changed.

A consequence of this as it is used in assert_mutated is, it also allows smaller updates, i.e. -5.9604645e-08 (failure case update) to be valid change .

Thus now we have stricter check on whether any element of NDArray has changed or not (with stricter tolerance) while allowing smaller valid updates (i.e. anything above rtol of 1e-10 will be a valid update).

I think it should be I don't care about 7th(or 8th precision or anything above) precision, it can go arbitrary lower than what I care about.

(I don't know much about this but) With the instability of floating point number (rounding errors), I guess it is better to have some tolerance like 1e-10. I don't know if it is good idea to go to Machine Epsilon (lowest possible) as rtol for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitij12345 I think it should be safe to remove the test_mutate check.


def assert_mutate(x, y):
np.testing.assert_raises(
AssertionError, np.testing.assert_allclose, x, y)
AssertionError, assert_allclose, x, y)

def assert_unchanged(x, y, op):
np.testing.assert_allclose(x, y)
def assert_unchanged(x, y):
assert_allclose(x, y)

def test_op(op, num_inputs, mutated_inputs, **kwargs):
for dim in range(1, 7):
Expand All @@ -1919,9 +1922,9 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
for idx, (pre_array, post_array) in \
enumerate(zip(pre_arrays, post_arrays)):
if idx in mutated_inputs:
assert_mutate(pre_array, post_array, op)
assert_mutate(pre_array, post_array)
else:
assert_unchanged(pre_array, post_array, op)
assert_unchanged(pre_array, post_array)

test_op(mx.nd.signsgd_update, 2, [0], **
{'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3,
Expand Down Expand Up @@ -1952,6 +1955,20 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
{'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3})


@with_seed()
def test_update_ops_mutation():
_test_update_ops_mutation_impl()


# Problem :
# https://github.com/apache/incubator-mxnet/pull/15768#issuecomment-532046408
@with_seed(412298777)
def test_update_ops_mutation_failed_seed():
# The difference was -5.9604645e-08 which was
# lower than then `rtol` of 1e-07
_test_update_ops_mutation_impl()


def test_large_int_rounding():
large_integer = 50000001

Expand Down