diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 6cab1990858b..d8fc5031e4ff 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -249,6 +249,48 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + typedef float AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + break; \ + case mshadow::kInt8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not int8"; \ + break; \ + case mshadow::kInt32: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32"; \ + break; \ + case mshadow::kInt64: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int64"; \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } /*! * \brief assign the val to out according diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index c063e385f63a..b663e7de4698 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -26,6 +26,7 @@ #define MXNET_OPERATOR_NN_SOFTMAX_INL_H_ #include +#include #include "../mxnet_op.h" #include "../operator_common.h" @@ -36,23 +37,33 @@ namespace op { namespace mxnet_op { struct softmax_fwd { - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return DType(expf(a)/b); + template + MSHADOW_XINLINE static AType Map(float a, AType b) { + return AType(expf(a)/b); + } + + template + MSHADOW_XINLINE static AType Map(double a, AType b) { + return AType(exp(a)/b); } }; struct log_softmax_fwd { template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return DType(a - logf(b)); + MSHADOW_XINLINE static float Map(DType a, float b) { + return a - logf(b); + } + + template + MSHADOW_XINLINE static double Map(DType a, double b) { + return a - log(b); } }; -template -inline void Softmax(Stream *s, DType *in, DType *out, +template +inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; @@ -72,10 +83,9 @@ inline void Softmax(Stream *s, DType *in, DType *out, if (mmax < val) mmax = val; } - DType sum = DType(0); + AType sum = AType(0); DType in_val; - // By default temperature is 1.0, and only in reinforcement training - // users would set it to other values. + // By default temperature is 1.0. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { @@ -103,23 +113,29 @@ inline void Softmax(Stream *s, DType *in, DType *out, struct softmax_bwd { - template - MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) { - return DType(out * (ograd - sum)); + template + MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) { + return AType(out * (ograd - sum)); } }; struct log_softmax_bwd { - template - MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) { - return DType(ograd - expf(out)*sum); + template + MSHADOW_XINLINE static AType Map(float ograd, float out, AType sum) { + return AType(ograd - expf(out)*sum); + } + + template + MSHADOW_XINLINE static AType Map(double ograd, double out, AType sum) { + return AType(ograd - exp(out)*sum); } }; -template -inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, +template +inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; @@ -133,13 +149,12 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, for (int i = 0; i < static_cast(N); ++i) { index_t base = unravel_dot(i, sshape, stride); - DType sum = DType(0); + AType sum = AType(0); for (index_t j = 0; j < M; ++j) { sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - // By default temperature is 1.0, and only in reinforcement training - // users would set it to other values. + // By default temperature is 1.0. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime DType final_result; if (temperature == 1.0) { @@ -162,19 +177,20 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, #ifdef __CUDACC__ -template -__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, +template +__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; - __shared__ DType smem[x_size]; + __shared__ AType smem[x_size]; index_t sa = stride[axis]; index_t base = unravel_dot(blockIdx.x, sshape, stride); index_t x = threadIdx.x; red::maximum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::maximum::Reduce(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); + smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); } __syncthreads(); cuda::Reduce1D(smem); @@ -186,13 +202,12 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi DType val; for (index_t i = x; i < M; i += x_size) { val = negate ? -in[base + i*sa]:in[base + i*sa]; - red::sum::Reduce( - smem[x], static_cast(expf((val - smax) / static_cast(temperature)))); + smem[x] += static_cast(expf((val - smax) / static_cast(temperature))); } __syncthreads(); cuda::Reduce1D(smem); __syncthreads(); - DType ssum = smem[0]; + AType ssum = smem[0]; __syncthreads(); for (index_t i = x; i < M; i += x_size) { @@ -201,8 +216,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi } } -template -inline void Softmax(Stream *s, DType *in, DType *out, +template +inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; @@ -212,31 +227,32 @@ inline void Softmax(Stream *s, DType *in, DType *out, Shape sshape = shape; sshape[axis] = 1; - softmax_compute_kernel + softmax_compute_kernel <<::GetStream(s)>>>( in, out, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } -template -__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, +template +__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; - __shared__ DType smem[x_size]; + __shared__ AType smem[x_size]; index_t sa = stride[axis]; index_t base = unravel_dot(blockIdx.x, sshape, stride); index_t x = threadIdx.x; red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], OP1::Map(ograd[base + i*sa], out[base + i*sa])); + smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]); } __syncthreads(); cuda::Reduce1D(smem); __syncthreads(); - DType ssum = smem[0]; + AType ssum = smem[0]; __syncthreads(); DType final_result; @@ -250,8 +266,9 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, } -template -inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, +template +inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const double temperature) { const int x_bits = 7; @@ -262,7 +279,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, Shape sshape = shape; sshape[axis] = 1; - softmax_gradient_kernel + softmax_gradient_kernel <<::GetStream(s)>>>( out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); @@ -275,14 +292,70 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; dmlc::optional temperature; + dmlc::optional dtype; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) - .describe("The axis along which to compute softmax."); + .describe("The axis along which to compute softmax."); DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional()) - .describe("Temperature parameter in softmax"); + .describe("Temperature parameter in softmax"); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .set_default(dmlc::optional()) + .describe("DType of the output in case this can't be inferred. " + "Defaults to the same as input's dtype if not defined (dtype=None)."); } }; +inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + + int arg_dtype = param.dtype.has_value()?param.dtype.value():-1, + in_dtype = (*in_attrs)[0], + out_dtype = (*out_attrs)[0]; + + if (out_dtype != -1 && in_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + return true; + } else if (in_dtype != -1) { + if (arg_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + } + return true; + } else if (out_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + return true; + } else { + if (arg_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + } + return false; + } +} + +inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), 1); + + int in_dtype = (*in_attrs)[1], + out_dtype = (*in_attrs)[2]; + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + + return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; +} + template void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -297,16 +370,20 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - if (shape.ndim() == 2) { - Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); - } + MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } + }); }); } @@ -324,17 +401,21 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (shape.ndim() == 2) { - SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); - } + MXNET_REAL_ACC_TYPE_SWITCH(inputs[2].type_flag_, OType, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[2].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[2].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } + }); }); }); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 81e775cac526..1d6cef58263c 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -67,7 +67,7 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, } #endif -MXNET_OPERATOR_REGISTER_UNARY(softmax) +NNVM_REGISTER_OP(softmax) .describe(R"code(Applies the softmax function. The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. @@ -102,15 +102,39 @@ Example:: .set_attr("FComputeEx", SoftmaxComputeExCPU) .set_attr("FInferStorageType", SoftmaxStorageType) #endif -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_softmax"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_softmax"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") .add_arguments(SoftmaxParam::__FIELDS__()); -MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax) +NNVM_REGISTER_OP(_backward_softmax) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"ograd", "data", "output"}; + }) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + }) +.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("output", "NDArray-or-Symbol", "output") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); -MXNET_OPERATOR_REGISTER_UNARY(softmin) +NNVM_REGISTER_OP(softmin) .describe(R"code(Applies the softmin function. The resulting array contains elements in the range (0,1) and the elements along the given axis sum @@ -141,15 +165,39 @@ Example:: return std::vector{"output"}; }) .set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_softmin"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_softmin"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") .add_arguments(SoftmaxParam::__FIELDS__()); -MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin) +NNVM_REGISTER_OP(_backward_softmin) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"ograd", "data", "output"}; + }) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + }) +.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("output", "NDArray-or-Symbol", "output") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); -MXNET_OPERATOR_REGISTER_UNARY(log_softmax) +NNVM_REGISTER_OP(log_softmax) .describe(R"code(Computes the log softmax of the input. This is equivalent to computing softmax followed by log. @@ -168,10 +216,34 @@ Examples:: )code") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_log_softmax"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") .add_arguments(SoftmaxParam::__FIELDS__()); -MXNET_OPERATOR_REGISTER_BINARY(_backward_log_softmax) +NNVM_REGISTER_OP(_backward_log_softmax) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"ograd", "data", "output"}; + }) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + }) +.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("output", "NDArray-or-Symbol", "output") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7b5b9ebf3be4..4cf0a970e15e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4515,6 +4515,47 @@ def softmax_forward(input_data, true_output): softmax_forward(mx.nd.array([[[[-3.4e38,-3.4e38]]]]), np.array([1.0,1.0])) softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0])) +@with_seed() +def test_softmax_dtype(): + def check_dtypes_almost_equal(op_name, + atol, rtol, + grad_atol, grad_rtol, + idtype, ref_dtype, odtype=None): + op = getattr(mx.nd, op_name) + input_data = mx.random.uniform(shape=(100, 500)) + dtype_input = input_data.astype(idtype) + ref_input = input_data.astype(ref_dtype) + dtype_input.attach_grad() + ref_input.attach_grad() + with mx.autograd.record(): + dtype_softmax = op(dtype_input, axis=-1, dtype=odtype) + ref_softmax = op(ref_input, axis=-1, dtype=odtype) + dtype_softmax_np = dtype_softmax.asnumpy() + ref_softmax_np = ref_softmax.asnumpy() + assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol) + dtype_softmax.backward() + ref_softmax.backward() + dtype_grad_np = dtype_input.grad.asnumpy() + ref_grad_np = ref_input.grad.asnumpy() + assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) + + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64', 'float64') + @with_seed() def test_pick(): def test_pick_helper(index_type=np.int32):