Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
workaround for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 2, 2019
1 parent 727b84d commit f6aa9e9
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 94 deletions.
6 changes: 6 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ struct true_divide : public mxnet_op::tunable {
return static_cast<float>(a) / static_cast<float>(b);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
Expand All @@ -150,6 +151,7 @@ struct true_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) / b;
}
#endif
};

struct rtrue_divide : public mxnet_op::tunable {
Expand All @@ -165,6 +167,7 @@ struct rtrue_divide : public mxnet_op::tunable {
return static_cast<float>(b) / static_cast<float>(a);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
Expand All @@ -182,6 +185,7 @@ struct rtrue_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static double Map(DType a, double b) {
return b / static_cast<double>(a);
}
#endif
};

MXNET_BINARY_MATH_OP_NC(left, a);
Expand All @@ -190,13 +194,15 @@ MXNET_BINARY_MATH_OP_NC(right, b);

MXNET_BINARY_MATH_OP_NC(mul, a * b);

#ifndef _WIN32
struct mixed_mul {
template<typename DType,
typename std::enable_if<!std::is_pointer<DType>::value, int>::type = 0>
MSHADOW_XINLINE static DType Map(bool a, DType b) {
return static_cast<DType>(a) * b;
}
};
#endif

MXNET_BINARY_MATH_OP_NC(div, a / b);

Expand Down
36 changes: 29 additions & 7 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")

#ifndef _WIN32
bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
Expand All @@ -71,6 +70,28 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
return true;
}

#ifdef _WIN32
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#else
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
Expand All @@ -97,12 +118,18 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});

#ifndef _WIN32
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
#else
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mul,
op::mshadow_op::mul>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
Expand All @@ -119,11 +146,6 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
})
.set_attr<FCompute>("FCompute<cpu>", MixedBinaryBackwardUseIn<cpu, mshadow_op::right,
mshadow_op::left>);
#else
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
#endif

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
Expand Down
9 changes: 6 additions & 3 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ NNVM_REGISTER_OP(_npi_multiply)
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>);
#else
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mul,
op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_attr<FCompute>("FCompute<gpu>", MixedBinaryBackwardUseIn<gpu, mshadow_op::right,
mshadow_op::left>);
#else
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
Expand Down
47 changes: 36 additions & 11 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
// TODO(haojin2): No mixed-precision multiply on windows temporarily due to CI issues.
#ifndef _WIN32
using namespace mshadow;
using namespace mxnet_op;
Expand Down Expand Up @@ -71,7 +70,7 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
});
});
#else
LOG(ERROR) << "mixed precision multiply is not supported on windows yet...";
LOG(ERROR) << "windows should not reach here...";
#endif
}

Expand All @@ -92,22 +91,18 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,

if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;

mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);


if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}

// TODO(haojin2): No mixed-precision multiply on windows temporarily due to CI issues.
#ifndef _WIN32
CHECK((lhs.type_flag_ == mshadow::kBool) || (rhs.type_flag_ == mshadow::kBool))
<< "now supports bool with another type only";


#ifndef _WIN32
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
Expand All @@ -130,7 +125,37 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
});
}
#else
LOG(ERROR) << "mixed precision multiply is not supported on windows yet...";
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
LOG(ERROR) << "not implemented yet...";
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
TBlob temp_tblob;
// one is float, the other is bool
CHECK_EQ(out.type_flag_,
common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_)
<< "This case out type should be same as the float type";
if (common::is_float(lhs.type_flag_)) {
MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
LOG(ERROR) << "not implemented yet...";
}
#endif
}

Expand Down
Loading

0 comments on commit f6aa9e9

Please sign in to comment.