From 362891dadfc59d99daa31051fdeb5553850615e4 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 12 Nov 2019 06:38:06 +0000 Subject: [PATCH] mixed precison binary op backward --- .../numpy/np_elemwise_broadcast_op.cc | 17 ++- .../numpy/np_elemwise_broadcast_op.cu | 4 + src/operator/numpy/np_elemwise_broadcast_op.h | 103 +++++++++++++++++- .../tensor/elemwise_binary_broadcast_op.h | 26 +++++ src/operator/tensor/elemwise_unary_op.h | 4 +- tests/python/unittest/test_numpy_op.py | 20 ++-- 6 files changed, 161 insertions(+), 13 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index a76e59d30dc6..acf0395123fc 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -147,7 +147,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) "FCompute", NumpyBinaryBroadcastComputeWithBool) #endif -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_mul) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) .set_attr("FCompute", BinaryBroadcastCompute) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a0a277df211f..d9499625e34d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -64,6 +64,10 @@ NNVM_REGISTER_OP(_npi_multiply) NumpyBinaryBroadcastComputeWithBool); #endif +NNVM_REGISTER_OP(_backward_npi_broadcast_mul) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_mod) .set_attr("FCompute", BinaryBroadcastCompute); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 1a4596fba91c..7a59b23fde46 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -381,11 +381,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, } template -void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, +void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); @@ -396,7 +398,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, return; } - PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + const TBlob& ograd = inputs[0]; + const TBlob& lgrad = outputs[0]; + const TBlob& rgrad = outputs[1]; + + if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // If any of the inputs is a float, it's the same type as the output + // So 2 of the 3 tensors have the same data type + Stream *s = ctx.get_stream(); + mxnet::TShape new_lshape, new_rshape, new_oshape; + using namespace broadcast; + const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_, + &new_lshape, &new_rshape, &new_oshape) != 0; + + // Prepare all the temporary memory + size_t workspace_size_l = 0, workspace_size_r = 0; + TBlob temp_tblob; // The TBlob for casted input data + TBlob temp_igrad; // The TBlob for casted grad results + size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size(); + Tensor workspace; + + MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, { + workspace_size_l = ReduceWorkspaceSize( + s, new_lshape, req[0], new_oshape, new_lshape, new_rshape); + workspace_size_r = ReduceWorkspaceSize( + s, new_rshape, req[1], new_oshape, new_lshape, new_rshape); + }); + size_t workspace_size = std::max(workspace_size_l, workspace_size_r); + size_t cast_tensor_size = tensor_size * sizeof(OType); + // Allocate the temporary memories now + Tensor temp_space = + ctx.requested[0].get_space_typed( + Shape1(workspace_size + cast_tensor_size * 2), s); + // Tensor for temp_tblob + Tensor temp_tblob_tensor( + reinterpret_cast(temp_space.dptr_), + Shape1(tensor_size), s); + // Tensor for temp_igrad + Tensor temp_igrad_tensor( + reinterpret_cast(temp_space.dptr_) + tensor_size, + Shape1(tensor_size), s); + temp_tblob = + TBlob(temp_tblob_tensor) + .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_)); + temp_igrad = + TBlob(temp_igrad_tensor) + .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_)); + if (temp_igrad.Size() != 0) { + Kernel::Launch(s, temp_igrad.Size(), temp_igrad.dptr()); + } + workspace = + Tensor(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s); + }); + // Cast the input that does not have consistent type to temp_tblob + CastCompute( + attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob}); + if (!need_bc) { + if (lhs.type_flag_ != ograd.type_flag_) { + ElemwiseBinaryOp::BackwardUseIn( + attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad}); + } else { + ElemwiseBinaryOp::BackwardUseIn( + attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad}); + } + } else { + if (lhs.type_flag_ != ograd.type_flag_) { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, { + BinaryBroadcastBackwardUseInImplWithWorkspace( + ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad}, + workspace, new_lshape, new_rshape, new_oshape); + }); + }); + } else { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, { + BinaryBroadcastBackwardUseInImplWithWorkspace( + ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad}, + workspace, new_lshape, new_rshape, new_oshape); + }); + }); + } + } + + // If both inputs are floating numbers, cast the igrad to the input that has + // the different data type + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ != ograd.type_flag_) { + CastCompute(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad}); + } else { + CastCompute(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad}); + } + } + } else { + // Case where both inputs are integer types, should not even do + // backward computation for this case. + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + } } } // namespace op diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index b48ed389ba98..4f44d90886af 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -671,6 +671,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs); +template +void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + mshadow::Tensor& workspace, + const mxnet::TShape& new_lshape, + const mxnet::TShape& new_rshape, + const mxnet::TShape& new_oshape) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob lgrad = outputs[0].reshape(new_lshape); + const TBlob rgrad = outputs[1].reshape(new_rshape); + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob lhs = inputs[1].reshape(new_lshape); + const TBlob rhs = inputs[2].reshape(new_rshape); + if (ograd.Size() != 0) { + Reduce(s, lgrad, req[0], workspace, + ograd, lhs, rhs); + Reduce(s, rgrad, req[1], workspace, + ograd, lhs, rhs); + } +} + template inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 27013dfb98ae..d976cfb263d8 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -454,8 +454,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs, Tensor out = outputs[0].FlatTo1D(s); MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); - if (outputs[0].type_flag_ != inputs[0].type_flag_ || - req[0] != kWriteInplace) { + if ((outputs[0].type_flag_ != inputs[0].type_flag_ || + req[0] != kWriteInplace) && outputs[0].Size() != 0) { Assign(out, req[0], tcast(data)); } }); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9aabdfd4cabc..f61c61b8628e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1652,7 +1652,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): @with_seed() @use_np def test_np_mixed_precision_binary_funcs(): - def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype): + itypes = [np.bool, np.int8, np.int32, np.int64] + ftypes = [np.float16, np.float32, np.float64] + def check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, ltype, rtype): class TestMixedBinary(HybridBlock): def __init__(self, func): super(TestMixedBinary, self).__init__() @@ -1686,13 +1688,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) funcs = { - 'add': (-1.0, 1.0), - 'subtract': (-1.0, 1.0), - 'multiply': (-1.0, 1.0), + 'add': (-1.0, 1.0, None, None), + 'subtract': (-1.0, 1.0, None, None), + 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape), + lambda y, x1, x2: _np.broadcast_to(x1, y.shape)) } shape_pairs = [((3, 2), (3, 2)), ((3, 2), (3, 1)), + ((3, 0), (3, 0)), ((3, 1), (3, 0)), ((0, 2), (1, 2)), ((2, 3, 4), (3, 1)), @@ -1702,16 +1706,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): itypes = [np.bool, np.int8, np.int32, np.int64] ftypes = [np.float16, np.float32, np.float64] for func, func_data in funcs.items(): - low, high = func_data + low, high, lgrad, rgrad = func_data for lshape, rshape in shape_pairs: for type1, type2 in itertools.product(itypes, ftypes): - check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) - check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1) for type1, type2 in itertools.product(ftypes, ftypes): if type1 == type2: continue - check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) @with_seed()