Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy compatible max min
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Aug 30, 2019
1 parent 61f3dbc commit 6a4f7b1
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
}
};

struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter<NumpyReduceAxesNoDTypeParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
bool keepdims;
dmlc::optional<double> initial;
DMLC_DECLARE_PARAMETER(NumpyReduceAxesNoDTypeParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
"all of the elements of the input array. If axis is negative it counts from the "
"last to the first axis.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
.describe("Starting value for the sum.");
}
};

inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
const dmlc::optional<mxnet::Tuple<int>>& axis,
bool keepdims) {
Expand Down Expand Up @@ -152,6 +170,39 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}

inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
// check the case where the reduction axis should not be zero
bool is_all_reducded_axes_not_zero = true;
const TShape& ishape = (*in_attrs)[0];
if (param.axis.has_value()) {
const mxnet::Tuple<int>& axes = param.axis.value();
for (int i = 0; i < axes.ndim(); ++i) {
if (ishape[axes[i]] == 0) {
is_all_reducded_axes_not_zero = false;
break;
}
}
} else {
if (ishape.Size() == 0) {
// global reduction should excuted only when input have size more than 0
is_all_reducded_axes_not_zero = false;
}
}
CHECK(is_all_reducded_axes_not_zero)
<< "zero-size array to reduction operation maximum which has no identity";
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
return shape_is_known(out_attrs->at(0));
}

template<bool safe_acc_hint = false>
inline bool NeedSafeAcc(int itype, int otype) {
bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64);
Expand Down Expand Up @@ -186,6 +237,30 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename reducer, typename OP = op::mshadow_op::identity>
void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
if (inputs[0].shape_.Size() == 0U || outputs[0].shape_.Size() == 0U) return; // zero-size tensor
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
}
TShape small;
if (param.keepdims) {
small = outputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
}
ReduceAxesComputeImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}


template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -273,6 +348,24 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename OP>
void NumpyReduceAxesNoDTypeBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
}
ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
66 changes: 66 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
DMLC_REGISTER_PARAMETER(NumpyReduceAxesNoDTypeParam);

inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
Expand Down Expand Up @@ -74,6 +75,71 @@ NNVM_REGISTER_OP(_backward_np_sum)
.set_num_inputs(1)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu>);

inline bool NumpyReduceAxesNoDTypeType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_np_max)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesNoDTypeShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesNoDTypeType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeCompute<cpu, mshadow::red::maximum>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_max"});

NNVM_REGISTER_OP(_backward_np_max)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_num_inputs(3)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_min)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesNoDTypeShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesNoDTypeType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeCompute<cpu, mshadow::red::minimum>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_min"});

NNVM_REGISTER_OP(_backward_np_min)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_num_inputs(3)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_prod)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
12 changes: 12 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ NNVM_REGISTER_OP(_np_sum)
NNVM_REGISTER_OP(_backward_np_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);

NNVM_REGISTER_OP(_np_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(_backward_np_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_min)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::minimum>);

NNVM_REGISTER_OP(_backward_np_min)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_prod)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::product, true>);

Expand Down
113 changes: 113 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,119 @@ def is_int(dtype):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)


@with_seed()
@use_np
def test_np_max_min():
class TestMax(HybridBlock):
def __init__(self, axis=None, keepdims=False):
super(TestMax, self).__init__()
self._axis = axis
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.max(a, axis=self._axis, keepdims=self._keepdims)

class TestMin(HybridBlock):
def __init__(self, axis=None, keepdims=False):
super(TestMin, self).__init__()
self._axis = axis
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.min(a, axis=self._axis, keepdims=self._keepdims)

def is_int(dtype):
return 'int' == dtype

def get_grad(axis, func_name):
index = -1 if func_name == 'max' else 0
if axis == ():
return _np.ones((2,3,4,5))
else:
temp = _np.zeros((2,3,4,5))
if axis == 0:
temp[index,:,:,:] = 1
return temp
elif axis == 1:
temp[:,index,:,:] = 1
return temp
elif axis == 2:
temp[:,:,index,:] = 1
return temp
elif axis == 3:
temp[:,:,:,index] = 1
return temp
elif not axis:
temp[index,index,index,index] = 1
return temp
raise ValueError('axis should be int or None or ()')

def _test_np_exception(func, shape, dim):
x = _np.random.uniform(-1.0, 1.0, shape)
x = mx.nd.array(x).as_np_ndarray()
if func == 'max':
out = mx.np.max(x)
else:
out = mx.np.min(x)
assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim)

in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
for func in ['max', 'min']:
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64', 'int']:
# test gluon
if func == 'max':
test_gluon = TestMax(axis=axis, keepdims=keepdims)
else:
test_gluon = TestMin(axis=axis, keepdims=keepdims)
if hybridize:
test_gluon.hybridize()
if is_int(itype):
x = mx.nd.arange(120).reshape((2, 3, 4, 5))
x = mx.nd.array(x)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x.attach_grad()
if func == 'max':
expected_ret = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
else:
expected_ret = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
with mx.autograd.record():
y = test_gluon(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3,
atol=1e-5 if itype == 'float16' else 1e-5)
y.backward()
# only check the gradient with hardcoded input
if is_int(itype):
assert same(x.grad.asnumpy(), get_grad(axis, func)), \
'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), x.grad.asnumpy(), get_grad(axis))

# test imperative
if func == 'max':
mx_out = np.max(x, axis=axis, keepdims=keepdims)
np_out = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
else:
mx_out = np.min(x, axis=axis, keepdims=keepdims)
np_out = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

# test zero and zero dim
shapes = [(), (0), (2, 0), (0, 2, 1)]
exceptions = [False, True, True, True]
dims = [0] * len(shapes)
for func in ['max', 'min']:
for shape, exception, dim in zip(shapes, exceptions, dims):
if exception:
assertRaises(MXNetError, _test_np_exception, func, shape, dim)
else:
_test_np_exception(func, shape, dim)


@with_seed()
@use_np
def test_np_linspace():
Expand Down

0 comments on commit 6a4f7b1

Please sign in to comment.