From f9140bb5c64bd2eb37ae061cb380e63201f62cd2 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Sun, 3 Nov 2019 10:34:15 +0000 Subject: [PATCH 1/2] support mixed-precision binary operations --- src/common/utils.h | 6 +- src/operator/mshadow_op.h | 94 +++++ src/operator/mxnet_op.h | 14 +- .../numpy/np_elemwise_broadcast_op.cc | 100 +++++- .../numpy/np_elemwise_broadcast_op.cu | 40 ++- src/operator/numpy/np_elemwise_broadcast_op.h | 332 ++++++++++++++++++ src/operator/numpy/np_true_divide.cc | 2 +- src/operator/tensor/elemwise_binary_op.h | 2 + tests/python/unittest/test_numpy_op.py | 63 ++++ 9 files changed, 632 insertions(+), 21 deletions(-) create mode 100644 src/operator/numpy/np_elemwise_broadcast_op.h diff --git a/src/common/utils.h b/src/common/utils.h index b919cb301dff..d1a591aec2ac 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -842,7 +842,7 @@ inline bool is_float(const int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; } -inline int more_precise_type(const int type1, const int type2) { +inline int get_more_precise_type(const int type1, const int type2) { if (type1 == type2) return type1; if (is_float(type1) && is_float(type2)) { if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) { @@ -870,12 +870,12 @@ inline int more_precise_type(const int type1, const int type2) { return mshadow::kInt8; } -inline int np_binary_out_type(const int type1, const int type2) { +inline int np_binary_out_infer_type(const int type1, const int type2) { if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || (type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) { return mshadow::kInt32; } - return more_precise_type(type1, type2); + return get_more_precise_type(type1, type2); } } // namespace common diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 1ece97b0efd8..b8db165675a0 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -194,6 +194,100 @@ MXNET_BINARY_MATH_OP_NC(right, b); MXNET_BINARY_MATH_OP_NC(mul, a * b); +#ifndef _WIN32 +struct mixed_plus { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast(a) + b; + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast(a) + b; + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast(a) + b; + } +}; + +struct mixed_minus { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast(a) - b; + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast(a) - b; + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast(a) - b; + } +}; + +struct mixed_rminus { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return b - static_cast(a); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return b - static_cast(a); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return b - static_cast(a); + } +}; + +struct mixed_mul { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast(a) * b; + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast(a) * b; + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast(a) * b; + } +}; +#endif + MXNET_BINARY_MATH_OP_NC(div, a / b); MXNET_BINARY_MATH_OP_NC(plus, a + b); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 5d297a547c8f..b15117f9f83b 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -859,14 +859,17 @@ struct op_with_req { /*! \brief inputs are two tensors with a float output tensor */ template::value, int>::type = 0> + typename std::enable_if::value || + std::is_integral::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } /*! \brief inputs are two tensors with a double output tensor */ template::value, int>::type = 0> + typename std::enable_if::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } @@ -883,14 +886,17 @@ struct op_with_req { /*! \brief inputs are two tensors with a float output tensor */ template::value, int>::type = 0> + typename std::enable_if::value || + std::is_integral::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } /*! \brief inputs are two tensors with a double output tensor */ template::value, int>::type = 0> + typename std::enable_if::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index c206ad453ba6..c32820d27361 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,8 +23,7 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ -#include "../tensor/elemwise_binary_broadcast_op.h" -#include "../tensor/elemwise_binary_scalar_op.h" +#include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { @@ -55,17 +54,102 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, .add_argument("data", "NDArray-or-Symbol", "source input") \ .add_argument("scalar", "float", "scalar input") +bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int ltype = in_attrs->at(0); + const int rtype = in_attrs->at(1); + if (ltype != -1 && rtype != -1 && (ltype != rtype)) { + // Only when both input types are known and not the same, we enter the mixed-precision mode + TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype)); + } else { + return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs); + } + return true; +} -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add) -.set_attr("FCompute", BinaryBroadcastCompute) +#ifdef _WIN32 +#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryMixedPrecisionType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") +#else +#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"lhs", "rhs"}; \ + }) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferType", NumpyBinaryMixedPrecisionType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}, {1, 0}}; \ + }) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") +#endif + +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) +#ifndef _WIN32 +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute) +#else +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute) +#endif .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract) -.set_attr("FCompute", BinaryBroadcastCompute) +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) +#ifndef _WIN32 +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute) +#else +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute) +#endif .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply) -.set_attr("FCompute", BinaryBroadcastCompute) +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) +#ifndef _WIN32 +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute) +#else +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute) +#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a682ec989ea8..153ffd0048dd 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -22,20 +22,50 @@ * \file np_elemwise_broadcast_op.cu * \brief GPU Implementation of basic functions for elementwise binary broadcast operator. */ -#include "../tensor/elemwise_binary_broadcast_op.h" -#include "../tensor/elemwise_binary_scalar_op.h" + +#include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_add) -.set_attr("FCompute", BinaryBroadcastCompute); +#ifndef _WIN32 +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute); +#else +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute); +#endif NNVM_REGISTER_OP(_npi_subtract) -.set_attr("FCompute", BinaryBroadcastCompute); +#ifndef _WIN32 +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute); +#else +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute); +#endif NNVM_REGISTER_OP(_npi_multiply) -.set_attr("FCompute", BinaryBroadcastCompute); +#ifndef _WIN32 +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute); +#else +.set_attr( + "FCompute", + MixedBinaryBroadcastCompute); +#endif NNVM_REGISTER_OP(_npi_mod) .set_attr("FCompute", BinaryBroadcastCompute); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h new file mode 100644 index 000000000000..48a64bdf7ee2 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_binary_op.h + * \brief + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ + +#include + +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + +#ifndef _WIN32 +template +void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(lhs.type_flag_, out.type_flag_); + + Stream *s = ctx.get_stream(); + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size == 0) return; + + switch (lhs.type_flag_) { + case mshadow::kFloat32: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel, xpu>::Launch( + s, size, out.dptr(), rhs.dptr(), + lhs.dptr()); + }); + } else { + LOG(ERROR) << "Should not reach here!"; + } + break; + } + case mshadow::kFloat64: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel, xpu>::Launch( + s, size, out.dptr(), rhs.dptr(), + lhs.dptr()); + }); + } else if (rhs.type_flag_ == mshadow::kFloat32) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel, xpu>::Launch( + s, size, out.dptr(), rhs.dptr(), + lhs.dptr()); + }); + } else { + LOG(ERROR) << "Should not reach here!"; + } + break; + } + default: + { + LOG(ERROR) << "Not supported case of ..."; + break; + } + } + }); +} + +template +void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(lhs.type_flag_, out.type_flag_); + + Stream *s = ctx.get_stream(); + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, FType, { + const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size == 0) return; + + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, IType, { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel, xpu>::Launch( + s, size, out.dptr(), rhs.dptr(), + lhs.dptr()); + }); + }); + }); +} + +template +void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ == out.type_flag_) { + MixedAllRealBinaryElemwiseCompute(ctx, lhs, rhs, out, req[0]); + } else { + MixedAllRealBinaryElemwiseCompute(ctx, rhs, lhs, out, req[0]); + } + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ == out.type_flag_) { + MixedIntRealBinaryElemwiseCompute(ctx, lhs, rhs, out, req[0]); + } else { + MixedIntRealBinaryElemwiseCompute(ctx, rhs, lhs, out, req[0]); + } + } else { + LOG(ERROR) << "not implemented yet..."; + } +} + +template +void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req, + const int ndim, + const mxnet::TShape& new_oshape, + const mxnet::TShape& new_lshape, + const mxnet::TShape& new_rshape) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(lhs.type_flag_, out.type_flag_); + + Stream *s = ctx.get_stream(); + + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + switch (lhs.type_flag_) { + case mshadow::kFloat32: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, + rhs.dptr(), lhs.dptr(), out.dptr()); + } else { + LOG(ERROR) << "Should not reach here!"; + } + break; + } + case mshadow::kFloat64: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, + rhs.dptr(), lhs.dptr(), out.dptr()); + } else if (rhs.type_flag_ == mshadow::kFloat32) { + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, + rhs.dptr(), lhs.dptr(), out.dptr()); + } else { + LOG(ERROR) << "Should not reach here!"; + } + break; + } + default: + { + LOG(ERROR) << "Not supported case of ..."; + break; + } + } + }); +} +#endif + +template +void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); + return; + } + +#ifndef _WIN32 + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); + } else { + mshadow::Stream *s = ctx.get_stream(); + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ == out.type_flag_) { + MixedAllRealBinaryBroadcastCompute( + ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape); + } else { + MixedAllRealBinaryBroadcastCompute( + ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape); + } + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == out.type_flag_) + << "One of the input type should be the same as the output"; + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape, + rhs.dptr(), lhs.dptr(), out.dptr()); + }); + }); + } else { + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, RType, { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + }); + } + }); + } else { + LOG(ERROR) << "not implemented yet..."; + } + } +#else + mshadow::Stream *s = ctx.get_stream(); + if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + TBlob temp_tblob; + // one is float, the other is bool + CHECK((out.type_flag_ == lhs.type_flag_) || (out.type_flag_ == rhs.type_flag_)) + << "This case out type should be same as the float type"; + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + } else { + LOG(ERROR) << "not implemented yet..."; + } +#endif +} + +template +void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + + const TBlob& lhs = inputs[1]; + const TBlob& rhs = inputs[2]; + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastBackwardUseIn(attrs, ctx, inputs, req, outputs); + return; + } + + LOG(ERROR) << "Binary operation with mixed input data types does not support backward yet..."; +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index d2135befef42..35df2c467e7b 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -31,7 +31,7 @@ namespace op { int TrueDivideOutType(int ltype, int rtype) { if (common::is_float(ltype) && common::is_float(rtype)) { // If both inputs are float, return the one with the higher precision - return common::more_precise_type(ltype, rtype); + return common::get_more_precise_type(ltype, rtype); } else if (common::is_float(ltype) || common::is_float(rtype)) { // If only one of the inputs is float, return that float type return (common::is_float(ltype)) ? ltype : rtype; diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 6f444aed21fe..da088c1dcc39 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -99,11 +99,13 @@ class ElemwiseBinaryOp : public OpBase { return a1.var() == a2.var(); } + public: /*! \brief Minimum of three */ static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) { return a < b ? (a < c ? a : c) : (b < c ? b : c); } + private: template static void BackwardUseNone_(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e6d12da23582..32cd5b10717e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1650,6 +1650,69 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): check_binary_func(func, lshape, rshape, low, high, lgrads, rgrads, dtypes) +@with_seed() +@use_np +def test_np_mixed_precision_binary_funcs(): + def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype): + class TestMixedBinary(HybridBlock): + def __init__(self, func): + super(TestMixedBinary, self).__init__() + self._func = func + + def hybrid_forward(self, F, a, b, *args, **kwargs): + return getattr(F.np, self._func)(a, b) + + np_func = getattr(_np, func) + mx_func = TestMixedBinary(func) + np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) + np_test_x2 = _np.random.uniform(low, high, rshape).astype(rtype) + mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ltype) + mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rtype) + rtol = 1e-2 if ltype is np.float16 or rtype is np.float16 else 1e-3 + atol = 1e-4 if ltype is np.float16 or rtype is np.float16 else 1e-5 + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + np_out = np_func(np_test_x1, np_test_x2) + with mx.autograd.record(): + y = mx_func(mx_test_x1, mx_test_x2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol, + use_broadcast=False, equal_nan=True) + + np_out = getattr(_np, func)(np_test_x1, np_test_x2) + mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol, + use_broadcast=False, equal_nan=True) + + funcs = { + 'add': (-1.0, 1.0), + 'subtract': (-1.0, 1.0), + 'multiply': (-1.0, 1.0), + } + shape_pairs = [((3, 2), (3, 2)), + ((3, 2), (3, 1)), + ((3, 1), (3, 0)), + ((0, 2), (1, 2)), + ((2, 3, 4), (3, 1)), + ((2, 3), ()), + ((), (2, 3))] + itypes = [np.bool, np.int8, np.int32, np.int64] + ftypes = [np.float16, np.float32, np.float64] + for func, func_data in funcs.items(): + low, high = func_data + for lshape, rshape in shape_pairs: + for type1, type2 in itertools.product(itypes, ftypes): + check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1) + + for type1, type2 in itertools.product(ftypes, ftypes): + if type1 == type2: + continue + check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) + + @with_seed() @use_np def test_npx_relu(): From b2d501f3b90a331c5678a2646506613f788cf192 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 5 Nov 2019 08:29:21 +0000 Subject: [PATCH 2/2] improvement for documentations and error messages --- python/mxnet/ndarray/numpy/_op.py | 40 +++++++++++++++++ python/mxnet/numpy/multiarray.py | 40 +++++++++++++++++ src/common/utils.h | 24 +++++++++++ src/operator/numpy/np_elemwise_broadcast_op.h | 43 +++++++++++-------- 4 files changed, 130 insertions(+), 17 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 9a106083a10e..30b3d2f4d90b 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -523,6 +523,14 @@ def add(x1, x2, out=None, **kwargs): ------- add : ndarray or scalar The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. """ return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out) @@ -549,6 +557,14 @@ def subtract(x1, x2, out=None, **kwargs): ------- subtract : ndarray or scalar The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. """ return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar, _npi.rsubtract_scalar, out) @@ -576,6 +592,14 @@ def multiply(x1, x2, out=None, **kwargs): out : ndarray or scalar The multiplication of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. """ return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out) @@ -603,6 +627,14 @@ def divide(x1, x2, out=None, **kwargs): ------- out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. """ return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, _npi.rtrue_divide_scalar, out) @@ -633,6 +665,14 @@ def true_divide(x1, x2, out=None): ------- out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. """ return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, _npi.rtrue_divide_scalar, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 0580e139df31..9439e751f1be 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -2402,6 +2402,14 @@ def add(x1, x2, out=None, **kwargs): add : ndarray or scalar The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. + Examples -------- >>> np.add(1.0, 4.0) @@ -2440,6 +2448,14 @@ def subtract(x1, x2, out=None, **kwargs): subtract : ndarray or scalar The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. + Examples -------- >>> np.subtract(1.0, 4.0) @@ -2476,6 +2492,14 @@ def multiply(x1, x2, out=None, **kwargs): out : ndarray or scalar The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. + Examples -------- >>> np.multiply(2.0, 4.0) @@ -2514,6 +2538,14 @@ def divide(x1, x2, out=None, **kwargs): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. + Examples -------- >>> np.true_divide(x, 4) @@ -2548,6 +2580,14 @@ def true_divide(x1, x2, out=None): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. + Examples -------- >>> x = np.arange(5) diff --git a/src/common/utils.h b/src/common/utils.h index d1a591aec2ac..0e3e35430652 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -365,6 +365,30 @@ inline bool ContainsStorageType(const std::vector& ndstypes, return false; } +inline std::string dtype_string(const int dtype) { + switch (dtype) { + case mshadow::kFloat32: + return "float"; + case mshadow::kFloat64: + return "double"; + case mshadow::kFloat16: + return "half"; + case mshadow::kUint8: + return "unsigned char"; + case mshadow::kInt8: + return "char"; + case mshadow::kInt32: + return "int"; + case mshadow::kInt64: + return "long long"; + case mshadow::kBool: + return "bool"; + default: + LOG(FATAL) << "Unknown type enum " << dtype; + } + return "unknown"; +} + /*! \brief get string representation of dispatch_mode */ inline std::string dispatch_mode_string(const DispatchMode x) { switch (x) { diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 48a64bdf7ee2..1d36c6ff881e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -20,12 +20,13 @@ /*! * Copyright (c) 2019 by Contributors * \file np_elemwise_binary_op.h - * \brief + * \brief Function definition of elemwise and broadcast operators */ #ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ #define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ #include +#include #include "../tensor/elemwise_binary_broadcast_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -33,9 +34,16 @@ namespace mxnet { namespace op { +inline void PrintErrorMessage(const std::string& name, const int dtype1, const int dtype2) { + LOG(FATAL) << "Operator " << name << " does not support combination of " + << common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2) + << " yet..."; +} + #ifndef _WIN32 template -void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx, +void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, + const OpContext& ctx, const TBlob& lhs, const TBlob& rhs, const TBlob& out, @@ -61,7 +69,7 @@ void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx, lhs.dptr()); }); } else { - LOG(ERROR) << "Should not reach here!"; + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); } break; } @@ -80,13 +88,13 @@ void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx, lhs.dptr()); }); } else { - LOG(ERROR) << "Should not reach here!"; + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); } break; } default: { - LOG(ERROR) << "Not supported case of ..."; + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); break; } } @@ -137,9 +145,9 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { - MixedAllRealBinaryElemwiseCompute(ctx, lhs, rhs, out, req[0]); + MixedAllRealBinaryElemwiseCompute(attrs.op->name, ctx, lhs, rhs, out, req[0]); } else { - MixedAllRealBinaryElemwiseCompute(ctx, rhs, lhs, out, req[0]); + MixedAllRealBinaryElemwiseCompute(attrs.op->name, ctx, rhs, lhs, out, req[0]); } } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { @@ -148,12 +156,13 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, MixedIntRealBinaryElemwiseCompute(ctx, rhs, lhs, out, req[0]); } } else { - LOG(ERROR) << "not implemented yet..."; + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } } template -void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx, +void MixedAllRealBinaryBroadcastCompute(const std::string& op_name, + const OpContext& ctx, const TBlob& lhs, const TBlob& rhs, const TBlob& out, @@ -180,7 +189,7 @@ void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx, template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, rhs.dptr(), lhs.dptr(), out.dptr()); } else { - LOG(ERROR) << "Should not reach here!"; + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); } break; } @@ -195,13 +204,13 @@ void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx, template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, rhs.dptr(), lhs.dptr(), out.dptr()); } else { - LOG(ERROR) << "Should not reach here!"; + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); } break; } default: { - LOG(ERROR) << "Not supported case of ..."; + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); break; } } @@ -242,10 +251,10 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { MixedAllRealBinaryBroadcastCompute( - ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape); + attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape); } else { MixedAllRealBinaryBroadcastCompute( - ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape); + attrs.op->name, ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape); } } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == out.type_flag_) @@ -273,7 +282,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } }); } else { - LOG(ERROR) << "not implemented yet..."; + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } } #else @@ -303,7 +312,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); } } else { - LOG(ERROR) << "not implemented yet..."; + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } #endif } @@ -324,7 +333,7 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, return; } - LOG(ERROR) << "Binary operation with mixed input data types does not support backward yet..."; + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } } // namespace op