Skip to content

Commit

Permalink
use MXNET_SAFE_ACCUMULATION for softmax accumulator (apache#15037)
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored and Rohit Kumar Srivastava committed May 22, 2019
1 parent 01bf170 commit 89a3852
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 39 deletions.
67 changes: 49 additions & 18 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,18 +410,34 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
if (safe_acc) {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
if (shape.ndim() == 2) {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
}
});
});
Expand All @@ -443,20 +459,35 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);

int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
if (safe_acc) {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
}
});
});
Expand Down
47 changes: 26 additions & 21 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,7 +3434,7 @@ def npy_layer_norm_grad(data, gamma, out_grad, axis, eps):
assert_almost_equal(exe.grad_dict['data'].asnumpy(), gt_data_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['gamma'].asnumpy(), gt_gamma_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['beta'].asnumpy(), gt_beta_grad, backward_check_eps, backward_check_eps)

# Test for grad_req = add
out_grad = np.random.normal(0, 1, in_shape).astype(dtype)
init_data_grad = np.random.normal(0, 1, in_shape).astype(dtype)
Expand Down Expand Up @@ -4926,22 +4926,27 @@ def check_dtypes_almost_equal(op_name,
ref_grad_np = ref_input.grad.asnumpy()
assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol)

check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64', 'float64')
import sys
is_windows = sys.platform.startswith('win')
enforce_safe_acc = os.environ.get("MXNET_SAFE_ACCUMULATION", "0")
if not is_windows or enforce_safe_acc == "1":
os.environ["MXNET_SAFE_ACCUMULATION"] = "1"
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64', 'float64')

@with_seed()
def test_pick():
Expand Down Expand Up @@ -6457,18 +6462,18 @@ def test_laop_5():
for n in range(1, 10):
# test batched and non-batched processing
for b in range(3):
shape = (n, n) if b == 0 else (b, n, n)
shape = (n, n) if b == 0 else (b, n, n)
data_in = np.random.uniform(1, 10, shape)
# test all legal offsets of the diagonal
for offs in range(1-n, n):
# test extraction of diagonal
for offs in range(1-n, n):
# test extraction of diagonal
test_diag = mx.sym.linalg.extractdiag(data, offset=offs)
res_diag = np.diagonal(data_in, offset=offs) if b==0 else np.diagonal(data_in, axis1=1, axis2=2, offset=offs)
check_symbolic_forward(test_diag, [data_in], [res_diag])
check_numeric_gradient(test_diag, [data_in])
# test generation of diagonal matrix
test_diag2 = mx.sym.linalg.makediag(data, offset=offs)
res_diag2 = None
res_diag2 = None
if b == 0:
res_diag2 = np.diagflat(res_diag, k=offs)
else:
Expand Down

0 comments on commit 89a3852

Please sign in to comment.