Skip to content

Commit

Permalink
Mxnet allclose (apache#14443)
Browse files Browse the repository at this point in the history
* MXNET version of numpy.allclose operation implemented

* Helper function test_bulking renamed to _test_bulking

* Will use mx.test_utils.assert_allclose and not a numpy version of similar function

* Trigger CI

* Trigger CI

* Will use mx.test_utils.assert_allclose

* Trigger CI

* Trigger CI

* Problem with missed _test_bulking() function fixed

* Fixing minor bug in error reporting

* Trigger CI

* Trigger CI

* retrigger CI

* Fixing problems in discrepancies printout in assert_almost_equal

* Trigger CI

* Trigger CI

* Improved version of MxNet allclose operator

* Fixing minor problem in attribite definition for allclose operator

* retrigger CI

* Minor problem fixed

* Trigger CI

* try to fix find_max_violation

* Trigger CI

* Skip 'test_bulking_gluon_gpu()' test

* Fixing bug in reporting MaxErrors for NaN coordinates.

* use smaller testcase for test_layer_norm

* remove redundant test for test_layer_norm

* reuse old testcase

* retrigger CI

* ci

* ci

* Merge problem fixed

* Fixing Python's lint problem

* Trigger CI

* Trigger CI

* Trigger CI

* Trigger CI

* Trigger CI
  • Loading branch information
drivanov authored and aaronmarkham committed Oct 16, 2019
1 parent 9c90d60 commit cc89586
Show file tree
Hide file tree
Showing 16 changed files with 938 additions and 546 deletions.
188 changes: 150 additions & 38 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import zipfile
import json
from contextlib import contextmanager
from collections import OrderedDict
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
Expand Down Expand Up @@ -109,6 +110,16 @@ def random_sample(population, k):
return population_copy[0:k]


def _sorted_items(d):
"""Return (key, value) pairs of dict 'd' in a deterministic order (sorted by key)."""
return sorted(d.items(), key=lambda t: t[0])


def _sorted_dict(d):
"""Return ordered dictionary containing items ordered by their keys."""
return OrderedDict(_sorted_items(d))


def _validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="uniform"):
"""Validates inputs for csr generation helper functions
Expand Down Expand Up @@ -482,9 +493,10 @@ def find_max_violation(a, b, rtol=None, atol=None):
"""Finds and returns the location of maximum violation."""
rtol = get_rtol(rtol)
atol = get_atol(atol)
diff = np.abs(a-b)
# 'smart' absdiff that considers inf's as equals (to match np.allclose)
absdiff = np.where(np.equal(a, b), 0, np.abs(a-b))
tol = atol + rtol*np.abs(b)
violation = diff/(tol+1e-20)
violation = absdiff/(tol+1e-20)
loc = np.argmax(violation)
idx = np.unravel_index(loc, violation.shape)
return idx, np.max(violation)
Expand All @@ -500,40 +512,122 @@ def same(a, b):
"""
return np.array_equal(a, b)

def almost_equal(a, b, rtol=None, atol=None, equal_nan=False, use_broadcast=True):
"""Test if two numpy arrays are almost equal."""
# pylint: disable=unexpected-keyword-arg
if (not use_broadcast) and a.shape != b.shape:

def checkShapes(a, b):
if a.shape != b.shape:
msg = npt.build_err_msg([a, b],
err_msg="a.shape = {} and b.shape = {} are not equal"
.format(str(a.shape), str(b.shape)))
raise AssertionError(msg)


def almost_equal(a, b, rtol=None, atol=None, equal_nan=False, use_broadcast=True):
"""Test if two numpy arrays are almost equal."""
# pylint: disable=unexpected-keyword-arg
if not use_broadcast:
checkShapes(a, b)

return np.allclose(a, b, rtol=get_rtol(rtol), atol=get_atol(atol), equal_nan=equal_nan)
# pylint: enable=unexpected-keyword-arg

