Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change space
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 committed Dec 16, 2019
1 parent ee4a08d commit a590c9e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 73 deletions.
84 changes: 42 additions & 42 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,48 +846,6 @@ void NumpyBroadcastToForward(const nnvm::NodeAttrs& attrs,
req, outputs, expanded_ishape);
}

template<typename xpu>
void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (inputs[0].shape_.Size() == 0U) return; // zero-size ograd
TShape expanded_igrad_shape(inputs[0].shape_.ndim(), 1);
const TShape& igrad_shape = outputs[0].shape_;
CHECK_LE(igrad_shape.ndim(), expanded_igrad_shape.ndim())
<< "output ndim cannot be less than input ndim";
const int ndim_delta = expanded_igrad_shape.ndim() - igrad_shape.ndim();
for (int i = 0; i < igrad_shape.ndim(); ++i) {
expanded_igrad_shape[i + ndim_delta] = igrad_shape[i];
}
if (NeedSafeAcc<true>(inputs[0].type_flag_, outputs[0].type_flag_)) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
} else {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(
ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
}
}

template<typename xpu, typename OP>
void NumpyReduceAxesNoDTypeBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
}
ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, outputs);
}

struct NumpyMedianParam : public dmlc::Parameter<NumpyMedianParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
bool keepdims;
Expand Down Expand Up @@ -1078,6 +1036,48 @@ inline bool NumpyMedianShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
}

template<typename xpu>
void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (inputs[0].shape_.Size() == 0U) return; // zero-size ograd
TShape expanded_igrad_shape(inputs[0].shape_.ndim(), 1);
const TShape& igrad_shape = outputs[0].shape_;
CHECK_LE(igrad_shape.ndim(), expanded_igrad_shape.ndim())
<< "output ndim cannot be less than input ndim";
const int ndim_delta = expanded_igrad_shape.ndim() - igrad_shape.ndim();
for (int i = 0; i < igrad_shape.ndim(); ++i) {
expanded_igrad_shape[i + ndim_delta] = igrad_shape[i];
}
if (NeedSafeAcc<true>(inputs[0].type_flag_, outputs[0].type_flag_)) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
} else {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(
ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
}
}

template<typename xpu, typename OP>
void NumpyReduceAxesNoDTypeBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
}
ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
62 changes: 31 additions & 31 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,37 @@ NNVM_REGISTER_OP(_backward_np_average)
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});

inline bool NumpyMedianType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_npi_median)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyMedianParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyMedianShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMedianType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyMedianParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyMedianForward<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace, ResourceRequest::kTempSpace};
})
// .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
Expand Down Expand Up @@ -509,36 +540,5 @@ NNVM_REGISTER_OP(_backward_np_broadcast_to)
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});

inline bool NumpyMedianType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_npi_median)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyMedianParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyMedianShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMedianType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyMedianParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyMedianForward<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace, ResourceRequest::kTempSpace};
})
// .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

} // namespace op
} // namespace mxnet

0 comments on commit a590c9e

Please sign in to comment.