From 440da13038d892b71c38e3af50b839efde017df1 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 16 Nov 2019 16:03:07 +0000 Subject: [PATCH 1/7] exception not working --- src/operator/numpy/np_constraint_check.cc | 60 +++++++++++++++++++++++ src/operator/numpy/np_constraint_check.cu | 21 ++++++++ src/operator/numpy/np_constraint_check.h | 47 ++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 src/operator/numpy/np_constraint_check.cc create mode 100644 src/operator/numpy/np_constraint_check.cu create mode 100644 src/operator/numpy/np_constraint_check.h diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc new file mode 100644 index 000000000000..85e33b9dcd3d --- /dev/null +++ b/src/operator/numpy/np_constraint_check.cc @@ -0,0 +1,60 @@ +#include "./np_constraint_check.h" + +namespace mxnet { +namespace op { + +template<> +void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, bool *red_output) { + *red_output = static_cast(*output_blob.dptr()); +} + +inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + // Only 1-D support is supported. + CHECK_EQ(in_attrs->at(0).ndim(), 1U) << "Only 1-D input is supported."; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + return true; +} + +inline bool ConstraintCheckType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK(in_attrs->at(0) == mshadow::kBool); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(ConstraintCheckParam); + +NNVM_REGISTER_OP(_npx_constraint_check) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"input"}; + }) +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + return 0; +}) +.set_attr("FInferShape", ConstraintCheckShape) +.set_attr("FInferType", ConstraintCheckType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", ConstraintCheckForward) +.add_argument("input", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(ConstraintCheckParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_constraint_check.cu b/src/operator/numpy/np_constraint_check.cu new file mode 100644 index 000000000000..79b522fc28de --- /dev/null +++ b/src/operator/numpy/np_constraint_check.cu @@ -0,0 +1,21 @@ +#include "./np_constraint_check.h" + +namespace mxnet { +namespace op { + +template<> +void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, bool *red_output) { + bool tmp = true; + cudaStream_t stream = mshadow::Stream::GetStream(s); + CUDA_CALL(cudaMemcpyAsync(&tmp, output_blob.dptr(), + sizeof(bool), cudaMemcpyDeviceToHost, + stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + *red_output = static_cast(tmp); +} + +NNVM_REGISTER_OP(_npx_constraint_check) +.set_attr("FCompute", ConstraintCheckForward); + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h new file mode 100644 index 000000000000..ed2cf81f6f69 --- /dev/null +++ b/src/operator/numpy/np_constraint_check.h @@ -0,0 +1,47 @@ +#include +#include +#include +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "./np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +template +void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, bool *red_output); + +struct ConstraintCheckParam : public dmlc::Parameter { + std::string msg; + DMLC_DECLARE_PARAMETER(ConstraintCheckParam) { + DMLC_DECLARE_FIELD(msg) + .set_default("Constraint violated!") + .describe("Error message raised when constraint violated"); + } +}; + +template +void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + // ??? + CHECK(false); + const ConstraintCheckParam& param = + nnvm::get(attrs.parsed); + ReduceAxesComputeImpl(ctx, inputs, req, outputs, + outputs[0].shape_); + std::string msg = param.msg; + bool red_output = true; + GetReduceOutput(ctx.get_stream(), outputs[0], &red_output); + CHECK_EQ(red_output, true) << msg; +} + +} // namespace op +} // namespace mxnet \ No newline at end of file From 1a88ec1bd51bca8f83759daed2f341583eaffe3b Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 18 Nov 2019 03:24:58 +0000 Subject: [PATCH 2/7] code polished --- src/operator/numpy/np_constraint_check.cc | 35 +++++++++++++++++---- src/operator/numpy/np_constraint_check.cu | 26 +++++++++++++++- src/operator/numpy/np_constraint_check.h | 38 ++++++++++++++++++----- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc index 85e33b9dcd3d..65d2d1051a94 100644 --- a/src/operator/numpy/np_constraint_check.cc +++ b/src/operator/numpy/np_constraint_check.cc @@ -1,3 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_constraint_check.cc + * \brief helper function for constraint check + */ + #include "./np_constraint_check.h" namespace mxnet { @@ -17,7 +41,7 @@ inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs, return false; } // Only 1-D support is supported. - CHECK_EQ(in_attrs->at(0).ndim(), 1U) << "Only 1-D input is supported."; + // CHECK_EQ(in_attrs->at(0).ndim(), 1U) << "Only 1-D input is supported."; SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) return true; } @@ -35,6 +59,9 @@ inline bool ConstraintCheckType(const nnvm::NodeAttrs& attrs, DMLC_REGISTER_PARAMETER(ConstraintCheckParam); NNVM_REGISTER_OP(_npx_constraint_check) +.describe(R"code(Check if all the elements in a 1-D boolean array is true. +If not, exception will be raised with given error message. +)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) @@ -42,10 +69,6 @@ NNVM_REGISTER_OP(_npx_constraint_check) [](const NodeAttrs& attrs) { return std::vector{"input"}; }) -.set_attr("FNumVisibleOutputs", - [](const NodeAttrs& attrs) { - return 0; -}) .set_attr("FInferShape", ConstraintCheckShape) .set_attr("FInferType", ConstraintCheckType) .set_attr("FResourceRequest", @@ -53,7 +76,7 @@ NNVM_REGISTER_OP(_npx_constraint_check) return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", ConstraintCheckForward) -.add_argument("input", "NDArray-or-Symbol", "Input ndarray") +.add_argument("input", "NDArray-or-Symbol", "Input boolean array") .add_arguments(ConstraintCheckParam::__FIELDS__()); } // namespace op diff --git a/src/operator/numpy/np_constraint_check.cu b/src/operator/numpy/np_constraint_check.cu index 79b522fc28de..f83fca0e5c33 100644 --- a/src/operator/numpy/np_constraint_check.cu +++ b/src/operator/numpy/np_constraint_check.cu @@ -1,3 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_constraint_check.cu + * \brief helper function for constraint check + */ + #include "./np_constraint_check.h" namespace mxnet { @@ -18,4 +42,4 @@ NNVM_REGISTER_OP(_npx_constraint_check) .set_attr("FCompute", ConstraintCheckForward); } // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace mxnet diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h index ed2cf81f6f69..ab955f2b0409 100644 --- a/src/operator/numpy/np_constraint_check.h +++ b/src/operator/numpy/np_constraint_check.h @@ -1,9 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_constraint_check.h + * \brief helper function for constraint check + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_ +#define MXNET_OPERATOR_NUMPY_NP_CONSTRAINT_CHECK_H_ + #include #include #include -#include "../mshadow_op.h" -#include "../mxnet_op.h" -#include "../operator_common.h" #include "./np_broadcast_reduce_op.h" namespace mxnet { @@ -16,7 +40,7 @@ struct ConstraintCheckParam : public dmlc::Parameter { std::string msg; DMLC_DECLARE_PARAMETER(ConstraintCheckParam) { DMLC_DECLARE_FIELD(msg) - .set_default("Constraint violated!") + .set_default("Constraint violated.") .describe("Error message raised when constraint violated"); } }; @@ -30,8 +54,6 @@ void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, using namespace mxnet_op; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - // ??? - CHECK(false); const ConstraintCheckParam& param = nnvm::get(attrs.parsed); ReduceAxesComputeImpl Date: Thu, 21 Nov 2019 04:20:17 +0000 Subject: [PATCH 3/7] add zero-grad node --- src/operator/numpy/np_constraint_check.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc index 65d2d1051a94..3b582bc3f379 100644 --- a/src/operator/numpy/np_constraint_check.cc +++ b/src/operator/numpy/np_constraint_check.cc @@ -76,6 +76,7 @@ If not, exception will be raised with given error message. return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", ConstraintCheckForward) +.set_attr("FGradient", MakeZeroGradNodes) .add_argument("input", "NDArray-or-Symbol", "Input boolean array") .add_arguments(ConstraintCheckParam::__FIELDS__()); From 6bf3ace2ebeafd20508a1b200519b85d087af849 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 11 Jan 2020 05:03:21 +0000 Subject: [PATCH 4/7] backend polished --- src/operator/numpy/np_constraint_check.cc | 2 -- src/operator/numpy/np_constraint_check.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc index 3b582bc3f379..3cd05740a905 100644 --- a/src/operator/numpy/np_constraint_check.cc +++ b/src/operator/numpy/np_constraint_check.cc @@ -40,8 +40,6 @@ inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs, if (!shape_is_known(in_attrs->at(0))) { return false; } - // Only 1-D support is supported. - // CHECK_EQ(in_attrs->at(0).ndim(), 1U) << "Only 1-D input is supported."; SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) return true; } diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h index ab955f2b0409..917c2fad9aba 100644 --- a/src/operator/numpy/np_constraint_check.h +++ b/src/operator/numpy/np_constraint_check.h @@ -62,7 +62,7 @@ void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, std::string msg = param.msg; bool red_output = true; GetReduceOutput(ctx.get_stream(), outputs[0], &red_output); - CHECK_EQ(red_output, true) << msg; + CHECK_EQ(red_output, true) << "ValueError: " << msg; } } // namespace op From 53147f2a48131d0f64cabae4e9b95be58f5e2d05 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 11 Jan 2020 07:03:39 +0000 Subject: [PATCH 5/7] add tests --- tests/python/unittest/test_numpy_op.py | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b25c69385e1e..db975b1600ce 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3142,6 +3142,47 @@ def _test_bernoulli_exception(prob, logit): assertRaises(ValueError, _test_bernoulli_exception, scaled_prob, None) +@with_seed() +@use_np +def test_npx_constraint_check(): + msg = "condition violated" + class TestConstraintViolatedCheck(HybridBlock): + def __init__(self): + super(TestConstraintViolatedCheck, self).__init__() + + def hybrid_forward(self, F, boolean_tensor): + return F.npx.constraint_check(boolean_tensor, msg) + + class TestConstraintNotViolatedCheck(HybridBlock): + def __init__(self): + super(TestConstraintNotViolatedCheck, self).__init__() + + def hybrid_forward(self, F, input, boolean_tensor): + return input * F.npx.constraint_check(boolean_tensor, msg) + + def raiseFunc(block): + def executor(boolean_tensor): + out = block(boolean_tensor).asnumpy() + return executor + + shapes = [(1,), (2, 3), 6, (7, 8)] + + expect_success_output = np.array(True) + for shape, hybridize in itertools.product(shapes, [True, False]): + test_constraint = TestConstraintViolatedCheck() + if hybridize: + test_constraint.hybridize() + assertRaises(ValueError, raiseFunc(test_constraint), (np.ones(shape) < 0)) + + for shape, hybridize in itertools.product(shapes, [True, False]): + test_constraint = TestConstraintNotViolatedCheck() + if hybridize: + test_constraint.hybridize() + input_tensor = np.random.normal(size=shape) + out = test_constraint(input_tensor, (np.ones(shape) > 0)) + assert (input_tensor.asnumpy() == out.asnumpy()).all() + + @with_seed() @use_np def test_npx_special_unary_func(): From b361557a83f7cba24da55aeb0d44075a916882a2 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 11 Jan 2020 12:05:41 +0000 Subject: [PATCH 6/7] add doc, modify tests to pass CI with tvm OP --- src/operator/numpy/np_constraint_check.cc | 21 +++++++++++++++++++-- tests/python/unittest/test_numpy_op.py | 4 ++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc index 3cd05740a905..a84171e14985 100644 --- a/src/operator/numpy/np_constraint_check.cc +++ b/src/operator/numpy/np_constraint_check.cc @@ -57,8 +57,25 @@ inline bool ConstraintCheckType(const nnvm::NodeAttrs& attrs, DMLC_REGISTER_PARAMETER(ConstraintCheckParam); NNVM_REGISTER_OP(_npx_constraint_check) -.describe(R"code(Check if all the elements in a 1-D boolean array is true. -If not, exception will be raised with given error message. +.describe(R"code(This operator will check if all the elements in a boolean tensor is true. +If not, ValueError exception will be raised in the backend with given error message. +In order to evaluate this operator, one should multiply the origin tensor by the return value +of this operator to force this operator become part of the computation graph, otherwise the check +would not be working under symoblic mode. + +Example: + +loc = np.zeros((2,2)) +scale = np.array(#some_value) +constraint = (scale > 0) +np.random.normal(loc, scale * npx.constraint_check(constraint, 'Scale should be larger than zero')) + +If elements in the scale tensor are all bigger than zero, npx.constraint_check would return +`np.array(True)`, which will not change the value of `scale` when multiplied by. +If some of the elements in the scale tensor violate the constraint, i.e. there exists `False` in +the boolean tensor `constraint`, a `ValueError` exception with given message +'Scale should be larger than zero' would be raised. + )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index db975b1600ce..274f37124262 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3172,14 +3172,14 @@ def executor(boolean_tensor): test_constraint = TestConstraintViolatedCheck() if hybridize: test_constraint.hybridize() - assertRaises(ValueError, raiseFunc(test_constraint), (np.ones(shape) < 0)) + assertRaises(ValueError, raiseFunc(test_constraint), np.zeros(shape, dtype='bool')) for shape, hybridize in itertools.product(shapes, [True, False]): test_constraint = TestConstraintNotViolatedCheck() if hybridize: test_constraint.hybridize() input_tensor = np.random.normal(size=shape) - out = test_constraint(input_tensor, (np.ones(shape) > 0)) + out = test_constraint(input_tensor, np.ones(shape, dtype='bool')) assert (input_tensor.asnumpy() == out.asnumpy()).all() From bd1f295deb234cdfdbe226eddf2479d73fafea78 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sun, 12 Jan 2020 09:00:31 +0000 Subject: [PATCH 7/7] fix doc and style --- python/mxnet/_numpy_op_doc.py | 39 +++++++++++++++++++++++ src/operator/numpy/np_constraint_check.cc | 8 ++--- src/operator/numpy/np_constraint_check.h | 6 ++-- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 190150eb5a0a..65f1f2f9407c 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1132,6 +1132,45 @@ def _np__random_shuffle(x): pass +def _npx_constraint_check(x, msg): + """ + This operator will check if all the elements in a boolean tensor is true. + If not, ValueError exception will be raised in the backend with given error message. + In order to evaluate this operator, one should multiply the origin tensor by the return value + of this operator to force this operator become part of the computation graph, + otherwise the check would not be working under symoblic mode. + + Parameters + ---------- + x : ndarray + A boolean tensor. + msg : string + The error message in the exception. + + Returns + ------- + out : ndarray + If all the elements in the input tensor are true, + array(True) will be returned, otherwise ValueError exception would + be raised before anything got returned. + + Examples + -------- + >>> loc = np.zeros((2,2)) + >>> scale = np.array(#some_value) + >>> constraint = (scale > 0) + >>> np.random.normal(loc, + scale * npx.constraint_check(constraint, 'Scale should be larger than zero')) + + If elements in the scale tensor are all bigger than zero, npx.constraint_check would return + `np.array(True)`, which will not change the value of `scale` when multiplied by. + If some of the elements in the scale tensor violate the constraint, + i.e. there exists `False` in the boolean tensor `constraint`, + a `ValueError` exception with given message 'Scale should be larger than zero' would be raised. + """ + pass + + def _npx_reshape(a, newshape, reverse=False, order='C'): """ Gives a new shape to an array without changing its data. diff --git a/src/operator/numpy/np_constraint_check.cc b/src/operator/numpy/np_constraint_check.cc index a84171e14985..450f4ddb392b 100644 --- a/src/operator/numpy/np_constraint_check.cc +++ b/src/operator/numpy/np_constraint_check.cc @@ -33,8 +33,8 @@ void GetReduceOutput(mshadow::Stream *s, const TBlob &output_blob, boo } inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); if (!shape_is_known(in_attrs->at(0))) { @@ -45,8 +45,8 @@ inline bool ConstraintCheckShape(const nnvm::NodeAttrs& attrs, } inline bool ConstraintCheckType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); CHECK(in_attrs->at(0) == mshadow::kBool); diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h index 917c2fad9aba..80beaa3a0bf5 100644 --- a/src/operator/numpy/np_constraint_check.h +++ b/src/operator/numpy/np_constraint_check.h @@ -47,9 +47,9 @@ struct ConstraintCheckParam : public dmlc::Parameter { template void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; CHECK_EQ(inputs.size(), 1U);