diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 190150eb5a0a..65f1f2f9407c 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1132,6 +1132,45 @@ def _np__random_shuffle(x): pass +def _npx_constraint_check(x, msg): + """ + This operator will check if all the elements in a boolean tensor is true. + If not, ValueError exception will be raised in the backend with given error message. + In order to evaluate this operator, one should multiply the origin tensor by the return value + of this operator to force this operator become part of the computation graph, + otherwise the check would not be working under symoblic mode. + + Parameters + ---------- + x : ndarray + A boolean tensor. + msg : string + The error message in the exception. + + Returns + ------- + out : ndarray + If all the elements in the input tensor are true, + array(True) will be returned, otherwise ValueError exception would + be raised before anything got returned. + + Examples + -------- + >>> loc = np.zeros((2,2)) + >>> scale = np.array(#some_value) + >>> constraint = (scale > 0) + >>> np.random.normal(loc, + scale * npx.constraint_check(constraint, 'Scale should be larger than zero')) + + If elements in the scale tensor are all bigger than zero, npx.constraint_check would return + `np.array(True)`, which will not change the value of `scale` when multiplied by. + If some of the elements in the scale tensor violate the constraint, + i.e. there exists `False` in the boolean tensor `constraint`, + a `ValueError` exception with given message 'Scale should be larger than zero' would be raised. + """ + pass + + def _npx_reshape(a, newshape, reverse=False, order='C'): """ Gives a new shape to an array without changing its data. diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc new file mode 100644 index 000000000000..450f4ddb392b --- /dev/null +++ b/src/operator/numpy/np_constraint_check.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_constraint_check.cc + * \brief helper function for constraint check + */ + +#include "./np_constraint_check.h" + +namespace mxnet { +namespace op { + +template<> +void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, bool *red_output) { + *red_output = static_cast(*output_blob.dptr()); +} + +inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + return true; +} + +inline bool ConstraintCheckType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK(in_attrs->at(0) == mshadow::kBool); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(ConstraintCheckParam); + +NNVM_REGISTER_OP(_npx_constraint_check) +.describe(R"code(This operator will check if all the elements in a boolean tensor is true. +If not, ValueError exception will be raised in the backend with given error message. +In order to evaluate this operator, one should multiply the origin tensor by the return value +of this operator to force this operator become part of the computation graph, otherwise the check +would not be working under symoblic mode. + +Example: + +loc = np.zeros((2,2)) +scale = np.array(#some_value) +constraint = (scale > 0) +np.random.normal(loc, scale * npx.constraint_check(constraint, 'Scale should be larger than zero')) + +If elements in the scale tensor are all bigger than zero, npx.constraint_check would return +`np.array(True)`, which will not change the value of `scale` when multiplied by. +If some of the elements in the scale tensor violate the constraint, i.e. there exists `False` in +the boolean tensor `constraint`, a `ValueError` exception with given message +'Scale should be larger than zero' would be raised. + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"input"}; + }) +.set_attr("FInferShape", ConstraintCheckShape) +.set_attr("FInferType", ConstraintCheckType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", ConstraintCheckForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input", "NDArray-or-Symbol", "Input boolean array") +.add_arguments(ConstraintCheckParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_constraint_check.cu b/src/operator/numpy/np_constraint_check.cu new file mode 100644 index 000000000000..f83fca0e5c33 --- /dev/null +++ b/src/operator/numpy/np_constraint_check.cu @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_constraint_check.cu + * \brief helper function for constraint check + */ + +#include "./np_constraint_check.h" + +namespace mxnet { +namespace op { + +template<> +void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, bool *red_output) { + bool tmp = true; + cudaStream_t stream = mshadow::Stream::GetStream(s); + CUDA_CALL(cudaMemcpyAsync(&tmp, output_blob.dptr(), + sizeof(bool), cudaMemcpyDeviceToHost, + stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + *red_output = static_cast(tmp); +} + +NNVM_REGISTER_OP(_npx_constraint_check) +.set_attr("FCompute", ConstraintCheckForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h new file mode 100644 index 000000000000..80beaa3a0bf5 --- /dev/null +++ b/src/operator/numpy/np_constraint_check.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_constraint_check.h + * \brief helper function for constraint check + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_ +#define MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_ + +#include +#include +#include +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +template +void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, bool *red_output); + +struct ConstraintCheckParam : public dmlc::Parameter { + std::string msg; + DMLC_DECLARE_PARAMETER(ConstraintCheckParam) { + DMLC_DECLARE_FIELD(msg) + .set_default("Constraint violated.") + .describe("Error message raised when constraint violated"); + } +}; + +template +void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const ConstraintCheckParam& param = + nnvm::get(attrs.parsed); + ReduceAxesComputeImpl(ctx, inputs, req, outputs, + outputs[0].shape_); + std::string msg = param.msg; + bool red_output = true; + GetReduceOutput(ctx.get_stream(), outputs[0], &red_output); + CHECK_EQ(red_output, true) << "ValueError: " << msg; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b25c69385e1e..274f37124262 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3142,6 +3142,47 @@ def _test_bernoulli_exception(prob, logit): assertRaises(ValueError, _test_bernoulli_exception, scaled_prob, None) +@with_seed() +@use_np +def test_npx_constraint_check(): + msg = "condition violated" + class TestConstraintViolatedCheck(HybridBlock): + def __init__(self): + super(TestConstraintViolatedCheck, self).__init__() + + def hybrid_forward(self, F, boolean_tensor): + return F.npx.constraint_check(boolean_tensor, msg) + + class TestConstraintNotViolatedCheck(HybridBlock): + def __init__(self): + super(TestConstraintNotViolatedCheck, self).__init__() + + def hybrid_forward(self, F, input, boolean_tensor): + return input * F.npx.constraint_check(boolean_tensor, msg) + + def raiseFunc(block): + def executor(boolean_tensor): + out = block(boolean_tensor).asnumpy() + return executor + + shapes = [(1,), (2, 3), 6, (7, 8)] + + expect_success_output = np.array(True) + for shape, hybridize in itertools.product(shapes, [True, False]): + test_constraint = TestConstraintViolatedCheck() + if hybridize: + test_constraint.hybridize() + assertRaises(ValueError, raiseFunc(test_constraint), np.zeros(shape, dtype='bool')) + + for shape, hybridize in itertools.product(shapes, [True, False]): + test_constraint = TestConstraintNotViolatedCheck() + if hybridize: + test_constraint.hybridize() + input_tensor = np.random.normal(size=shape) + out = test_constraint(input_tensor, np.ones(shape, dtype='bool')) + assert (input_tensor.asnumpy() == out.asnumpy()).all() + + @with_seed() @use_np def test_npx_special_unary_func():