def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False, use_broadcast=True):

def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False,
use_broadcast=True, mismatches=(10, 10)):
"""Test that two numpy arrays are almost equal. Raise exception message if not.
Parameters
----------
a : np.ndarray
b : np.ndarray
threshold : None or float
The checking threshold. Default threshold will be used if set to ``None``.
a : np.ndarray or mx.nd.array
b : np.ndarray or mx.nd.array
rtol : None or float
The relative threshold. Default threshold will be used if set to ``None``.
atol : None or float
The absolute threshold. Default threshold will be used if set to ``None``.
names : tuple of names, optional
The names used in error message when an exception occurs
equal_nan : boolean, optional
The flag determining how to treat NAN values in comparison
mismatches : tuple of mismatches
Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
"""
if not use_broadcast:
checkShapes(a, b)

rtol = get_rtol(rtol)
atol = get_atol(atol)
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan, use_broadcast=use_broadcast):
return
use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
if not use_np_allclose:
if not (hasattr(a, 'context') and hasattr(b, 'context') and a.context == b.context and a.dtype == b.dtype):
use_np_allclose = True
if isinstance(a, mx.nd.NDArray):
a = a.asnumpy()
if isinstance(b, mx.nd.NDArray):
b = b.asnumpy()

if use_np_allclose:
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
else:
output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan)
if output.asnumpy() == 1:
return

a = a.asnumpy()
b = b.asnumpy()

def locationError(a, b, index, names, maxError=False):
"""Create element mismatch comment
Parameters
----------
a, b : compared np.ndarray's
index : tuple of coordinate arrays
Location of violation
names : tuple of names
The names of compared arrays.
maxError: boolean, optional
Flag indicating that maximum error is reporting.
"""
maximum = "maximum " if maxError else ""
return "Location of %serror: %s, %s=%.8f, %s=%.8f" \
% (maximum, str(index), names[0], a[index], names[1], b[index])

index, rel = find_max_violation(a, b, rtol, atol)
indexErr = index
relErr = rel

print('\n*** Maximum errors for vector of size {}: rtol={}, atol={}\n'.format(a.size, rtol, atol))
aTmp = a.copy()
bTmp = b.copy()
i = 1
while i <= a.size:
if i <= mismatches[0]:
print("%3d: Error %f %s" %(i, rel, locationError(a, b, index, names)))

aTmp[index] = bTmp[index] = 0
if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan):
break

i += 1
if i <= mismatches[1] or mismatches[1] <= 0:
index, rel = find_max_violation(aTmp, bTmp, rtol, atol)
else:
break

mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e (mismatch %s%f%%).\n%s" % \
(relErr, rtol, atol, mismatchDegree, 100*i/a.size, \
locationError(a, b, indexErr, names, maxError=True))
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. "
" Location of maximum error:%s, a=%f, b=%f"
% (rel, rtol, atol, str(index), a[index], b[index]),
names=names)
msg = npt.build_err_msg([a, b], err_msg=errMsg)

raise AssertionError(msg)


