From 492be4585c9a9c27114a1b03f0f873b20463dc53 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Wed, 22 May 2019 11:52:17 -0700 Subject: [PATCH] use MXNET_SAFE_ACCUMULATION for softmax accumulator (#15037) --- src/operator/nn/softmax-inl.h | 67 +++++++++++++++++++------- tests/python/unittest/test_operator.py | 47 ++++++++++-------- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 096d87416081..1910ff49eb11 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -410,18 +410,34 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); + if (safe_acc) { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } } }); }); @@ -443,20 +459,35 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (shape.ndim() == 2) { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); + if (safe_acc) { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } } else { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } } }); }); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b5aa06964b29..324e8a3d7ed9 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3434,7 +3434,7 @@ def npy_layer_norm_grad(data, gamma, out_grad, axis, eps): assert_almost_equal(exe.grad_dict['data'].asnumpy(), gt_data_grad, backward_check_eps, backward_check_eps) assert_almost_equal(exe.grad_dict['gamma'].asnumpy(), gt_gamma_grad, backward_check_eps, backward_check_eps) assert_almost_equal(exe.grad_dict['beta'].asnumpy(), gt_beta_grad, backward_check_eps, backward_check_eps) - + # Test for grad_req = add out_grad = np.random.normal(0, 1, in_shape).astype(dtype) init_data_grad = np.random.normal(0, 1, in_shape).astype(dtype) @@ -4926,22 +4926,27 @@ def check_dtypes_almost_equal(op_name, ref_grad_np = ref_input.grad.asnumpy() assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') - check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, - 'float16', 'float32') - check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, - 'float16', 'float32', 'float32') - check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, - 'float32', 'float64') - check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, - 'float32', 'float64', 'float64') + import sys + is_windows = sys.platform.startswith('win') + enforce_safe_acc = os.environ.get("MXNET_SAFE_ACCUMULATION", "0") + if not is_windows or enforce_safe_acc == "1": + os.environ["MXNET_SAFE_ACCUMULATION"] = "1" + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64', 'float64') @with_seed() def test_pick(): @@ -6457,18 +6462,18 @@ def test_laop_5(): for n in range(1, 10): # test batched and non-batched processing for b in range(3): - shape = (n, n) if b == 0 else (b, n, n) + shape = (n, n) if b == 0 else (b, n, n) data_in = np.random.uniform(1, 10, shape) # test all legal offsets of the diagonal - for offs in range(1-n, n): - # test extraction of diagonal + for offs in range(1-n, n): + # test extraction of diagonal test_diag = mx.sym.linalg.extractdiag(data, offset=offs) res_diag = np.diagonal(data_in, offset=offs) if b==0 else np.diagonal(data_in, axis1=1, axis2=2, offset=offs) check_symbolic_forward(test_diag, [data_in], [res_diag]) check_numeric_gradient(test_diag, [data_in]) # test generation of diagonal matrix test_diag2 = mx.sym.linalg.makediag(data, offset=offs) - res_diag2 = None + res_diag2 = None if b == 0: res_diag2 = np.diagflat(res_diag, k=offs) else: