From bdfc4329b638915fb0ac59d96dcb2ae52e243f18 Mon Sep 17 00:00:00 2001 From: Haohuan Wang Date: Tue, 23 Jul 2019 22:17:42 -0700 Subject: [PATCH] fix LinearRegressionOutput with empty label (#15620) --- cpp-package/example/CMakeLists.txt | 4 ++ cpp-package/example/test_regress_label.cpp | 56 ++++++++++++++++++++++ cpp-package/tests/ci_test.sh | 3 ++ src/operator/regression_output-inl.h | 3 +- 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 cpp-package/example/test_regress_label.cpp diff --git a/cpp-package/example/CMakeLists.txt b/cpp-package/example/CMakeLists.txt index b4cea68fbd05..643a92d9a3bc 100644 --- a/cpp-package/example/CMakeLists.txt +++ b/cpp-package/example/CMakeLists.txt @@ -27,6 +27,10 @@ endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../include) +add_executable(test_regress_label test_regress_label.cpp ${CPP_PACKAGE_HAEDERS}) +target_link_libraries(test_regress_label ${CPP_EXAMPLE_LIBS}) +add_dependencies(test_regress_label ${CPPEX_DEPS}) + add_executable(lenet lenet.cpp ${CPP_PACKAGE_HEADERS}) target_link_libraries(lenet ${CPP_EXAMPLE_LIBS}) add_dependencies(lenet ${CPPEX_DEPS}) diff --git a/cpp-package/example/test_regress_label.cpp b/cpp-package/example/test_regress_label.cpp new file mode 100644 index 000000000000..8d1d6444b138 --- /dev/null +++ b/cpp-package/example/test_regress_label.cpp @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This file is used for testing LinearRegressionOutput can + * still bind if label is not provided + */ + +#include +#include +#include +#include "dmlc/logging.h" +#include "mxnet-cpp/MxNetCpp.h" + +using namespace mxnet::cpp; + +int main() { + LOG(INFO) << "Running LinearRegressionOutput symbol testing, " + "executor should be able to bind without label."; + Symbol data = Symbol::Variable("data"); + Symbol label = Symbol::Variable("regress_label"); + Symbol symbol = LinearRegressionOutput(data, label); + std::map opReqMap; + for (const auto& iter : symbol.ListArguments()) { + opReqMap[iter] = mxnet::cpp::OpReqType::kNullOp; + } + std::map argMap({ + {"data", NDArray(Shape{1, 3}, Context::cpu(), true)} + }); + + try { + symbol.SimpleBind(Context::cpu(), + argMap, + std::map(), + opReqMap, + std::map()); + } catch (const std::exception& e) { + LOG(ERROR) << "Error binding the symbol: " << MXGetLastError() << " " << e.what(); + throw; + } + return 0; +} diff --git a/cpp-package/tests/ci_test.sh b/cpp-package/tests/ci_test.sh index ef7fceacfd6e..7e4cf7d4d945 100755 --- a/cpp-package/tests/ci_test.sh +++ b/cpp-package/tests/ci_test.sh @@ -60,6 +60,9 @@ cp ../../build/cpp-package/example/test_score . cp ../../build/cpp-package/example/test_ndarray_copy . ./test_ndarray_copy +cp ../../build/cpp-package/example/test_regress_label . +./test_regress_label + sh unittests/unit_test_mlp_csv.sh cd inference diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index ba59937a7152..dcee8027dff0 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -59,7 +59,8 @@ inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs, const mxnet::TShape &dshape = in_attrs->at(0); if (!shape_is_known(dshape)) return false; auto &lshape = (*in_attrs)[1]; - if (lshape.ndim() == 0) { + // if label is not defined, manually build the shape based on dshape + if (lshape.ndim() == -1) { // special treatment for 1D output, to allow 1D label by default. // Think about change convention later if (dshape.ndim() == 2 && dshape[1] == 1) {