def assert_allclose(a, b, rtol=1e-07, atol=0, equal_nan=True):
assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('a', 'b'), equal_nan=False):
"""Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.
Expand All @@ -554,7 +648,6 @@ def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('
equals = np.isclose(a, b, rtol=rtol, atol=atol)
err = 1 - np.count_nonzero(equals) / equals.size
if err > etol:
#if True:
index, rel = find_max_violation(a, b, rtol, atol)
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
Expand Down Expand Up @@ -684,7 +777,7 @@ def simple_forward(sym, ctx=None, is_train=False, **inputs):


def _parse_location(sym, location, ctx, dtype=default_dtype()):
"""Parses the given location to a dictionary.
"""Parses the given location to a ordered dictionary.
Arguments of the provided op `sym` are used as dictionary keys
and elements of `location` are used as values.
Expand Down Expand Up @@ -740,7 +833,7 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()):
location = {k: v for k, v in zip(sym.list_arguments(), location)}
location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
if isinstance(v, np.ndarray) else v for k, v in location.items()}
return location
return _sorted_dict(location)


def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
Expand Down Expand Up @@ -1177,7 +1270,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
if isinstance(expected, (list, tuple)):
expected = {k:v for k, v in zip(sym.list_arguments(), expected)}

args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()}
# Dirty the output buffer deterministically, for reproducibility.
args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in _sorted_items(expected)}
args_grad_data = {}
for k, v in args_grad_npy.items():
nd = mx.nd.array(v, ctx=ctx, dtype=expected[k].dtype if dtype == "asnumpy" else dtype)
Expand Down Expand Up @@ -1313,6 +1407,15 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole",
else:
raise ValueError('typ can only be "whole" or "forward".')


def get_tolerance(rtol, ctx):
if 'atol' in ctx:
return ctx['atol']
if 'atol_mult' in ctx:
return ctx['atol_mult'] * rtol
return rtol


def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
arg_params=None, aux_params=None, tol=None,
raise_on_err=True, ground_truth=None, equal_nan=False,
Expand Down Expand Up @@ -1431,12 +1534,15 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
for i, exe in enumerate(exe_list):
if i == max_idx:
continue

rtol = tol[dtypes[i]]
atol = get_tolerance(rtol, ctx_list[i])
for name, arr in zip(output_names, exe.outputs):
gtarr = gt[name].astype(dtypes[i]).asnumpy()
arr = arr.asnumpy()
# Previously, the cast was to dtypes[i], but symbol may be mixed-precision,
# so casting the ground truth to the actual output type seems more correct.
gtarr = gt[name].astype(arr.dtype)
try:
assert_almost_equal(arr, gtarr, rtol=tol[dtypes[i]], atol=tol[dtypes[i]],
equal_nan=equal_nan)
assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan)
except AssertionError as e:
print('Predict Err: ctx %d vs ctx %d at %s'%(i, max_idx, name))
traceback.print_exc()
Expand All @@ -1454,16 +1560,20 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
for i, exe in enumerate(exe_list):
if i == max_idx:
continue

rtol = tol[dtypes[i]]
atol = get_tolerance(rtol, ctx_list[i])
curr = zip(output_names + arg_names, exe.outputs + exe.grad_arrays)
for name, arr in curr:
if gt[name] is None:
assert arr is None
continue
gtarr = gt[name].astype(dtypes[i]).asnumpy()
arr = arr.asnumpy()

# Previous cast was to dtypes[i], but symbol may be mixed-precision,
# so casting the ground truth to the actual output type seems more correct.
gtarr = gt[name].astype(arr.dtype)
try:
assert_almost_equal(arr, gtarr, rtol=tol[dtypes[i]], atol=tol[dtypes[i]],
equal_nan=equal_nan)
assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan)
except AssertionError as e:
print('Train Err: ctx %d vs ctx %d at %s'%(i, max_idx, name))
traceback.print_exc()
Expand Down Expand Up @@ -1694,7 +1804,7 @@ def get_mnist_iterator(batch_size, input_shape, num_parts=1, part_index=0):
"""

get_mnist_ubyte()
flat = not bool(len(input_shape) == 3)
flat = len(input_shape) != 3

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
Expand Down Expand Up @@ -2134,12 +2244,14 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc

def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
"""Compare ndarray tuple."""
if t1 is not None and t2 is not None:
if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)
if t1 is None or t2 is None:
return

if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1, t2, rtol=rtol, atol=atol)


def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
Expand Down Expand Up @@ -2171,7 +2283,7 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
assert_almost_equal(w1, w2, rtol=rtol, atol=atol)


def same_symbol_structure(sym1, sym2):
Expand Down
Loading

0 comments on commit cc89586

Please sign in to comment.