Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Add infra for performing constraint check #17272

Merged
merged 9 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,45 @@ def _np__random_shuffle(x):
pass


def _npx_constraint_check(x, msg):
"""
This operator will check if all the elements in a boolean tensor is true.
If not, ValueError exception will be raised in the backend with given error message.
In order to evaluate this operator, one should multiply the origin tensor by the return value
of this operator to force this operator become part of the computation graph,
otherwise the check would not be working under symoblic mode.

Parameters
----------
x : ndarray
A boolean tensor.
msg : string
The error message in the exception.

Returns
-------
out : ndarray
If all the elements in the input tensor are true,
array(True) will be returned, otherwise ValueError exception would
be raised before anything got returned.

Examples
--------
>>> loc = np.zeros((2,2))
>>> scale = np.array(#some_value)
>>> constraint = (scale > 0)
>>> np.random.normal(loc,
scale * npx.constraint_check(constraint, 'Scale should be larger than zero'))

If elements in the scale tensor are all bigger than zero, npx.constraint_check would return
`np.array(True)`, which will not change the value of `scale` when multiplied by.
If some of the elements in the scale tensor violate the constraint,
i.e. there exists `False` in the boolean tensor `constraint`,
a `ValueError` exception with given message 'Scale should be larger than zero' would be raised.
"""
pass


def _npx_reshape(a, newshape, reverse=False, order='C'):
"""
Gives a new shape to an array without changing its data.
Expand Down
99 changes: 99 additions & 0 deletions src/operator/numpy/np_constraint_check.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_constraint_check.cc
* \brief helper function for constraint check
*/

#include "./np_constraint_check.h"

namespace mxnet {
namespace op {

template<>
void GetReduceOutput<cpu>(mshadow::Stream<cpu> *s, const TBlob &output_blob, bool *red_output) {
*red_output = static_cast<bool>(*output_blob.dptr<bool>());
}

inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1))
return true;
}

inline bool ConstraintCheckType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK(in_attrs->at(0) == mshadow::kBool);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

DMLC_REGISTER_PARAMETER(ConstraintCheckParam);

NNVM_REGISTER_OP(_npx_constraint_check)
xidulu marked this conversation as resolved.
Show resolved Hide resolved
.describe(R"code(This operator will check if all the elements in a boolean tensor is true.
If not, ValueError exception will be raised in the backend with given error message.
In order to evaluate this operator, one should multiply the origin tensor by the return value
of this operator to force this operator become part of the computation graph, otherwise the check
would not be working under symoblic mode.

Example:

loc = np.zeros((2,2))
scale = np.array(#some_value)
constraint = (scale > 0)
np.random.normal(loc, scale * npx.constraint_check(constraint, 'Scale should be larger than zero'))

If elements in the scale tensor are all bigger than zero, npx.constraint_check would return
`np.array(True)`, which will not change the value of `scale` when multiplied by.
If some of the elements in the scale tensor violate the constraint, i.e. there exists `False` in
the boolean tensor `constraint`, a `ValueError` exception with given message
'Scale should be larger than zero' would be raised.

)code" ADD_FILELINE)
.set_attr_parser(ParamParser<ConstraintCheckParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"input"};
})
.set_attr<mxnet::FInferShape>("FInferShape", ConstraintCheckShape)
.set_attr<nnvm::FInferType>("FInferType", ConstraintCheckType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", ConstraintCheckForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("input", "NDArray-or-Symbol", "Input boolean array")
.add_arguments(ConstraintCheckParam::__FIELDS__());

} // namespace op
} // namespace mxnet
45 changes: 45 additions & 0 deletions src/operator/numpy/np_constraint_check.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_constraint_check.cu
* \brief helper function for constraint check
*/

#include "./np_constraint_check.h"

namespace mxnet {
namespace op {

template<>
void GetReduceOutput<gpu>(mshadow::Stream<gpu> *s, const TBlob &output_blob, bool *red_output) {
bool tmp = true;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
CUDA_CALL(cudaMemcpyAsync(&tmp, output_blob.dptr<bool>(),
sizeof(bool), cudaMemcpyDeviceToHost,
stream));
CUDA_CALL(cudaStreamSynchronize(stream));
*red_output = static_cast<bool>(tmp);
}

NNVM_REGISTER_OP(_npx_constraint_check)
.set_attr<FCompute>("FCompute<gpu>", ConstraintCheckForward<gpu>);

} // namespace op
} // namespace mxnet
71 changes: 71 additions & 0 deletions src/operator/numpy/np_constraint_check.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_constraint_check.h
* \brief helper function for constraint check
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_
#define MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_

#include <algorithm>
#include <string>
#include <vector>
#include "./np_broadcast_reduce_op.h"

namespace mxnet {
namespace op {

template<typename xpu>
void GetReduceOutput(mshadow::Stream<xpu> *s, const TBlob &output_blob, bool *red_output);

struct ConstraintCheckParam : public dmlc::Parameter<ConstraintCheckParam> {
std::string msg;
DMLC_DECLARE_PARAMETER(ConstraintCheckParam) {
DMLC_DECLARE_FIELD(msg)
.set_default("Constraint violated.")
.describe("Error message raised when constraint violated");
}
};

template <typename xpu>
void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const ConstraintCheckParam& param =
nnvm::get<ConstraintCheckParam>(attrs.parsed);
ReduceAxesComputeImpl<xpu, mshadow_op::product, false, false,
op::mshadow_op::identity>(ctx, inputs, req, outputs,
outputs[0].shape_);
std::string msg = param.msg;
bool red_output = true;
GetReduceOutput(ctx.get_stream<xpu>(), outputs[0], &red_output);
CHECK_EQ(red_output, true) << "ValueError: " << msg;
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_
41 changes: 41 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,47 @@ def _test_bernoulli_exception(prob, logit):
assertRaises(ValueError, _test_bernoulli_exception, scaled_prob, None)


@with_seed()
@use_np
def test_npx_constraint_check():
msg = "condition violated"
class TestConstraintViolatedCheck(HybridBlock):
def __init__(self):
super(TestConstraintViolatedCheck, self).__init__()

def hybrid_forward(self, F, boolean_tensor):
return F.npx.constraint_check(boolean_tensor, msg)

class TestConstraintNotViolatedCheck(HybridBlock):
def __init__(self):
super(TestConstraintNotViolatedCheck, self).__init__()

def hybrid_forward(self, F, input, boolean_tensor):
return input * F.npx.constraint_check(boolean_tensor, msg)

def raiseFunc(block):
def executor(boolean_tensor):
out = block(boolean_tensor).asnumpy()
return executor

shapes = [(1,), (2, 3), 6, (7, 8)]

expect_success_output = np.array(True)
for shape, hybridize in itertools.product(shapes, [True, False]):
test_constraint = TestConstraintViolatedCheck()
if hybridize:
test_constraint.hybridize()
assertRaises(ValueError, raiseFunc(test_constraint), np.zeros(shape, dtype='bool'))

for shape, hybridize in itertools.product(shapes, [True, False]):
test_constraint = TestConstraintNotViolatedCheck()
if hybridize:
test_constraint.hybridize()
input_tensor = np.random.normal(size=shape)
out = test_constraint(input_tensor, np.ones(shape, dtype='bool'))
assert (input_tensor.asnumpy() == out.asnumpy()).all()


@with_seed()
@use_np
def test_npx_special_unary_func():
Expand Down