Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

use MXNET_SAFE_ACCUMULATION for softmax accumulator #15037

Merged
merged 1 commit into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 49 additions & 18 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,18 +410,34 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
if (safe_acc) {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
if (shape.ndim() == 2) {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
}
});
});
Expand All @@ -443,20 +459,35 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);

int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
if (safe_acc) {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
}
});
});
Expand Down
47 changes: 26 additions & 21 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,7 +3434,7 @@ def npy_layer_norm_grad(data, gamma, out_grad, axis, eps):
assert_almost_equal(exe.grad_dict['data'].asnumpy(), gt_data_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['gamma'].asnumpy(), gt_gamma_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['beta'].asnumpy(), gt_beta_grad, backward_check_eps, backward_check_eps)

# Test for grad_req = add
out_grad = np.random.normal(0, 1, in_shape).astype(dtype)
init_data_grad = np.random.normal(0, 1, in_shape).astype(dtype)
Expand Down Expand Up @@ -4926,22 +4926,27 @@ def check_dtypes_almost_equal(op_name,
ref_grad_np = ref_input.grad.asnumpy()
assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol)

check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64', 'float64')
import sys
is_windows = sys.platform.startswith('win')
enforce_safe_acc = os.environ.get("MXNET_SAFE_ACCUMULATION", "0")
if not is_windows or enforce_safe_acc == "1":
os.environ["MXNET_SAFE_ACCUMULATION"] = "1"
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
'float16', 'float32', 'float32')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64')
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64', 'float64')

@with_seed()
def test_pick():
Expand Down Expand Up @@ -6457,18 +6462,18 @@ def test_laop_5():
for n in range(1, 10):
# test batched and non-batched processing
for b in range(3):
shape = (n, n) if b == 0 else (b, n, n)
shape = (n, n) if b == 0 else (b, n, n)
data_in = np.random.uniform(1, 10, shape)
# test all legal offsets of the diagonal
for offs in range(1-n, n):
# test extraction of diagonal
for offs in range(1-n, n):
# test extraction of diagonal
test_diag = mx.sym.linalg.extractdiag(data, offset=offs)
res_diag = np.diagonal(data_in, offset=offs) if b==0 else np.diagonal(data_in, axis1=1, axis2=2, offset=offs)
check_symbolic_forward(test_diag, [data_in], [res_diag])
check_numeric_gradient(test_diag, [data_in])
# test generation of diagonal matrix
test_diag2 = mx.sym.linalg.makediag(data, offset=offs)
res_diag2 = None
res_diag2 = None
if b == 0:
res_diag2 = np.diagflat(res_diag, k=offs)
else:
Expand Down