From 076b2f330c60f05cb939beea28dd04cd571a34c0 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 19 Jul 2019 10:43:10 -0700 Subject: [PATCH] Softmax with length (#15169) * softmax with length forward * softmax with length backward * new macro to reduce compile-time heap usage and limit length to integers only * address comments --- src/operator/mxnet_op.h | 51 +++ src/operator/nn/softmax-inl.h | 428 +++++++++++++++++++++---- src/operator/nn/softmax.cc | 33 +- tests/python/unittest/test_operator.py | 39 ++- 4 files changed, 487 insertions(+), 64 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f17b708a7687..52788f697f11 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -363,6 +363,57 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! * \brief assign the val to out according * to request in Kernel::Launch diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index d6113b05dbb9..2c82d839e5ed 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -75,7 +75,7 @@ inline void Softmax(Stream *s, DType *in, OType *out, index_t sa = stride[axis]; #pragma omp parallel for - for (int i = 0; i < static_cast(N); ++i) { + for (index_t i = 0; i < N; ++i) { index_t base = unravel_dot(i, sshape, stride); DType mmax = negate ? -in[base] : in[base]; @@ -113,6 +113,60 @@ inline void Softmax(Stream *s, DType *in, OType *out, } } +template +inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *length, + Shape shape, int axis, const DType temperature) { + index_t M = shape[axis]; + index_t N = shape.Size()/M; + Shape stride = calc_stride(shape); + Shape sshape = shape; + sshape[axis] = 1; + index_t sa = stride[axis]; + + #pragma omp parallel for + for (index_t i = 0; i < N; ++i) { + index_t len = static_cast(length[i]); + index_t base = unravel_dot(i, sshape, stride); + + DType mmax = negate ? -in[base] : in[base]; + DType val; + for (index_t j = 1; j < len; ++j) { + val = negate ? -in[base + j*sa] : in[base + j*sa]; + if (mmax < val) mmax = val; + } + for (index_t j = len; j < M; ++j) { + out[base + j*sa] = OType(0.0f); + } + + AType sum = AType(0); + DType in_val; + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + if (temperature == 1.0) { + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp(in_val - mmax); + } + + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map(in_val - mmax, sum); + } + } else { + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp((in_val - mmax)/temperature); + } + + for (index_t j = 0; j < len; ++j) { + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum); + } + } + } +} + struct softmax_bwd { template @@ -136,7 +190,7 @@ struct log_softmax_bwd { template + typename AType, typename DType, typename OType, int ndim> inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const DType temperature) { @@ -148,7 +202,7 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, index_t sa = stride[axis]; #pragma omp parallel for - for (int i = 0; i < static_cast(N); ++i) { + for (index_t i = 0; i < N; ++i) { index_t base = unravel_dot(i, sshape, stride); AType sum = AType(0); @@ -177,10 +231,55 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, } } +template +inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, + DType *igrad, IType *length, Shape shape, + int axis, const DType temperature) { + index_t M = shape[axis]; + index_t N = shape.Size()/M; + Shape stride = calc_stride(shape); + Shape sshape = shape; + sshape[axis] = 1; + index_t sa = stride[axis]; + + #pragma omp parallel for + for (index_t i = 0; i < N; ++i) { + index_t base = unravel_dot(i, sshape, stride); + index_t len = static_cast(length[i]); + + AType sum = AType(0); + for (index_t j = 0; j < len; ++j) { + sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); + } + + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + DType final_result; + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + final_result = (j < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } else { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; + final_result = (j < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } + } +} + #ifdef __CUDACC__ template + typename DType, typename OType> __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { @@ -235,9 +334,68 @@ inline void Softmax(Stream *s, DType *in, OType *out, MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } +template +__global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length, + index_t M, int axis, Shape sshape, + Shape stride, const double temperature) { + const unsigned x_size = 1 << x_bits; + __shared__ AType smem[x_size]; + index_t sa = stride[axis]; + index_t base = unravel_dot(blockIdx.x, sshape, stride); + index_t x = threadIdx.x; + index_t len = static_cast(length[blockIdx.x]); + + red::maximum::SetInitValue(smem[x]); + for (index_t i = x; i < len; i += x_size) { + smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); + } + __syncthreads(); + cuda::Reduce1D(smem); + __syncthreads(); + DType smax = smem[0]; + __syncthreads(); + + red::sum::SetInitValue(smem[x]); + DType val; + for (index_t i = x; i < len; i += x_size) { + val = negate ? -in[base + i*sa]:in[base + i*sa]; + smem[x] += static_cast(expf((val - smax) / static_cast(temperature))); + } + __syncthreads(); + cuda::Reduce1D(smem); + __syncthreads(); + AType ssum = smem[0]; + __syncthreads(); + + for (index_t i = x; i < M; i += x_size) { + val = negate ? -in[base + i*sa] : in[base + i*sa]; + out[base + i*sa] = + (i < len) ? OType(OP::Map((val - smax)/static_cast(temperature), ssum)) : OType(0.0f); + } +} + +template +inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *length, + Shape shape, int axis, const double temperature) { + const int x_bits = 7; + const int x_size = 1 << x_bits; + index_t M = shape[axis]; + index_t N = shape.Size()/M; + Shape stride = calc_stride(shape); + Shape sshape = shape; + sshape[axis] = 1; + + softmax_with_length_kernel + <<::GetStream(s)>>>( + in, out, length, M, axis, sshape, stride, temperature); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); +} + template + typename DType, typename OType> __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { @@ -269,7 +427,7 @@ __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, template + typename DType, typename OType> inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const double temperature) { @@ -286,6 +444,60 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); } + +template +__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad, + IType *length, index_t M, int axis, + Shape sshape, Shape stride, + const double temperature) { + const unsigned x_size = 1 << x_bits; + __shared__ AType smem[x_size]; + index_t sa = stride[axis]; + index_t base = unravel_dot(blockIdx.x, sshape, stride); + index_t x = threadIdx.x; + index_t len = static_cast(length[blockIdx.x]); + + red::sum::SetInitValue(smem[x]); + for (index_t i = x; i < len; i += x_size) { + smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]); + } + __syncthreads(); + cuda::Reduce1D(smem); + __syncthreads(); + AType ssum = smem[0]; + __syncthreads(); + + DType final_result; + for (index_t i = x; i < M; i += x_size) { + final_result = + negate ? + -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) : + OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); + final_result = (i < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast(temperature)); + } +} + + +template +inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, + DType *igrad, IType *length, Shape shape, int axis, + const double temperature) { + const int x_bits = 7; + const int x_size = 1 << x_bits; + index_t M = shape[axis]; + index_t N = shape.Size()/M; + Shape stride = calc_stride(shape); + Shape sshape = shape; + sshape[axis] = 1; + + softmax_with_length_grad_kernel + <<::GetStream(s)>>>( + out, ograd, igrad, length, M, axis, sshape, stride, temperature); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_with_length_grad_kernel); +} #endif } // namespace mxnet_op @@ -295,6 +507,7 @@ struct SoftmaxParam : public dmlc::Parameter { int axis; dmlc::optional temperature; dmlc::optional dtype; + dmlc::optional use_length; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) .describe("The axis along which to compute softmax."); @@ -307,6 +520,9 @@ struct SoftmaxParam : public dmlc::Parameter { .set_default(dmlc::optional()) .describe("DType of the output in case this can't be inferred. " "Defaults to the same as input's dtype if not defined (dtype=None)."); + DMLC_DECLARE_FIELD(use_length) + .set_default(dmlc::optional(false)) + .describe("Whether to use the length input as a mask over the data input."); } }; @@ -315,27 +531,71 @@ static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { return param.dtype.has_value() && param.dtype.value() != -1; } +static inline bool softmax_use_length(const nnvm::NodeAttrs& attrs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + return param.use_length.value(); +} + static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); const SoftmaxParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 2U : 1U); if (softmax_has_dtype_override(attrs)) { TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); type_assign(&(*in_attrs)[0], (*out_attrs)[0]); return true; } else { - return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs); + std::vector tmp = {in_attrs->at(0)}; + return ElemwiseType<1, 1>(attrs, &tmp, out_attrs); + } +} + +static inline bool SoftmaxOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(out_attrs->size(), 1U); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), param.use_length.value() ? 2U : 1U); + + if (param.use_length.value()) { + mxnet::TShape& dshape = in_attrs->at(0); + mxnet::TShape tmp_shape((dshape.ndim() == 1) ? 1U : dshape.ndim() - 1, 1); + int j = 0; + for (int i = 0; i < dshape.ndim(); ++i) { + if (i != param.axis) { + tmp_shape[j++] = dshape[i]; + } + } + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); } + mxnet::ShapeVector tmp = {in_attrs->at(0)}; + return ElemwiseShape<1, 1>(attrs, &tmp, out_attrs); } static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { - if (softmax_has_dtype_override(attrs)) { - return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + if (softmax_use_length(attrs)) { + mxnet::ShapeVector ins = {in_attrs->at(0), in_attrs->at(1), in_attrs->at(3)}; + mxnet::ShapeVector dgrad = {out_attrs->at(0)}; + bool res = ElemwiseShape<3, 1>(attrs, &ins, &dgrad); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, ins[0]); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, ins[1]); + SHAPE_ASSIGN_CHECK(*in_attrs, 3, ins[2]); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dgrad[0]); + mxnet::ShapeVector length = {in_attrs->at(2)}; + mxnet::ShapeVector lgrad = {out_attrs->at(1)}; + res = (res && ElemwiseShape<1, 1>(attrs, &length, &lgrad)); + SHAPE_ASSIGN_CHECK(*in_attrs, 2, length[0]); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, lgrad[0]); + return res; + } else { + return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); + } } else { return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs); } @@ -344,17 +604,21 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { - CHECK_EQ(out_attrs->size(), 1); - if (softmax_has_dtype_override(attrs)) { - CHECK_EQ(in_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U); + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 4U : 3U); int in_dtype = (*in_attrs)[1]; - int out_dtype = (*in_attrs)[2]; + int out_dtype = (*in_attrs)[softmax_use_length(attrs) ? 3 : 2]; TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + if (softmax_use_length(attrs)) { + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(2)); + } - return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; + return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1 && + (*out_attrs)[1] != -1 && (*in_attrs)[1] != -1; } else { - CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(in_attrs->size(), 2U); int out_dtype = (*in_attrs)[1]; TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); @@ -365,20 +629,31 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, static inline std::vector > SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { - return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + if (softmax_use_length(attrs)) { + return std::vector >{{0, 0}, {1, 0}, {2, 1}, {3, 0}}; + } else { + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + } } else { return std::vector >{{0, 0}, {1, 0}}; } } static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) { - return softmax_has_dtype_override(attrs) ? 3 : 2; + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + return softmax_use_length(attrs) ? 4 : 3; + } + return 2; } static inline std::vector SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { - return std::vector{"ograd", "data", "output"}; + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + if (softmax_use_length(attrs)) { + return std::vector{"ograd", "data", "length", "output"}; + } else { + return std::vector{"ograd", "data", "output"}; + } } else { return std::vector{"ograd", "output"}; } @@ -388,7 +663,7 @@ struct SoftmaxFGradient { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, const std::vector& ograds) const { - if (softmax_has_dtype_override(n->attrs)) { + if (softmax_has_dtype_override(n->attrs) || softmax_use_length(n->attrs)) { return ElemwiseGradUseInOut {op_name}(n, ograds); } else { return ElemwiseGradUseOut {op_name}(n, ograds); @@ -419,30 +694,46 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - if (safe_acc) { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); + if (!param.use_length.value()) { + if (safe_acc) { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } } } else { - if (shape.ndim() == 2) { - Softmax( + MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, { + if (shape.ndim() == 2) { + SoftmaxWithLength( ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax( + outputs[0].dptr(), inputs[1].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxWithLength( ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); - } + outputs[0].dptr(), inputs[1].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } + }); } }); }); @@ -464,35 +755,56 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; + out_idx = softmax_use_length(attrs) ? 3 : out_idx; bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (safe_acc) { - if (shape.ndim() == 2) { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); + if (!softmax_use_length(attrs)) { + if (safe_acc) { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } } else { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } } } else { - if (shape.ndim() == 2) { - SoftmaxGrad( + MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { + if (req[1] != kNullOp) { + mxnet_op::Kernel::Launch( + ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); + } + if (shape.ndim() == 2) { + SoftmaxWithLengthGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxGrad( + inputs[2].dptr(), shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxWithLengthGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); - } + inputs[2].dptr(), shape.get<3>(), axis, static_cast(temperature)); + } + }); } }); }); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index e44bbbb6b8f6..5a581e4ea5ef 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -59,14 +59,23 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1); - CHECK_EQ(out_attrs->size(), 1); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U); + CHECK_EQ(out_attrs->size(), 1U); + + if (param.use_length.value()) { + auto& out_stype = out_attrs->at(0); + return storage_type_assign(&out_stype, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); } #endif + + NNVM_REGISTER_OP(softmax) .describe(R"code(Applies the softmax function. @@ -92,6 +101,13 @@ Example:: )code" ADD_FILELINE) .set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs){ + const SoftmaxParam& param = nnvm::get(attrs.parsed); + return (param.use_length.value()) ? + std::vector{"data", "length"} : + std::vector{"data"}; +}) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { return std::vector{"output"}; @@ -103,20 +119,27 @@ Example:: .set_attr("FInferStorageType", SoftmaxStorageType) #endif .set_attr("FGradient", SoftmaxFGradient{"_backward_softmax"}) +// .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FInferType", SoftmaxOpType) -.set_num_inputs(1) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + return (param.use_length.value()) ? 2 : 1; + }) .set_num_outputs(1) -.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferShape", SoftmaxOpShape) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) .add_argument("data", "NDArray-or-Symbol", "The input array.") +.add_argument("length", "NDArray-or-Symbol", "The length array.") .add_arguments(SoftmaxParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_softmax) .set_num_inputs(SoftmaxGradOpNumInputs) -.set_num_outputs(1) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + return (softmax_use_length(attrs) ? 2 : 1); + }) .set_attr("FListInputNames", SoftmaxGradOpInputNames) .set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 749f0f2bed23..fea07f540624 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5196,6 +5196,39 @@ def check_dtypes_almost_equal(op_name, check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, 'float32', 'float64', 'float64') + +@with_seed() +def test_softmax_with_length(): + def np_softmax_with_length(data, length): + res = np.zeros(data.shape) + for i in range(length.shape[0]): + for j in range(length.shape[1]): + leng = int(length[i, j]) + res[i, 0:leng, j] = np_softmax(data[i, 0:leng, j]) + return res + + ndim = 3 + shape = rand_shape_nd(ndim, dim=10) + len_shape = list(shape) + del len_shape[1] + len_shape = tuple(len_shape) + for dtype in [np.float16, np.float32, np.float64]: + mx_data = rand_ndarray(shape, dtype=dtype) + np_data = mx_data.asnumpy() + np_length = np.random.randint(1, shape[1] + 1, len_shape) + mx_length = mx.nd.array(np_length, dtype=np.int32) + np_out = np_softmax_with_length(np_data, np_length) + data = mx.sym.Variable("data") + length = mx.sym.Variable("length") + mx_sym = mx.sym.softmax(data=data, length=length, use_length=True, axis=1) + location = {"data": mx_data, "length": mx_length} + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy") + check_symbolic_backward(mx_sym, location, [np.ones(shape, dtype=dtype)], + [np.zeros(shape), np.zeros(len_shape, dtype=np.int32)], rtol=1e-2, atol=1e-3, dtype="asnumpy") + + @with_seed() def test_pick(): def test_pick_helper(index_type=np.int32): @@ -8034,7 +8067,11 @@ def get_output_names_callback(name, arr): check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output']) sm_sym = mx.sym.softmax(data, name='softmax') - check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output']) + check_name(sm_sym, ['data', 'softmax_data', 'softmax_output']) + + length = mx.sym.Variable("length", shape=(10, 10, 10)) + sm_sym = mx.sym.softmax(data, length, axis=1, use_length=True, name='softmax') + check_name(sm_sym, ['data', 'softmax_data', 'length', 'softmax_length', 'softmax_output']) sa_sym = mx.sym.SoftmaxActivation(data, name='softmax') check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])