diff --git a/src/common/utils.h b/src/common/utils.h index 2b4b821a1835..b919cb301dff 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; } +inline int more_precise_type(const int type1, const int type2) { + if (type1 == type2) return type1; + if (is_float(type1) && is_float(type2)) { + if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) { + return mshadow::kFloat64; + } + if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) { + return mshadow::kFloat32; + } + return mshadow::kFloat16; + } else if (is_float(type1) || is_float(type2)) { + return is_float(type1) ? type1 : type2; + } + if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) { + return mshadow::kInt64; + } + if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) { + return mshadow::kInt32; + } + CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || + (type1 == mshadow::kInt8 && type2 == mshadow::kUint8))) + << "1 is UInt8 and 1 is Int8 should not get here"; + if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) { + return mshadow::kUint8; + } + return mshadow::kInt8; +} + +inline int np_binary_out_type(const int type1, const int type2) { + if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || + (type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) { + return mshadow::kInt32; + } + return more_precise_type(type1, type2); +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index d73fa1be54a4..3d81cfc0d967 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -134,8 +134,7 @@ class LeakyReLUOp : public Operator { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape, in_data[leakyrelu::kData].dptr(), in_data[leakyrelu::kGamma].dptr(), out_data[leakyrelu::kOut].dptr()); diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index c5a2b1308c73..1ece97b0efd8 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -132,6 +132,26 @@ struct true_divide : public mxnet_op::tunable { MSHADOW_XINLINE static float Map(DType a, DType b) { return static_cast(a) / static_cast(b); } + +#ifndef _WIN32 + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast(a) / b; + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast(a) / b; + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast(a) / b; + } +#endif }; struct rtrue_divide : public mxnet_op::tunable { @@ -146,6 +166,26 @@ struct rtrue_divide : public mxnet_op::tunable { MSHADOW_XINLINE static float Map(DType a, DType b) { return static_cast(b) / static_cast(a); } + +#ifndef _WIN32 + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return b / static_cast(a); + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return b / static_cast(a); + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return b / static_cast(a); + } +#endif }; MXNET_BINARY_MATH_OP_NC(left, a); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 91478660a123..5d297a547c8f 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -471,6 +471,69 @@ struct AccType { {__VA_ARGS__} \ } \ break; \ + case mshadow::kBool: \ + { \ + typedef bool DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not uint8"; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not int8"; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kBool: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not bool"; \ + } \ + break; \ default: \ LOG(FATAL) << "Unknown type enum " << type; \ } @@ -783,6 +846,56 @@ struct op_with_req { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } +#ifndef _WIN32 + /*! \brief inputs are two tensors with a half_t output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, + mshadow::half::half_t *out, + const DType *lhs, + const mshadow::half::half_t *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief inputs are two tensors with a float output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief inputs are two tensors with a double output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief inputs are two tensors with a half_t output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, + mshadow::half::half_t *out, + const DType *lhs, + const mshadow::half::half_t value) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); + } + + /*! \brief inputs are two tensors with a float output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); + } + + /*! \brief inputs are two tensors with a double output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); + } +#endif + /*! \brief inputs are two tensors with a float output tensor */ template::value, int>::type = 0> diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 61239d33800c..1eff5cd8591d 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -394,8 +394,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], lstride, rstride, oshape, in.dptr(), @@ -463,8 +462,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, grad.dptr(), mask.dptr(), gdata.dptr()); }); diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 601a0526650c..89da570c133b 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -790,7 +790,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, << "Mask needs to be provided when using softmax with use_length=True."; type = inputs[1].type_flag_; } - MXNET_INT_TYPE_SWITCH(type, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(type, IType, { IType* mask_ptr = nullptr; if (param.use_length.value()) { mask_ptr = inputs[1].dptr(); @@ -834,7 +834,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mxnet_op; if (softmax_use_length(attrs)) { - MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(inputs[2].type_flag_, IType, { if (req[1] != kNullOp) { mxnet_op::Kernel::Launch( ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); @@ -856,7 +856,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MXNET_INT_TYPE_SWITCH(itype, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(itype, IType, { IType * length_ptr = nullptr; if (softmax_use_length(attrs)) { length_ptr = inputs[2].dptr(); diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index cc74e19aef8f..0bc60a08803e 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -43,30 +43,42 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, CHECK_EQ(outputs.size(), 1U); if (req[0] == kNullOp || outputs[0].Size() == 0U) return; using namespace mshadow; + using namespace mxnet_op; using namespace mshadow::expr; Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); - if (common::is_float(inputs[0].type_flag_)) { + const TBlob& data = inputs[0]; + const TBlob& out = outputs[0]; + if (out.type_flag_ == data.type_flag_) { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + Kernel, xpu>::Launch( + s, data.Size(), out.dptr(), data.dptr(), DType(alpha)); }); }); } else { +#ifndef _WIN32 CHECK_EQ(outputs[0].type_flag_, kFloat32) << "true_divide only supports float32 output " "when input's dtype is " << type_string(inputs[0].type_flag_); MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + Kernel, xpu>::Launch( + s, data.Size(), out.dptr(), data.dptr(), + static_cast(alpha)); }); }); +#else + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(data.Size()), s); + TBlob temp_tblob(temp_tensor); + CastCompute(attrs, ctx, {data}, {kWriteTo}, {temp_tblob}); + TrueDivideScalarCompute(attrs, ctx, {temp_tblob}, req, outputs); +#endif } } -template +template void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, @@ -77,66 +89,254 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (common::is_float(inputs[0].type_flag_)) { - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Kernel, xpu>::Launch(s, outputs[0].Size(), - outputs[0].dptr(), - inputs[0].dptr(), - inputs[1].dptr()); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + if (lhs.type_flag_ == rhs.type_flag_) { + // Case when types of the 2 input tensors are the same + if (common::is_float(lhs.type_flag_)) { + // If both are the same floats, normal launch + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), rhs.dptr()); + }); }); } else { - CHECK_EQ(outputs[0].type_flag_, kFloat32) << "true_divide only supports float32 output " - "when input's dtype is " - << type_string(inputs[0].type_flag_); - MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Kernel, xpu>::Launch(s, outputs[0].Size(), - outputs[0].dptr(), - inputs[0].dptr(), - inputs[1].dptr()); + // If both are the same integers, output is float32 + CHECK_EQ(out.type_flag_, kFloat32) << "true_divide only supports float32 output " + "when input's dtype is " + << type_string(lhs.type_flag_); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), rhs.dptr()); + }); }); } - }); + } else { +#ifndef _WIN32 + // Non-windows case: no usage of temporary space + // Case when types of the 2 input tensors are different + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // both lhs and rhs are float types, output type is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // one is float type, the other is integer type, the output type should be the same as float + CHECK_EQ(out.type_flag_, + common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) + << "This case out type should be same as the float type"; + if (common::is_float(lhs.type_flag_)) { + // lhs is the float one + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), rhs.dptr(), lhs.dptr()); + }); + }); + }); + } else { + // rhs is the float one + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), rhs.dptr()); + }); + }); + }); + } + } else { + // lhs is integer type, rhs is integer type, output type should be float + LOG(ERROR) << "not implemented yet..."; + } +#else + // Windows case: using temp space for casting the type + // Case when types of the 2 input tensors are different + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // both lhs and rhs are float types, output type is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // lhs is float type, rhs is integer type, the output type should be the same as lhs + CHECK_EQ(out.type_flag_, + common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) + << "This case out type should be same as the float type"; + TBlob temp_tblob; + if (common::is_float(lhs.type_flag_)) { + // lhs is the float one + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + TrueDivideElemwiseCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + // rhs is the float one + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + TrueDivideElemwiseCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + } else { + // lhs is integer type, rhs is integer type, output type should be float + LOG(ERROR) << "not implemented yet..."; + } +#endif + } } -template +template void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + using namespace mxnet_op; if (outputs[0].shape_.Size() == 0U) return; + CHECK_EQ(inputs.size(), 2U); mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { - TrueDivideElemwiseCompute(attrs, ctx, inputs, req, outputs); + TrueDivideElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { if (req[0] == kNullOp) return; mshadow::Stream *s = ctx.get_stream(); + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; +#ifndef _WIN32 BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - if (common::is_float(inputs[0].type_flag_)) { - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - outputs[0].dptr()); - }); - } else { - CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) + mshadow::Shape lstride = calc_stride(new_lshape.get()); + mshadow::Shape rstride = calc_stride(new_rshape.get()); + if (lhs.type_flag_ == rhs.type_flag_) { + // When the both inputs have the same data types + if (common::is_float(lhs.type_flag_)) { + // If both inputs are the same float types, output is the same float type + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } else { + CHECK_EQ(out.type_flag_, mshadow::kFloat32) << "true_divide only supports float32 output when input's dtype is " - << type_string(inputs[0].type_flag_); - MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, { - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - outputs[0].dptr()); - }); + << type_string(lhs.type_flag_); + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { + // If both inputs are the same integer types, output is float type + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } + } else { + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // lhs and rhs have different float types, the output is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // one of lhs and rhs is float, the output is the same type as the float one + if (common::is_float(lhs.type_flag_)) { + // lhs is float type, output will be the same float type + CHECK_EQ(lhs.type_flag_, out.type_flag_) + << "lhs should have the same type as out, infer type broken?"; + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape, + rhs.dptr(), lhs.dptr(), out.dptr()); + }); + }); + } else { + // rhs is float type, output will be the same float type + CHECK_EQ(rhs.type_flag_, out.type_flag_) + << "rhs should have the same type as out, infer type broken?"; + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + }); + } + } else { + // lhs and rhs have different integer types, the output is float type + LOG(ERROR) << "not implemented yet..."; + } } }); +#else + if (lhs.type_flag_ == rhs.type_flag_) { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = calc_stride(new_lshape.get()); + mshadow::Shape rstride = calc_stride(new_rshape.get()); + // When the both inputs have the same data types + if (common::is_float(lhs.type_flag_)) { + // If both inputs are the same float types, output is the same float type + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } else { + CHECK_EQ(out.type_flag_, mshadow::kFloat32) + << "true_divide only supports float32 output when input's dtype is " + << type_string(lhs.type_flag_); + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { + // If both inputs are the same integer types, output is float type + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } + }); + } else { + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // lhs and rhs have different float types, the output is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // one of lhs and rhs is float, the output is the same type as the float one + TBlob temp_tblob; + if (common::is_float(lhs.type_flag_)) { + // lhs is float type, output will be the same float type + CHECK_EQ(lhs.type_flag_, out.type_flag_) + << "lhs should have the same type as out, infer type broken?"; + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + TrueDivideBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + // rhs is float type, output will be the same float type + CHECK_EQ(rhs.type_flag_, out.type_flag_) + << "rhs should have the same type as out, infer type broken?"; + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + TrueDivideBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + } else { + // lhs and rhs have different integer types, the output is float type + LOG(ERROR) << "not implemented yet..."; + } + } +#endif } } diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 5a4634c3ff8c..d2135befef42 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -28,26 +28,35 @@ namespace mxnet { namespace op { +int TrueDivideOutType(int ltype, int rtype) { + if (common::is_float(ltype) && common::is_float(rtype)) { + // If both inputs are float, return the one with the higher precision + return common::more_precise_type(ltype, rtype); + } else if (common::is_float(ltype) || common::is_float(rtype)) { + // If only one of the inputs is float, return that float type + return (common::is_float(ltype)) ? ltype : rtype; + } + // If neither of the inputs is float, return the default float32 type + return mshadow::kFloat32; +} + template bool TrueDivideType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(num_inputs)); + CHECK_GT(in_attrs->size(), 0U); CHECK_EQ(out_attrs->size(), 1U); + for (const int dtype : *in_attrs) { if (dtype == -1) return false; } - if (num_inputs == 2) { - const int lhs_dtype = in_attrs->at(0); - const int rhs_dtype = in_attrs->at(1); - CHECK_EQ(lhs_dtype, rhs_dtype) - << "true_divide currently only supports same dtype for dividend and divisor"; - } - if (common::is_float(in_attrs->at(0))) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - } else { - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); - } + + const int lhs_dtype = in_attrs->at(0); + const int rhs_dtype = (num_inputs == 2) ? + in_attrs->at(1) : + (common::is_float(lhs_dtype) ? lhs_dtype : mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 0, TrueDivideOutType(lhs_dtype, rhs_dtype)); return true; } @@ -64,7 +73,13 @@ NNVM_REGISTER_OP(_npi_true_divide) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}, {1, 0}}; }) -.set_attr("FCompute", TrueDivideBroadcastCompute) +#ifdef _WIN32 +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +#endif +.set_attr("FCompute", TrueDivideBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") .add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); @@ -81,6 +96,12 @@ NNVM_REGISTER_OP(_npi_true_divide_scalar) [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) +#ifdef _WIN32 +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") @@ -98,6 +119,12 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) +#ifdef _WIN32 +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index c026d689233d..7211f4a0a006 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -29,7 +29,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_true_divide) -.set_attr("FCompute", TrueDivideBroadcastCompute); +.set_attr("FCompute", TrueDivideBroadcastCompute); NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 3d3bcfacbd05..ad06df8d92be 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -187,9 +187,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: } namespace mxnet_op { -template +template struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ + template MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, const Shape &oshape, IType *lhs, IType *rhs, @@ -208,6 +209,7 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ + template MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, const Shape &oshape, IType lhs, IType *rhs, @@ -224,6 +226,49 @@ struct binary_broadcast_kernel { KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); } } + +#ifndef _WIN32 + /*! \brief Map function for binary_broadcast_kernel */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType *lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ + template::value && + !std::is_pointer::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); + } + } +#endif }; template @@ -307,7 +352,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); }); @@ -336,7 +381,7 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); @@ -444,11 +489,11 @@ void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx, Shape lstride = calc_stride(new_csrshape.get()); Shape rstride = calc_stride(new_dnsshape.get()); if (reverse && std::is_same::value) { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } else { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } @@ -658,7 +703,7 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs) { \ return std::vector{"lhs", "rhs"}; \ }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ .set_attr("FInferType", ElemwiseType<2, 1>) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 02b005eed995..834bbdbfc3d1 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -256,7 +256,7 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 3b4b4b6491fc..e6d12da23582 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1940,7 +1940,7 @@ def get_new_shape(shape, axis): with mx.autograd.record(): y = test_concat(a, b, c, d) - + assert y.shape == expected_ret.shape assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) @@ -2933,7 +2933,7 @@ def check_cholesky(L, data_np): test_cholesky = TestCholesky() if hybridize: test_cholesky.hybridize() - + # Numerical issue: # When backpropagating through Cholesky decomposition, we need to compute the inverse # of L according to dA = 0.5 * L**(-T) * copyLTU(L**T * dL) * L**(-1) where A = LL^T. @@ -3847,12 +3847,14 @@ def test_np_true_divide(): [(2, 3, 1), (1, 4)], [(2, 1, 4, 1), (3, 1, 5)], ] - dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + dtypes = [np.bool, np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + itypes = [np.bool, np.int8, np.uint8, np.int32, np.int64] + ftypes = [np.float16, np.float32, np.float64] for shape_pair, dtype in itertools.product(shapes, dtypes): a = np.random.uniform(3, 50, size=shape_pair[0]).astype(dtype) b = np.random.uniform(3, 50, size=shape_pair[-1]).astype(dtype) out_mx = a / b - if _np.issubdtype(dtype, _np.integer): + if _np.issubdtype(dtype, _np.integer) or (dtype is np.bool): assert out_mx.dtype == np.float32 else: assert out_mx.dtype == dtype @@ -3868,6 +3870,20 @@ def test_np_true_divide(): out_np = _np.true_divide(val, a.asnumpy()) assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-3, atol=1e-3, use_broadcast=False) + for shape_pair, itype, ftype in itertools.product(shapes, itypes, ftypes): + i_ = np.random.uniform(3, 50, size=shape_pair[0]).astype(itype) + f_ = np.random.uniform(3, 50, size=shape_pair[-1]).astype(ftype) + + out_mx = i_ / f_ + assert out_mx.dtype == ftype + out_np = _np.true_divide(i_.asnumpy(), f_.asnumpy()) + assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-3, atol=1e-3, use_broadcast=False) + + out_mx = f_ / i_ + assert out_mx.dtype == ftype + out_np = _np.true_divide(f_.asnumpy(), i_.asnumpy()) + assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-3, atol=1e-3, use_broadcast=False) + @with_seed() @use_np