From 794e919033c2211913eced06e6cfe598cbc77a30 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 21 May 2019 22:48:07 +0000 Subject: [PATCH] add to amp list --- python/mxnet/contrib/amp/lists/symbol.py | 1 + tests/python/unittest/test_operator.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 9c99340ab75d..c1333e01bbf9 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -471,6 +471,7 @@ 'log_softmax', 'InstanceNorm', 'LayerNorm', + 'GroupNorm', 'L2Normalization', 'LRN', 'SoftmaxActivation', diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c4911913a3af..204d2977cb97 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1882,8 +1882,8 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, num_groups=num_groups, eps=eps, output_mean_var=True) check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std], - rtol=1e-2 if dtype == np.float16 else 1e-4, - atol=4e-3 if dtype == np.float16 else 1e-6, dtype=dtype) + rtol=1e-2 if dtype == np.float16 else 1e-3, + atol=5e-3 if dtype == np.float16 else 1e-5, dtype=dtype) mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, num_groups=num_groups, eps=eps, output_mean_var=False) np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype) @@ -1896,7 +1896,7 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)], [np_data_grad, np_gamma_grad, np_beta_grad], rtol=1e-2 if dtype == np.float16 else 1e-3, - atol=2e-2 if dtype == np.float16 else 1e-5, dtype=dtype) + atol=5e-2 if dtype == np.float16 else 1e-5, dtype=dtype) @with_seed()