Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Mxnet allclose #14443

Merged
merged 58 commits into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
98962ba
MXNET version of numpy.allclose operation implemented
drivanov Mar 15, 2019
dedd3d1
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
drivanov Mar 15, 2019
f61c1a6
Helper function test_bulking renamed to _test_bulking
drivanov Mar 18, 2019
f679a09
Will use mx.test_utils.assert_allclose and not a numpy version of sim…
drivanov Mar 18, 2019
ad9028a
Trigger CI
drivanov Mar 21, 2019
53173fd
Trigger CI
drivanov Mar 22, 2019
387838b
Will use mx.test_utils.assert_allclose
drivanov Mar 25, 2019
ac56251
Trigger CI
drivanov Mar 26, 2019
052788e
Merge branch 'master' into mxnet_allclose
drivanov Apr 1, 2019
25233aa
Trigger CI
drivanov Apr 2, 2019
cc5668e
Merge branch 'mxnet_allclose' of https://github.com/drivanov/incubato…
drivanov Apr 2, 2019
e05c449
Merge branch 'master' into mxnet_allclose
drivanov Apr 2, 2019
eba2fc7
Problem with missed _test_bulking() function fixed
drivanov Apr 2, 2019
45fbd13
Fixing minor bug in error reporting
drivanov Apr 5, 2019
49a14d9
Trigger CI
drivanov Apr 5, 2019
6fb0129
Trigger CI
drivanov Apr 16, 2019
6209aa8
retrigger CI
wkcn Apr 21, 2019
dcb65a1
Fixing problems in discrepancies printout in assert_almost_equal
drivanov Apr 23, 2019
a5addfa
Merge branch 'mxnet_allclose' of https://github.com/drivanov/incubato…
drivanov Apr 23, 2019
a1fb4be
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
drivanov Apr 23, 2019
eb2b6ce
Trigger CI
drivanov Apr 24, 2019
448bbc5
Trigger CI
drivanov Apr 24, 2019
6dfec82
Improved version of MxNet allclose operator
drivanov Apr 24, 2019
31c674a
Fixing minor problem in attribite definition for allclose operator
drivanov Apr 24, 2019
d07a83c
retrigger CI
wkcn Apr 28, 2019
21e57fa
Minor problem fixed
drivanov Apr 30, 2019
3b1fa3e
Merge branch 'mxnet_allclose' of https://github.com/drivanov/incubato…
drivanov Apr 30, 2019
4ae0d3a
Merge branch 'master' into mxnet_allclose
wkcn May 2, 2019
fa00c6a
Merge branch 'master' into mxnet_allclose
wkcn May 22, 2019
3f2b68e
Trigger CI
drivanov May 22, 2019
fd5209a
Merge branch 'master' into mxnet_allclose
wkcn Jun 9, 2019
255f7a5
try to fix find_max_violation
wkcn Jun 9, 2019
dde0e71
Merge branch 'mxnet_allclose' of github.com:drivanov/incubator-mxnet …
wkcn Jun 9, 2019
f203d9a
Merge branch 'master' of https://github.com/drivanov/incubator-mxnet
drivanov Jun 10, 2019
6b63706
Trigger CI
drivanov Jun 10, 2019
6a95b07
Merge branch 'mxnet_allclose' of https://github.com/drivanov/incubato…
drivanov Jun 10, 2019
391a3fb
Skip 'test_bulking_gluon_gpu()' test
drivanov Jun 19, 2019
f1fe0ad
Merge branch 'master' of https://github.com/apache/incubator-mxnet
drivanov Jul 9, 2019
310f37e
Merge branch 'master' of https://github.com/drivanov/incubator-mxnet …
drivanov Jul 9, 2019
3e65e9c
Fixing bug in reporting MaxErrors for NaN coordinates.
drivanov Jul 9, 2019
1ffcb52
use smaller testcase for test_layer_norm
wkcn Jul 16, 2019
e1c9c72
remove redundant test for test_layer_norm
wkcn Jul 16, 2019
b52b68f
reuse old testcase
wkcn Jul 16, 2019
90b5c15
retrigger CI
wkcn Jul 16, 2019
0e20d5e
Merge branch 'master' into mxnet_allclose
wkcn Jul 25, 2019
f2b3cd3
ci
wkcn Jul 26, 2019
7f4661f
ci
wkcn Jul 28, 2019
302f8c8
Merge branch 'master' into mxnet_allclose
drivanov Aug 15, 2019
36a8db3
Merge problem fixed
drivanov Aug 15, 2019
733ae87
Fixing Python's lint problem
drivanov Aug 15, 2019
5523467
Merge branch 'master' into mxnet_allclose
drivanov Sep 3, 2019
5733e6a
Trigger CI
drivanov Sep 9, 2019
0e6e9c6
Merge branch 'mxnet_allclose' of https://github.com/drivanov/incubato…
drivanov Sep 9, 2019
a329e09
Trigger CI
drivanov Sep 11, 2019
368ae7c
Trigger CI
drivanov Sep 20, 2019
ac725e0
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
drivanov Oct 15, 2019
3fd78b8
Trigger CI
drivanov Oct 15, 2019
f4f6bfe
Trigger CI
drivanov Oct 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 135 additions & 34 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import bz2
import zipfile
from contextlib import contextmanager
from collections import OrderedDict
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
Expand Down Expand Up @@ -103,6 +104,16 @@ def random_sample(population, k):
return population_copy[0:k]


