Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix np.argmax/argmin output data type #17476

Merged
merged 1 commit into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,54 @@ void NumpyReduceAxesBoolCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeBoolImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}

template<typename xpu, typename reducer>
void NumpySearchAxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
int axis = inputs[0].ndim();
TBlob input = inputs[0];
if (param.axis.has_value()) {
axis = param.axis.value();
} else {
// If global reduction, reshape the input tensor into 2D shape (1, inputs[0].shape_.Size())
// and search on axis = 1.
mxnet::TShape shape_2d(2, 1);
shape_2d[1] = input.shape_.Size();
input = TBlob(input.dptr_, shape_2d, input.dev_mask(), input.type_flag_, input.dev_id());
axis = 1;
}

axis = CheckAxis(axis, input.shape_.ndim());
if (inputs[0].shape_.ndim() != 0) {
if (param.axis.has_value()) {
// cannot do argmax in an empty dimension
CHECK_NE(inputs[0].shape_[axis], 0)
<< "searching input tensor of shape " << inputs[0].shape_
<< " along axis = " << axis << " of zero dim-size is not allowed";
} else {
// cannot do argmax on an empty array
CHECK_NE(inputs[0].shape_.Size(), 0U) << "attempt to search an empty sequence";
}
}

if (input.shape_.Size() == 0U) return; // zero-size tensor
mxnet::TShape shape = AxisShapeCompact(input.shape_, &axis, false);
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, int64_t> out =
outputs[0].get_with_shape<xpu, 2, int64_t>(Shape2(shape[0], shape[2]), s);
Tensor<xpu, 3, DType> in =
input.get_with_shape<xpu, 3, DType>(shape.get<3>(), s);
CHECK(req[0] != kAddTo) << "AddTo is not supported";
ASSIGN_DISPATCH(out, req[0], tcast<int64_t>(reduce_with_axis<reducer, true>(in, 1)));
});
}

template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
18 changes: 14 additions & 4 deletions src/operator/numpy/np_broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,24 @@ bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}

bool ArgMinMaxType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE(in_attrs->at(0), -1);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_npi_argmax)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ReduceAxisParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
.add_argument("data", "NDArray-or-Symbol", "The input")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>)
.set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::maximum>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_arguments(ReduceAxisParam::__FIELDS__());

Expand All @@ -62,9 +72,9 @@ NNVM_REGISTER_OP(_npi_argmin)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ReduceAxisParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
.add_argument("data", "NDArray-or-Symbol", "The input")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::minimum>)
.set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::minimum>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_arguments(ReduceAxisParam::__FIELDS__());

Expand Down
4 changes: 2 additions & 2 deletions src/operator/numpy/np_broadcast_reduce_op_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_argmax)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);
.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(_npi_argmin)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::minimum>);
.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::minimum>);

} // namespace op
} // namespace mxnet
2 changes: 2 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,6 +3127,7 @@ def hybrid_forward(self, F, x):
else:
mx_ret = getattr(np, op_name)(a, axis=axis)
np_ret = getattr(_np, op_name)(a.asnumpy(), axis=axis)
assert mx_ret.dtype == np_ret.dtype
assert same(mx_ret.asnumpy(), np_ret)

for hybridize in [False, True]:
Expand All @@ -3142,6 +3143,7 @@ def hybrid_forward(self, F, x):
pass
else:
mx_ret = net(a)
assert mx_ret.dtype == np_ret.dtype
assert same(mx_ret.asnumpy(), np_ret)


Expand Down