Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
mixed precison binary op backward
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 12, 2019
1 parent 02f4f05 commit 362891d
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 13 deletions.
17 changes: 16 additions & 1 deletion src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
"FCompute<cpu>",
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::right,
mshadow_op::left>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ NNVM_REGISTER_OP(_npi_multiply)
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::right,
mshadow_op::left>);

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);

Expand Down
103 changes: 101 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
}

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

Expand All @@ -396,7 +398,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
return;
}

PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
const TBlob& ograd = inputs[0];
const TBlob& lgrad = outputs[0];
const TBlob& rgrad = outputs[1];

if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// If any of the inputs is a float, it's the same type as the output
// So 2 of the 3 tensors have the same data type
Stream<xpu> *s = ctx.get_stream<xpu>();
mxnet::TShape new_lshape, new_rshape, new_oshape;
using namespace broadcast;
const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_,
&new_lshape, &new_rshape, &new_oshape) != 0;

// Prepare all the temporary memory
size_t workspace_size_l = 0, workspace_size_r = 0;
TBlob temp_tblob; // The TBlob for casted input data
TBlob temp_igrad; // The TBlob for casted grad results
size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size();
Tensor<xpu, 1, char> workspace;

MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
});
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
size_t cast_tensor_size = tensor_size * sizeof(OType);
// Allocate the temporary memories now
Tensor<xpu, 1, char> temp_space =
ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(workspace_size + cast_tensor_size * 2), s);
// Tensor for temp_tblob
Tensor<xpu, 1, OType> temp_tblob_tensor(
reinterpret_cast<OType*>(temp_space.dptr_),
Shape1(tensor_size), s);
// Tensor for temp_igrad
Tensor<xpu, 1, OType> temp_igrad_tensor(
reinterpret_cast<OType*>(temp_space.dptr_) + tensor_size,
Shape1(tensor_size), s);
temp_tblob =
TBlob(temp_tblob_tensor)
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
temp_igrad =
TBlob(temp_igrad_tensor)
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
if (temp_igrad.Size() != 0) {
Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), temp_igrad.dptr<OType>());
}
workspace =
Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s);
});
// Cast the input that does not have consistent type to temp_tblob
CastCompute<xpu>(
attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob});
if (!need_bc) {
if (lhs.type_flag_ != ograd.type_flag_) {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad});
} else {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad});
}
} else {
if (lhs.type_flag_ != ograd.type_flag_) {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad},
workspace, new_lshape, new_rshape, new_oshape);
});
});
} else {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad},
workspace, new_lshape, new_rshape, new_oshape);
});
});
}
}

// If both inputs are floating numbers, cast the igrad to the input that has
// the different data type
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
if (lhs.type_flag_ != ograd.type_flag_) {
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
} else {
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
}
}
} else {
// Case where both inputs are integer types, should not even do
// backward computation for this case.
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
}

} // namespace op
Expand Down
26 changes: 26 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
mshadow::Tensor<xpu, 1, char>& workspace,
const mxnet::TShape& new_lshape,
const mxnet::TShape& new_rshape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
if (ograd.Size() != 0) {
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], workspace,
ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], workspace,
ograd, lhs, rhs);
}
}

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) {
if ((outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) && outputs[0].Size() != 0) {
Assign(out, req[0], tcast<DstDType>(data));
}
});
Expand Down
20 changes: 12 additions & 8 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,7 +1652,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
@with_seed()
@use_np
def test_np_mixed_precision_binary_funcs():
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype):
itypes = [np.bool, np.int8, np.int32, np.int64]
ftypes = [np.float16, np.float32, np.float64]
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, ltype, rtype):
class TestMixedBinary(HybridBlock):
def __init__(self, func):
super(TestMixedBinary, self).__init__()
Expand Down Expand Up @@ -1686,13 +1688,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

funcs = {
'add': (-1.0, 1.0),
'subtract': (-1.0, 1.0),
'multiply': (-1.0, 1.0),
'add': (-1.0, 1.0, None, None),
'subtract': (-1.0, 1.0, None, None),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape),
lambda y, x1, x2: _np.broadcast_to(x1, y.shape))
}

shape_pairs = [((3, 2), (3, 2)),
((3, 2), (3, 1)),
((3, 0), (3, 0)),
((3, 1), (3, 0)),
((0, 2), (1, 2)),
((2, 3, 4), (3, 1)),
Expand All @@ -1702,16 +1706,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
itypes = [np.bool, np.int8, np.int32, np.int64]
ftypes = [np.float16, np.float32, np.float64]
for func, func_data in funcs.items():
low, high = func_data
low, high, lgrad, rgrad = func_data
for lshape, rshape in shape_pairs:
for type1, type2 in itertools.product(itypes, ftypes):
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1)

for type1, type2 in itertools.product(ftypes, ftypes):
if type1 == type2:
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)


@with_seed()
Expand Down

0 comments on commit 362891d

Please sign in to comment.