def _sorted_items(d):
"""Return (key, value) pairs of dict 'd' in a deterministic order (sorted by key)."""
return sorted(d.items(), key=lambda t: t[0])


def _sorted_dict(d):
"""Return ordered dictionary containing items ordered by their keys."""
return OrderedDict(_sorted_items(d))


def _validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="uniform"):
"""Validates inputs for csr generation helper functions
Expand Down Expand Up @@ -461,9 +472,10 @@ def find_max_violation(a, b, rtol=None, atol=None):
"""Finds and returns the location of maximum violation."""
rtol = get_rtol(rtol)
atol = get_atol(atol)
diff = np.abs(a-b)
# 'smart' absdiff that considers inf's as equals (to match np.allclose)
absdiff = np.where(np.equal(a, b), 0, np.abs(a-b))
tol = atol + rtol*np.abs(b)
violation = diff/(tol+1e-20)
violation = absdiff/(tol+1e-20)
loc = np.argmax(violation)
idx = np.unravel_index(loc, violation.shape)
return idx, np.max(violation)
Expand All @@ -479,35 +491,106 @@ def same(a, b):
"""
return np.array_equal(a, b)


def almost_equal(a, b, rtol=None, atol=None, equal_nan=False):
"""Test if two numpy arrays are almost equal."""
# pylint: disable=unexpected-keyword-arg
return np.allclose(a, b, rtol=get_rtol(rtol), atol=get_atol(atol), equal_nan=equal_nan)
# pylint: enable=unexpected-keyword-arg

def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False):

def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False, mismatches=(10, 10)):
"""Test that two numpy arrays are almost equal. Raise exception message if not.

Parameters
----------
a : np.ndarray
b : np.ndarray
threshold : None or float
The checking threshold. Default threshold will be used if set to ``None``.
a : np.ndarray or mx.nd.array
b : np.ndarray or mx.nd.array
rtol : None or float
The relative threshold. Default threshold will be used if set to ``None``.
atol : None or float
The absolute threshold. Default threshold will be used if set to ``None``.
names : tuple of names, optional
The names used in error message when an exception occurs
equal_nan : boolean, optional
The flag determining how to treat NAN values in comparison
mismatches : tuple of mismatches
Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
"""
rtol = get_rtol(rtol)
atol = get_atol(atol)
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
if not use_np_allclose:
if not (hasattr(a, 'context') and hasattr(b, 'context') and a.context == b.context and a.dtype == b.dtype):
use_np_allclose = True
if isinstance(a, mx.nd.NDArray):
a = a.asnumpy()
if isinstance(b, mx.nd.NDArray):
b = b.asnumpy()

if use_np_allclose:
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
else:
output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan)
if output.asnumpy() == 1:
wkcn marked this conversation as resolved.
Show resolved Hide resolved
return

a = a.asnumpy()
b = b.asnumpy()

def locationError(a, b, index, names, maxError=False):
"""Create element mismatch comment

Parameters
----------
a, b : compared np.ndarray's
index : tuple of coordinate arrays
Location of violation
names : tuple of names
The names of compared arrays.
maxError: boolean, optional
Flag indicating that maximum error is reporting.
"""
maximum = "maximum " if maxError else ""
return "Location of %serror: %s, %s=%.8f, %s=%.8f" \
% (maximum, str(index), names[0], a[index], names[1], b[index])

index, rel = find_max_violation(a, b, rtol, atol)
indexErr = index
relErr = rel
actual = a.copy()

print('\n*** Maximum errors for vector of size {}: rtol={}, atol={}\n'.format(a.size, rtol, atol))
i = 1
while i < a.size:
if i <= mismatches[0]:
print("%3d: Error %f %s" %(i, rel, locationError(a, b, index, names)))

a[index] = b[index]
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
break

i += 1
if i <= mismatches[1] or mismatches[1] <= 0:
index, rel = find_max_violation(a, b, rtol, atol)
else:
break

mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e (mismatch %s%f%%).\n%s" % \
(relErr, rtol, atol, mismatchDegree, 100*i/a.size, \
locationError(actual, b, indexErr, names, maxError=True))
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. "
" Location of maximum error:%s, a=%f, b=%f"
% (rel, rtol, atol, str(index), a[index], b[index]),
names=names)
msg = npt.build_err_msg([actual, b], err_msg=errMsg)

