Skip to content

Commit

Permalink
Reduce numerical error on numerical gradient calculations
Browse files Browse the repository at this point in the history
Fixes apache#11720
Overall will reduce flakiness of tests using numerical gradients
  • Loading branch information
larroy authored and anirudhacharya committed Oct 25, 2020
1 parent 3dda8e9 commit 693362b
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,18 +1088,6 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=None, rto
if ctx is None:
ctx = default_context()

def random_projection(shape):
"""Get a random weight matrix with not too small elements
Parameters
----------
shape : list or tuple
"""
# random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient
plain = np.random.rand(*shape) + 0.1
return plain

location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
location_npy = {k:v.asnumpy() for k, v in location.items()}
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
Expand All @@ -1126,14 +1114,13 @@ def random_projection(shape):
is_np_sym = bool(isinstance(sym, np_symbol))
if is_np_sym: # convert to np symbol for using element-wise multiplication
proj = proj.as_np_ndarray()
out = sym * proj
out = sym + proj
if is_np_sym: # convert to classic symbol so that make_loss can be used
out = out.as_nd_ndarray()
out = mx.sym.make_loss(out)

location = dict(list(location.items()) +
[("__random_proj", mx.nd.array(random_projection(out_shape[0]),
ctx=ctx, dtype=dtype))])
location = dict(list(location.items()) + [("__random_proj",
mx.nd.random.uniform(shape=out_shape[0], ctx=ctx, dtype=dtype))])
args_grad_npy = dict([(k, np.random.normal(0, 0.01, size=location[k].shape))
for k in grad_nodes]
+ [("__random_proj", np.random.normal(0, 0.01, size=out_shape[0]))])
Expand Down

0 comments on commit 693362b

Please sign in to comment.