Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
any/all
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Dec 17, 2019
1 parent 52c9a45 commit d95f6bb
Show file tree
Hide file tree
Showing 14 changed files with 476 additions and 2 deletions.
101 changes: 101 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,107 @@
"""Doc placeholder for numpy ops with prefix _np."""


def _np_all(a, axis=None, keepdims=False, out=None):
"""
Test whether all array elements along a given axis evaluate to True.
Parameters
----------
a : array_like
Input array or object that can be converted to an array.
axis : None or int or tuple of ints, optional
Axis or axes along which a logical AND reduction is performed.
The default (axis = None) is to perform a logical AND over
all the dimensions of the input array.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
out : ndarray, optional
Alternate output array in which to place the result. It must have
the same shape as the expected output and its type is preserved
Returns
--------
all : ndarray, bool
A new boolean or array is returned unless out is specified,
in which case a reference to out is returned.
Examples:
---------
>>>np.all([[True,False],[True,True]])
False
>>> np.all([[True,False],[True,True]], axis=0)
array([ True, False])
>>> np.all([-1, 4, 5])
True
>>> np.all([1.0, np.nan])
True
>>> o=np.array(False)
>>> z=np.all([-1, 4, 5], out=o)
>>> id(z), id(o), z
(28293632, 28293632, array(True)) # may vary
"""
pass

def _np_any(a, axis=None, keepdims=False, out=None):
"""
Test whether any array element along a given axis evaluates to True.
Returns single boolean unless axis is not None
Parameters
----------
a : array_like
Input array or object that can be converted to an array.
axis : None or int or tuple of ints, optional
Axis or axes along which a logical AND reduction is performed.
The default (axis = None) is to perform a logical AND over
all the dimensions of the input array.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
out : ndarray, optional
Alternate output array in which to place the result. It must have
the same shape as the expected output and its type is preserved
Returns
--------
any : bool or ndarray
A new boolean or ndarray is returned unless out is specified,
in which case a reference to out is returned.
Examples:
---------
>>> np.any([[True, False], [True, True]])
True
>>> np.any([[True, False], [False, False]], axis=0)
array([ True, False])
>>> np.any([-1, 0, 5])
True
>>> np.any(np.nan)
True
>>> o=np.array(False)
>>> z=np.any([-1, 4, 5], out=o)
>>> z, o
(array(True), array(True))
>>> # Check now that z is a reference to o
>>> z is o
True
>>> id(z), id(o) # identity of z and o # doctest: +SKIP
(191614240, 191614240)
"""
pass


def _np_cumsum(a, axis=None, dtype=None, out=None):
"""
Return the cumulative sum of the elements along a given axis.
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):


_NUMPY_ARRAY_FUNCTION_LIST = [
'all',
'any',
'argmin',
'argmax',
'around',
Expand Down
8 changes: 8 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,14 @@ struct minimum : public mxnet_op::tunable {
}
};

/*! \brief boolean any/all kernel that determines whether elem is NonZero */
struct NonZero {
template<typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return (a != DType(0));
}
};

/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
Expand Down
23 changes: 23 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,29 @@ struct set_to_int : public tunable {
*/
using set_zero = set_to_int<0>;
using set_one = set_to_int<1>;

/*!
* \brief Set to immediate scalar value kernel
* \tparam val Scalar immediate
*/
template<bool val>
struct set_to_bool : public tunable {
// mxnet_op version (when used directly with Kernel<>::Launch()) */
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType *out) {
out[i] = DType(val);
}
// mshadow_op version (when used with op_with_req<>)
MSHADOW_XINLINE static int Map() {
return val;
}
};

/*!
* \brief Special-case kernel shortcut for setting to true and false
*/
using set_true = set_to_bool<true>;
using set_false = set_to_bool<false>;
} // namespace mxnet_op

} // namespace op
Expand Down
55 changes: 54 additions & 1 deletion src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,26 @@ struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter<NumpyReduceAxesNoDTy
}
};

struct NumpyReduceAxesBoolParam : public dmlc::Parameter<NumpyReduceAxesBoolParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
bool keepdims;
DMLC_DECLARE_PARAMETER(NumpyReduceAxesBoolParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
"all of the elements of the input array. If axis is negative it counts from the "
"last to the first axis.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
}
};

inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
const dmlc::optional<mxnet::Tuple<int>>& axis,
bool keepdims) {
// If input is a scalar, output should be a scalar too
if (ishape.ndim() == 0) {
if (ishape.ndim() == 0{
if (axis.has_value()) {
const mxnet::Tuple<int>& axes = axis.value();
if (axes.ndim() > 0) {
Expand Down Expand Up @@ -173,6 +188,20 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}

inline bool NumpyReduceAxesBoolShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
const NumpyReduceAxesBoolParam& param = nnvm::get<NumpyReduceAxesBoolParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
return shape_is_known(out_attrs->at(0));
}

inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand Down Expand Up @@ -298,6 +327,30 @@ void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}

template<typename xpu, typename reducer, typename OP = op::mshadow_op::NonZero>
void NumpyReduceAxesBoolCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyReduceAxesBoolParam& param = nnvm::get<NumpyReduceAxesBoolParam>(attrs.parsed);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
using namespace mxnet_op;
Kernel<set_false, xpu>::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr<bool>());
return;
}
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
}
TShape small;
if (param.keepdims) {
small = outputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
}
ReduceAxesComputeBoolImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}

template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
Expand Down
85 changes: 85 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_boolean.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_broadcast_reduce_op_boolean.cc
* \brief CPU Implementation of broadcast and reduce functions based on boolean.
*/

#include "./np_broadcast_reduce_op.h"

namespace mxnet {
namespace op {

inline bool NumpyReduceAxesBoolType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

DMLC_REGISTER_PARAMETER(NumpyReduceAxesBoolParam);

NNVM_REGISTER_OP(_np_any)
.set_attr_parser(ParamParser<NumpyReduceAxesBoolParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesBoolShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesBoolType)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBoolCompute<cpu,
mshadow_op::sum, mshadow_op::NonZero>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyReduceAxesBoolParam::__FIELDS__());

NNVM_REGISTER_OP(_np_all)
.set_attr_parser(ParamParser<NumpyReduceAxesBoolParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesBoolShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesBoolType)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBoolCompute<cpu,
mshadow_op::product, mshadow_op::NonZero>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyReduceAxesBoolParam::__FIELDS__());

} // namespace op
} // namespace mxnet
40 changes: 40 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_boolean.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_broadcast_reduce_op_boolean.cu
* \brief GPU Implementation of broadcast and reduce functions based on boolean.
*/

#include "./np_broadcast_reduce_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_np_any)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBoolCompute<gpu,
mshadow_op::sum, mshadow_op::NonZero>);

NNVM_REGISTER_OP(_np_all)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBoolCompute<gpu,
mshadow_op::product, mshadow_op::NonZero>);

} // namespace op
} // namespace mxnet
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NO
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_bool<false>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_bool<true>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT()
Expand Down
10 changes: 10 additions & 0 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,16 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
}
}

template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
void ReduceBool(Stream<gpu> *s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
if (req == kNullOp) return;
cudaStream_t stream = Stream<gpu>::GetStream(s);
ReduceImplConfig<ndim> config =
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
ReduceImpl<Reducer, ndim, bool, DType, bool, OP>(stream, small, req, big, workspace, config);
}

template <typename Reducer, int ndim, typename DType, typename OP>
void ReduceWithExtraMem(Stream<gpu>* s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {};
Expand Down
Loading

0 comments on commit d95f6bb

Please sign in to comment.