Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add to amp list
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jun 16, 2019
1 parent f250b6a commit e91b389
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/mxnet/contrib/amp/lists/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@
'log_softmax',
'InstanceNorm',
'LayerNorm',
'GroupNorm',
'L2Normalization',
'LRN',
'SoftmaxActivation',
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,8 +1882,8 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=True)
check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std],
rtol=1e-2 if dtype == np.float16 else 1e-4,
atol=4e-3 if dtype == np.float16 else 1e-6, dtype=dtype)
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=5e-3 if dtype == np.float16 else 1e-5, dtype=dtype)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=False)
np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype)
Expand All @@ -1896,7 +1896,7 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)],
[np_data_grad, np_gamma_grad, np_beta_grad],
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=2e-2 if dtype == np.float16 else 1e-5, dtype=dtype)
atol=5e-2 if dtype == np.float16 else 1e-5, dtype=dtype)


@with_seed()
Expand Down

0 comments on commit e91b389

Please sign in to comment.