Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix np.argmax/argmin output data type
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jan 30, 2020
1 parent 821b6e5 commit b5a974a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 6 deletions.
48 changes: 48 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,54 @@ void NumpyReduceAxesBoolCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeBoolImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}

template<typename xpu, typename reducer>
void NumpySearchAxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
int axis = inputs[0].ndim();
TBlob input = inputs[0];
if (param.axis.has_value()) {
axis = param.axis.value();
} else {
// If global reduction, reshape the input tensor into 2D shape (1, inputs[0].shape_.Size())
// and search on axis = 1.
mxnet::TShape shape_2d(2, 1);
shape_2d[1] = input.shape_.Size();
input = TBlob(input.dptr_, shape_2d, input.dev_mask(), input.type_flag_, input.dev_id());
axis = 1;
}

axis = CheckAxis(axis, input.shape_.ndim());
if (inputs[0].shape_.ndim() != 0) {
if (param.axis.has_value()) {
// cannot do argmax in an empty dimension
CHECK_NE(inputs[0].shape_[axis], 0)
<< "searching input tensor of shape " << inputs[0].shape_
<< " along axis = " << axis << " of zero dim-size is not allowed";
} else {
// cannot do argmax on an empty array
CHECK_NE(inputs[0].shape_.Size(), 0U) << "attempt to search an empty sequence";
}
}

if (input.shape_.Size() == 0U) return; // zero-size tensor
mxnet::TShape shape = AxisShapeCompact(input.shape_, &axis, false);
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, int64_t> out =
outputs[0].get_with_shape<xpu, 2, int64_t>(Shape2(shape[0], shape[2]), s);
Tensor<xpu, 3, DType> in =
input.get_with_shape<xpu, 3, DType>(shape.get<3>(), s);
CHECK(req[0] != kAddTo) << "AddTo is not supported";
ASSIGN_DISPATCH(out, req[0], tcast<int64_t>(reduce_with_axis<reducer, true>(in, 1)));
});
}

template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
18 changes: 14 additions & 4 deletions src/operator/numpy/np_broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,24 @@ bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}

bool ArgMinMaxType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE(in_attrs->at(0), -1);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_npi_argmax)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ReduceAxisParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
.add_argument("data", "NDArray-or-Symbol", "The input")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>)
.set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::maximum>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_arguments(ReduceAxisParam::__FIELDS__());

Expand All @@ -62,9 +72,9 @@ NNVM_REGISTER_OP(_npi_argmin)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ReduceAxisParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
.add_argument("data", "NDArray-or-Symbol", "The input")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::minimum>)
.set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::minimum>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_arguments(ReduceAxisParam::__FIELDS__());

Expand Down
4 changes: 2 additions & 2 deletions src/operator/numpy/np_broadcast_reduce_op_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_argmax)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);
.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(_npi_argmin)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::minimum>);
.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::minimum>);

} // namespace op
} // namespace mxnet
2 changes: 2 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,6 +3127,7 @@ def hybrid_forward(self, F, x):
else:
mx_ret = getattr(np, op_name)(a, axis=axis)
np_ret = getattr(_np, op_name)(a.asnumpy(), axis=axis)
assert mx_ret.dtype == np_ret.dtype
assert same(mx_ret.asnumpy(), np_ret)

for hybridize in [False, True]:
Expand All @@ -3142,6 +3143,7 @@ def hybrid_forward(self, F, x):
pass
else:
mx_ret = net(a)
assert mx_ret.dtype == np_ret.dtype
assert same(mx_ret.asnumpy(), np_ret)


Expand Down

0 comments on commit b5a974a

Please sign in to comment.