raise AssertionError(msg)


def assert_allclose(a, b, rtol=1e-07, atol=0, equal_nan=True):
assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('a', 'b'), equal_nan=False):
"""Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.

Expand All @@ -528,7 +611,6 @@ def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('
equals = np.isclose(a, b, rtol=rtol, atol=atol)
err = 1 - np.count_nonzero(equals) / equals.size
if err > etol:
#if True:
index, rel = find_max_violation(a, b, rtol, atol)
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
Expand Down Expand Up @@ -658,7 +740,7 @@ def simple_forward(sym, ctx=None, is_train=False, **inputs):


def _parse_location(sym, location, ctx, dtype=default_dtype()):
"""Parses the given location to a dictionary.
"""Parses the given location to a ordered dictionary.

Arguments of the provided op `sym` are used as dictionary keys
and elements of `location` are used as values.
Expand Down Expand Up @@ -714,7 +796,7 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()):
location = {k: v for k, v in zip(sym.list_arguments(), location)}
location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
if isinstance(v, np.ndarray) else v for k, v in location.items()}
return location
return _sorted_dict(location)


def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
Expand Down Expand Up @@ -1143,7 +1225,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
if isinstance(expected, (list, tuple)):
expected = {k:v for k, v in zip(sym.list_arguments(), expected)}

args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()}
# Dirty the output buffer deterministically, for reproducibility.
args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in _sorted_items(expected)}
args_grad_data = {}
for k, v in args_grad_npy.items():
nd = mx.nd.array(v, ctx=ctx, dtype=dtype)
Expand Down Expand Up @@ -1279,6 +1362,15 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole",
else:
raise ValueError('typ can only be "whole" or "forward".')


def get_tolerance(rtol, ctx):
if 'atol' in ctx:
return ctx['atol']
if 'atol_mult' in ctx:
return ctx['atol_mult'] * rtol
return rtol


def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
arg_params=None, aux_params=None, tol=None,
raise_on_err=True, ground_truth=None, equal_nan=False,
Expand Down Expand Up @@ -1397,12 +1489,15 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
for i, exe in enumerate(exe_list):
if i == max_idx:
continue

rtol = tol[dtypes[i]]
atol = get_tolerance(rtol, ctx_list[i])
for name, arr in zip(output_names, exe.outputs):
gtarr = gt[name].astype(dtypes[i]).asnumpy()
arr = arr.asnumpy()
# Previously, the cast was to dtypes[i], but symbol may be mixed-precision,
# so casting the ground truth to the actual output type seems more correct.
gtarr = gt[name].astype(arr.dtype)
try:
assert_almost_equal(arr, gtarr, rtol=tol[dtypes[i]], atol=tol[dtypes[i]],
equal_nan=equal_nan)
assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan)
except AssertionError as e:
print('Predict Err: ctx %d vs ctx %d at %s'%(i, max_idx, name))
traceback.print_exc()
Expand All @@ -1420,16 +1515,20 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
for i, exe in enumerate(exe_list):
if i == max_idx:
continue

rtol = tol[dtypes[i]]
atol = get_tolerance(rtol, ctx_list[i])
curr = zip(output_names + arg_names, exe.outputs + exe.grad_arrays)
for name, arr in curr:
if gt[name] is None:
assert arr is None
continue
gtarr = gt[name].astype(dtypes[i]).asnumpy()
arr = arr.asnumpy()

# Previous cast was to dtypes[i], but symbol may be mixed-precision,
# so casting the ground truth to the actual output type seems more correct.
gtarr = gt[name].astype(arr.dtype)
try:
assert_almost_equal(arr, gtarr, rtol=tol[dtypes[i]], atol=tol[dtypes[i]],
equal_nan=equal_nan)
assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan)
except AssertionError as e:
print('Train Err: ctx %d vs ctx %d at %s'%(i, max_idx, name))
traceback.print_exc()
Expand Down Expand Up @@ -1594,7 +1693,7 @@ def get_mnist_iterator(batch_size, input_shape, num_parts=1, part_index=0):
"""

get_mnist_ubyte()
flat = False if len(input_shape) == 3 else True # pylint: disable=simplifiable-if-expression
flat = len(input_shape) != 3

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
Expand Down Expand Up @@ -2034,12 +2133,14 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc

def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
"""Compare ndarray tuple."""
if t1 is not None and t2 is not None:
if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)
if t1 is None or t2 is None:
return

if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1, t2, rtol=rtol, atol=atol)


def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
Expand Down Expand Up @@ -2071,7 +2172,7 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
assert_almost_equal(w1, w2, rtol=rtol, atol=atol)

class EnvManager(object):
"""Environment variable setter and unsetter via with idiom"""
Expand Down
Loading