From 0901ca2c7bc39af312a5f813a94ea571fee7b348 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Thu, 11 Apr 2019 23:14:42 -0700 Subject: [PATCH 1/6] upcast Softmax Accumulator type only when output dtype is specified --- src/operator/mxnet_op.h | 59 ----------------------------------- src/operator/nn/softmax-inl.h | 12 +++++-- 2 files changed, 9 insertions(+), 62 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index e331255c2e50..d7227da8d40b 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -249,65 +249,6 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } -#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - typedef double AType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - typedef double AType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half_t DType; \ - typedef float AType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - { \ - typedef uint8_t DType; \ - typedef uint8_t AType; \ - LOG(FATAL) << "This operation only support " \ - "floating point types not uint8"; \ - } \ - break; \ - case mshadow::kInt8: \ - { \ - typedef int8_t DType; \ - typedef int8_t AType; \ - LOG(FATAL) << "This operation only support " \ - "floating point types not int8"; \ - } \ - break; \ - case mshadow::kInt32: \ - { \ - typedef int32_t DType; \ - typedef int32_t AType; \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int32"; \ - } \ - break; \ - case mshadow::kInt64: \ - { \ - typedef int64_t DType; \ - typedef int64_t AType; \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int64"; \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - #define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\ switch (type) { \ case mshadow::kFloat32: \ diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 096d87416081..034a592117a8 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -64,7 +64,8 @@ struct log_softmax_fwd { }; -template +template inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; @@ -410,7 +411,10 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { + + bool upcast_atype = softmax_has_dtype_override(attrs); + + MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, upcast_atype, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { if (shape.ndim() == 2) { Softmax( @@ -444,7 +448,9 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { + bool upcast_atype = softmax_has_dtype_override(attrs); + + MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, upcast_atype, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (shape.ndim() == 2) { From 4754dc7b98801e13d80cdf58867eb41d5e94253a Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Sat, 13 Apr 2019 05:52:28 -0700 Subject: [PATCH 2/6] set softmax Accumulator type to the DType passed --- src/operator/mxnet_op.h | 51 ------------------- src/operator/nn/softmax-inl.h | 92 +++++++++++++++++++---------------- 2 files changed, 49 insertions(+), 94 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d7227da8d40b..ed15e1e331ed 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -249,57 +249,6 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } -#define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - typedef double AType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - typedef double AType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half_t DType; \ - typedef float AType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - { \ - typedef uint8_t DType; \ - typedef uint32_t AType; \ - } \ - break; \ - case mshadow::kInt8: \ - { \ - typedef int8_t DType; \ - typedef int32_t AType; \ - } \ - break; \ - case mshadow::kInt32: \ - { \ - typedef int32_t DType; \ - typedef int64_t AType; \ - } \ - break; \ - case mshadow::kInt64: \ - { \ - typedef int64_t DType; \ - typedef int64_t AType; \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - /*! * \brief assign the val to out according * to request in Kernel::Launch diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 034a592117a8..8098af23f2e9 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -311,9 +311,9 @@ struct SoftmaxParam : public dmlc::Parameter { } }; -static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { +static inline int sofmtax_dtype_param(const nnvm::NodeAttrs &attrs) { const SoftmaxParam& param = nnvm::get(attrs.parsed); - return param.dtype.has_value() && param.dtype.value() != -1; + return param.dtype.has_value() ? param.dtype.value(): -1; } static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, @@ -323,7 +323,7 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); const SoftmaxParam& param = nnvm::get(attrs.parsed); - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); type_assign(&(*in_attrs)[0], (*out_attrs)[0]); return true; @@ -335,7 +335,7 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); } else { return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs); @@ -346,7 +346,7 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { CHECK_EQ(out_attrs->size(), 1); - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { CHECK_EQ(in_attrs->size(), 3); int in_dtype = (*in_attrs)[1]; int out_dtype = (*in_attrs)[2]; @@ -366,7 +366,7 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, static inline std::vector > SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { return std::vector >{{0, 0}, {1, 0}, {2, 0}}; } else { return std::vector >{{0, 0}, {1, 0}}; @@ -374,11 +374,11 @@ SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { } static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) { - return softmax_has_dtype_override(attrs) ? 3 : 2; + return sofmtax_dtype_param(attrs) != -1 ? 3 : 2; } static inline std::vector SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { return std::vector{"ograd", "data", "output"}; } else { return std::vector{"ograd", "output"}; @@ -389,7 +389,7 @@ struct SoftmaxFGradient { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, const std::vector& ograds) const { - if (softmax_has_dtype_override(n->attrs)) { + if (sofmtax_dtype_param(n->attrs) != -1) { return ElemwiseGradUseInOut {op_name}(n, ograds); } else { return ElemwiseGradUseOut {op_name}(n, ograds); @@ -412,21 +412,24 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, param.temperature.value() : 1.0; mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - bool upcast_atype = softmax_has_dtype_override(attrs); - - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, upcast_atype, { - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); - } + int atype_flag_ = sofmtax_dtype_param(attrs); + atype_flag_ = atype_flag_ != -1 ? atype_flag_ : inputs[0].type_flag_; + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(atype_flag_, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } + }); }); }); } @@ -446,25 +449,28 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, param.temperature.value() : 1.0; mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; - - bool upcast_atype = softmax_has_dtype_override(attrs); - - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, upcast_atype, { - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (shape.ndim() == 2) { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); - } - }); + int out_idx = sofmtax_dtype_param(attrs) != -1 ? 2 : 1; + + int atype_flag_ = sofmtax_dtype_param(attrs); + atype_flag_ = atype_flag_ != -1 ? atype_flag_ : inputs[0].type_flag_; + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, OType, { + MSHADOW_REAL_TYPE_SWITCH(atype_flag_, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } + }); + }); }); }); } From 2f667457fbe974e4289ff5aabbd674fa3c36ab9f Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Tue, 16 Apr 2019 15:37:34 -0700 Subject: [PATCH 3/6] Update test_softmax_dtype: use AType in np_softmax, change tolerance values based on AType --- src/operator/mxnet_op.h | 51 ++++++++++++++++++++++++++ tests/python/unittest/test_operator.py | 25 ++++++++----- 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index ed15e1e331ed..d7227da8d40b 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -249,6 +249,57 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + typedef float AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + typedef uint32_t AType; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + typedef int32_t AType; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + typedef int64_t AType; \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + typedef int64_t AType; \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! * \brief assign the val to out according * to request in Kernel::Launch diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f2d8a1b2524f..cdafc9a16bac 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -270,10 +270,11 @@ def test_rnnrelu_dropout(): out = exe.forward(is_train=True) out[0].wait_to_read() -def np_softmax(x, axis=-1, temperature=1.0): +def np_softmax(x, axis=-1, temperature=1.0, odtype=None): + x = x.astype(odtype) x = x - np.max(x, axis=axis, keepdims=True) x = np.exp(x/temperature) - x /= np.sum(x, axis=axis, keepdims=True) + x /= np.sum(x, axis=axis, keepdims=True, dtype=odtype) return x @@ -4748,28 +4749,34 @@ def check_dtypes_almost_equal(op_name, op = getattr(mx.nd, op_name) input_data = mx.random.uniform(shape=(100, 500)) dtype_input = input_data.astype(idtype) + np_op = {'softmax': np_softmax(dtype_input.asnumpy(), odtype=odtype), + 'softmin': np_softmax(-1 * dtype_input.asnumpy(), odtype=odtype), + 'log_softmax': np.log(np_softmax(dtype_input.asnumpy(), + odtype=odtype)+1e-20) + } ref_input = input_data.astype(ref_dtype) dtype_input.attach_grad() ref_input.attach_grad() with mx.autograd.record(): dtype_softmax = op(dtype_input, axis=-1, dtype=odtype) ref_softmax = op(ref_input, axis=-1, dtype=odtype) - dtype_softmax_np = dtype_softmax.asnumpy() + dtype_mx_softmax = dtype_softmax.asnumpy() + dtype_np_softmax = np_op[op_name] ref_softmax_np = ref_softmax.asnumpy() - assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol) + assert_almost_equal(dtype_mx_softmax, dtype_np_softmax, rtol=rtol, atol=atol) dtype_softmax.backward() ref_softmax.backward() dtype_grad_np = dtype_input.grad.asnumpy() ref_grad_np = ref_input.grad.asnumpy() assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmax', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmax', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32', 'float32') check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-3, 1e-5, 'float32', 'float64') check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, 'float16', 'float32') From b3dc30ded201194e2f01f7db4a0e37f1855d1d52 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Sat, 20 Apr 2019 01:21:59 -0700 Subject: [PATCH 4/6] changes to check_numeric_grad to support odtype and test_softmax_dtype --- python/mxnet/test_utils.py | 18 ++++-- tests/python/unittest/test_operator.py | 89 ++++++++++++++------------ 2 files changed, 59 insertions(+), 48 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index bbb12dd5d7af..5c2dc4053a1b 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -800,7 +800,7 @@ def as_stype(var, stype, dtype): def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, atol=None, grad_nodes=None, use_forward_train=True, ctx=None, - grad_stype_dict=None, dtype=default_dtype()): + grad_stype_dict=None, dtype=default_dtype(), odtype=None): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -841,6 +841,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto [1] https://github.com/Theano/Theano/blob/master/theano/gradient.py """ assert dtype in (np.float16, np.float32, np.float64) + odtype = dtype if odtype is None else odtype + # cannot use finite differences with small eps without high precision if dtype in (np.float32, np.float16): assert numeric_eps >= 1e-5 @@ -887,12 +889,12 @@ def random_projection(shape): location = dict(list(location.items()) + [("__random_proj", mx.nd.array(random_projection(out_shape[0]), - ctx=ctx, dtype=dtype))]) - args_grad_npy = dict([(k, np.random.normal(0, 0.01, size=location[k].shape)) + ctx=ctx, dtype=odtype))]) + args_grad_npy = dict([(k, np.random.normal(0, 0.01, size=location[k].shape).astype(odtype)) for k in grad_nodes] + [("__random_proj", np.random.normal(0, 0.01, size=out_shape[0]))]) - args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in args_grad_npy.items()} + args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=odtype) for k, v in args_grad_npy.items()} if grad_stype_dict is not None: assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" for k, v in grad_stype_dict.items(): @@ -1016,7 +1018,7 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None, aux_states=None, grad_req='write', ctx=None, grad_stypes=None, - equal_nan=False, dtype=default_dtype()): + equal_nan=False, dtype=default_dtype(), odtype=None): """Compares a symbol's backward results with the expected ones. Prints error messages if the backward results are not the same as the expected results. @@ -1076,6 +1078,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected]) """ assert dtype in (np.float16, np.float32, np.float64) + odtype = dtype if odtype is None else odtype + if ctx is None: ctx = default_context() @@ -1085,10 +1089,10 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} - args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()} + args_grad_npy = {k:np.random.normal(size=v.shape).astype(odtype) for k, v in expected.items()} args_grad_data = {} for k, v in args_grad_npy.items(): - nd = mx.nd.array(v, ctx=ctx, dtype=dtype) + nd = mx.nd.array(v, ctx=ctx, dtype=odtype) if grad_stypes is not None and k in grad_stypes: stype = grad_stypes[k] if stype is not None and stype != 'default': diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index cdafc9a16bac..7ab0488645cd 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -271,11 +271,10 @@ def test_rnnrelu_dropout(): out[0].wait_to_read() def np_softmax(x, axis=-1, temperature=1.0, odtype=None): - x = x.astype(odtype) x = x - np.max(x, axis=axis, keepdims=True) x = np.exp(x/temperature) - x /= np.sum(x, axis=axis, keepdims=True, dtype=odtype) - return x + y = x /np.sum(x, axis=axis, keepdims=True, dtype=odtype) + return y def check_elementwise_sum_with_shape(shape, n): @@ -4686,11 +4685,11 @@ def test_softmin(): @with_seed() -def test_new_softmax(): +def test_new_softmax(idtype=None): for ndim in range(1, 5): shape = np.random.randint(1, 5, size=ndim) axis = np.random.randint(-ndim, ndim) - data = np.random.uniform(-2, 2, size=shape) + data = np.random.uniform(-2, 2, size=shape).astype(idtype) sym = mx.sym.softmax(axis=axis) expected_fwd = np_softmax(data, axis=axis) expected_bwd = np.zeros(shape) @@ -4745,47 +4744,55 @@ def test_softmax_dtype(): def check_dtypes_almost_equal(op_name, atol, rtol, grad_atol, grad_rtol, - idtype, ref_dtype, odtype=None): + idtype, odtype=None): op = getattr(mx.nd, op_name) - input_data = mx.random.uniform(shape=(100, 500)) - dtype_input = input_data.astype(idtype) - np_op = {'softmax': np_softmax(dtype_input.asnumpy(), odtype=odtype), - 'softmin': np_softmax(-1 * dtype_input.asnumpy(), odtype=odtype), - 'log_softmax': np.log(np_softmax(dtype_input.asnumpy(), - odtype=odtype)+1e-20) + + input_data = mx.random.uniform(shape=(3, 4)).astype(idtype) + + np_op = {'softmax': np_softmax(input_data.asnumpy(), odtype=odtype if odtype else idtype), + 'softmin': np_softmax(-1 * input_data.asnumpy(), odtype=odtype if odtype else idtype), + 'log_softmax': np.log(np_softmax(input_data.asnumpy(), + odtype=odtype if odtype else idtype)+1e-20) } - ref_input = input_data.astype(ref_dtype) - dtype_input.attach_grad() - ref_input.attach_grad() + + input_data.attach_grad() with mx.autograd.record(): - dtype_softmax = op(dtype_input, axis=-1, dtype=odtype) - ref_softmax = op(ref_input, axis=-1, dtype=odtype) + dtype_softmax = op(input_data, axis=-1, dtype=odtype) + dtype_mx_softmax = dtype_softmax.asnumpy() dtype_np_softmax = np_op[op_name] - ref_softmax_np = ref_softmax.asnumpy() - assert_almost_equal(dtype_mx_softmax, dtype_np_softmax, rtol=rtol, atol=atol) - dtype_softmax.backward() - ref_softmax.backward() - dtype_grad_np = dtype_input.grad.asnumpy() - ref_grad_np = ref_input.grad.asnumpy() - assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) - - check_dtypes_almost_equal('softmax', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32') - check_dtypes_almost_equal('softmax', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32', 'float32') - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') - check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32') - check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-3, 1e-5, 'float16', 'float32', 'float32') - check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-3, 1e-5, 'float32', 'float64') - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') - check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, - 'float16', 'float32') - check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, - 'float16', 'float32', 'float32') - check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, - 'float32', 'float64') - check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, - 'float32', 'float64', 'float64') + + assert_almost_equal(dtype_mx_softmax, dtype_np_softmax, rtol=rtol, atol=atol, + names=("mx_softmax", "np_softmax")) + + data_var = mx.sym.Variable('data') + op_sym = getattr(mx.sym, op_name) + + sym = op_sym(data=data_var) if not odtype else op_sym(data=data_var, dtype=odtype) + + expected_fwd = np_op[op_name] + expected_bwd = np.zeros(input_data.shape) + check_symbolic_forward(sym, [input_data], [expected_fwd], dtype='asnumpy', atol=atol, + rtol=rtol) + + if op_name is not 'log_softmax': + check_symbolic_backward(sym, [input_data], [np.ones(expected_fwd.shape)], [expected_bwd], + rtol=grad_rtol, atol=grad_atol, dtype=idtype) + if idtype is not np.float16: + check_numeric_gradient(sym, location=[input_data], rtol=grad_rtol, atol=grad_atol, + dtype=idtype, + odtype=odtype) + + check_dtypes_almost_equal('softmax', 1e-3, 1e-5, 1e-2, 1e-5, np.float16) + check_dtypes_almost_equal('softmax', 1e-4, 1e-5, 1e-2, 1e-5, np.float16, np.float32) + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-3, 1e-5, np.float32) + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-3, 1e-5, np.float32, np.float64) + check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-2, 1e-5, np.float16, np.float32) + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-3, 1e-5, np.float32, np.float64) + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-3, 1e-5, np.float16) + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-3, 1e-5, np.float16, np.float32) + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-5, 1e-3, 1e-5, np.float32) + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-5, 1e-3, 1e-5, np.float32, np.float64) @with_seed() def test_pick(): From feda00611c01c42314a774c7c06e019cf67d8364 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Sun, 21 Apr 2019 23:30:31 -0700 Subject: [PATCH 5/6] revert tests back to earlier since verifying against numpy on such a large tensor takes a long time and shows no change between this change and previous change --- python/mxnet/test_utils.py | 18 ++---- tests/python/unittest/test_operator.py | 90 +++++++++++--------------- 2 files changed, 45 insertions(+), 63 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 5c2dc4053a1b..bbb12dd5d7af 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -800,7 +800,7 @@ def as_stype(var, stype, dtype): def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, atol=None, grad_nodes=None, use_forward_train=True, ctx=None, - grad_stype_dict=None, dtype=default_dtype(), odtype=None): + grad_stype_dict=None, dtype=default_dtype()): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -841,8 +841,6 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto [1] https://github.com/Theano/Theano/blob/master/theano/gradient.py """ assert dtype in (np.float16, np.float32, np.float64) - odtype = dtype if odtype is None else odtype - # cannot use finite differences with small eps without high precision if dtype in (np.float32, np.float16): assert numeric_eps >= 1e-5 @@ -889,12 +887,12 @@ def random_projection(shape): location = dict(list(location.items()) + [("__random_proj", mx.nd.array(random_projection(out_shape[0]), - ctx=ctx, dtype=odtype))]) - args_grad_npy = dict([(k, np.random.normal(0, 0.01, size=location[k].shape).astype(odtype)) + ctx=ctx, dtype=dtype))]) + args_grad_npy = dict([(k, np.random.normal(0, 0.01, size=location[k].shape)) for k in grad_nodes] + [("__random_proj", np.random.normal(0, 0.01, size=out_shape[0]))]) - args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=odtype) for k, v in args_grad_npy.items()} + args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in args_grad_npy.items()} if grad_stype_dict is not None: assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" for k, v in grad_stype_dict.items(): @@ -1018,7 +1016,7 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None, aux_states=None, grad_req='write', ctx=None, grad_stypes=None, - equal_nan=False, dtype=default_dtype(), odtype=None): + equal_nan=False, dtype=default_dtype()): """Compares a symbol's backward results with the expected ones. Prints error messages if the backward results are not the same as the expected results. @@ -1078,8 +1076,6 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected]) """ assert dtype in (np.float16, np.float32, np.float64) - odtype = dtype if odtype is None else odtype - if ctx is None: ctx = default_context() @@ -1089,10 +1085,10 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} - args_grad_npy = {k:np.random.normal(size=v.shape).astype(odtype) for k, v in expected.items()} + args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()} args_grad_data = {} for k, v in args_grad_npy.items(): - nd = mx.nd.array(v, ctx=ctx, dtype=odtype) + nd = mx.nd.array(v, ctx=ctx, dtype=dtype) if grad_stypes is not None and k in grad_stypes: stype = grad_stypes[k] if stype is not None and stype != 'default': diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7ab0488645cd..f2d8a1b2524f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -270,11 +270,11 @@ def test_rnnrelu_dropout(): out = exe.forward(is_train=True) out[0].wait_to_read() -def np_softmax(x, axis=-1, temperature=1.0, odtype=None): +def np_softmax(x, axis=-1, temperature=1.0): x = x - np.max(x, axis=axis, keepdims=True) x = np.exp(x/temperature) - y = x /np.sum(x, axis=axis, keepdims=True, dtype=odtype) - return y + x /= np.sum(x, axis=axis, keepdims=True) + return x def check_elementwise_sum_with_shape(shape, n): @@ -4685,11 +4685,11 @@ def test_softmin(): @with_seed() -def test_new_softmax(idtype=None): +def test_new_softmax(): for ndim in range(1, 5): shape = np.random.randint(1, 5, size=ndim) axis = np.random.randint(-ndim, ndim) - data = np.random.uniform(-2, 2, size=shape).astype(idtype) + data = np.random.uniform(-2, 2, size=shape) sym = mx.sym.softmax(axis=axis) expected_fwd = np_softmax(data, axis=axis) expected_bwd = np.zeros(shape) @@ -4744,55 +4744,41 @@ def test_softmax_dtype(): def check_dtypes_almost_equal(op_name, atol, rtol, grad_atol, grad_rtol, - idtype, odtype=None): + idtype, ref_dtype, odtype=None): op = getattr(mx.nd, op_name) - - input_data = mx.random.uniform(shape=(3, 4)).astype(idtype) - - np_op = {'softmax': np_softmax(input_data.asnumpy(), odtype=odtype if odtype else idtype), - 'softmin': np_softmax(-1 * input_data.asnumpy(), odtype=odtype if odtype else idtype), - 'log_softmax': np.log(np_softmax(input_data.asnumpy(), - odtype=odtype if odtype else idtype)+1e-20) - } - - input_data.attach_grad() + input_data = mx.random.uniform(shape=(100, 500)) + dtype_input = input_data.astype(idtype) + ref_input = input_data.astype(ref_dtype) + dtype_input.attach_grad() + ref_input.attach_grad() with mx.autograd.record(): - dtype_softmax = op(input_data, axis=-1, dtype=odtype) - - dtype_mx_softmax = dtype_softmax.asnumpy() - dtype_np_softmax = np_op[op_name] - - assert_almost_equal(dtype_mx_softmax, dtype_np_softmax, rtol=rtol, atol=atol, - names=("mx_softmax", "np_softmax")) - - data_var = mx.sym.Variable('data') - op_sym = getattr(mx.sym, op_name) - - sym = op_sym(data=data_var) if not odtype else op_sym(data=data_var, dtype=odtype) - - expected_fwd = np_op[op_name] - expected_bwd = np.zeros(input_data.shape) - check_symbolic_forward(sym, [input_data], [expected_fwd], dtype='asnumpy', atol=atol, - rtol=rtol) - - if op_name is not 'log_softmax': - check_symbolic_backward(sym, [input_data], [np.ones(expected_fwd.shape)], [expected_bwd], - rtol=grad_rtol, atol=grad_atol, dtype=idtype) - if idtype is not np.float16: - check_numeric_gradient(sym, location=[input_data], rtol=grad_rtol, atol=grad_atol, - dtype=idtype, - odtype=odtype) - - check_dtypes_almost_equal('softmax', 1e-3, 1e-5, 1e-2, 1e-5, np.float16) - check_dtypes_almost_equal('softmax', 1e-4, 1e-5, 1e-2, 1e-5, np.float16, np.float32) - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-3, 1e-5, np.float32) - check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-3, 1e-5, np.float32, np.float64) - check_dtypes_almost_equal('softmin', 1e-3, 1e-5, 1e-2, 1e-5, np.float16, np.float32) - check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-3, 1e-5, np.float32, np.float64) - check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-3, 1e-5, np.float16) - check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-3, 1e-5, np.float16, np.float32) - check_dtypes_almost_equal('log_softmax', 1e-3, 1e-5, 1e-3, 1e-5, np.float32) - check_dtypes_almost_equal('log_softmax', 1e-3, 1e-5, 1e-3, 1e-5, np.float32, np.float64) + dtype_softmax = op(dtype_input, axis=-1, dtype=odtype) + ref_softmax = op(ref_input, axis=-1, dtype=odtype) + dtype_softmax_np = dtype_softmax.asnumpy() + ref_softmax_np = ref_softmax.asnumpy() + assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol) + dtype_softmax.backward() + ref_softmax.backward() + dtype_grad_np = dtype_input.grad.asnumpy() + ref_grad_np = ref_input.grad.asnumpy() + assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) + + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64', 'float64') @with_seed() def test_pick(): From c47b02458ec59aa515a25d73ff42d43984e32f04 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Sun, 21 Apr 2019 23:50:54 -0700 Subject: [PATCH 6/6] don't remove MXNET_REAL_ACC_TYPE_SWITCH, since its used by other operators --- src/operator/mxnet_op.h | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d7227da8d40b..e331255c2e50 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -249,6 +249,65 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + typedef float AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + typedef uint8_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + typedef int8_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types not int8"; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + typedef int32_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32"; \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + typedef int64_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int64"; \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + #define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\ switch (type) { \ case mshadow::kFloat32: \