From 40125aa452324d00c2473057668a90b9ca6eddbe Mon Sep 17 00:00:00 2001 From: stu1130 Date: Sun, 1 Sep 2019 20:12:16 -0700 Subject: [PATCH] global numpy shape flag --- ci/docker/runtime_functions.sh | 7 +++ ci/jenkins/Jenkins_steps.groovy | 14 ++++++ include/mxnet/c_api.h | 3 +- include/mxnet/imperative.h | 44 ++++++++++++++---- src/c_api/c_api_ndarray.cc | 2 +- src/imperative/imperative.cc | 4 +- tests/python/other/test_global_numpy_shape.py | 45 +++++++++++++++++++ 7 files changed, 107 insertions(+), 12 deletions(-) create mode 100644 tests/python/other/test_global_numpy_shape.py diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index b2a50f30af4e..70e3a6494d2e 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1350,6 +1350,13 @@ test_ubuntu_cpu_python3() { popd } +test_global() { + set -ex + export PYTHONPATH=./python/ + cd /work/mxnet/tests/python/other + nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_unittest.xml --verbose test_global_numpy_shape.py +} + # Functions that run the nightly Tests: #Runs Apache RAT Check on MXNet Source for License Headers diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 30db32252e66..2edbbebeadf5 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -1184,6 +1184,20 @@ def test_unix_distributed_kvstore_gpu() { }] } +def test_unix_other_cpu() { + return ['unix-other tests CPU': { + node(NODE_LINUX_CPU) { + ws('workspace/unix-other-tests') { + timeout(time: max_time, unit: 'MINUTES') { + utils.unpack_and_init('cpu', mx_lib) + utils.docker_run('ubuntu_cpu', 'test_global', true) + utils.publish_test_coverage() + } + } + } + }] +} + def test_centos7_python3_cpu() { return ['Python3: CentOS 7 CPU': { node(NODE_LINUX_CPU) { diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 242d5a4abd36..177ec5d40146 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr); MXNET_DLL int MXIsNumpyShape(bool* curr); /*! * \brief set numpy compatibility switch - * \param is_np_shape 1 when numpy shape semantics is on, 0 when off + * \param is_np_shape 1 when numpy shape semantics is thread local on, + * 2 when numpy shape semantics is global on and 0 when off * \param prev returns the previous status before this set * \return 0 when success, -1 when failure happens */ diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index c588565012a5..18f6424e54f7 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,6 +35,17 @@ #include "./ndarray.h" namespace mxnet { + /*! \brief there are three numpy shape flags based on priority. + * GlobalOn + * turn on numpy shape flag globally, it includes thread local. + * The flag can be seen in any thread. + * ThreadLocalOn + * only turn on thread local numpy shape flag, it cannot be seen + * in other threads. + * Off + * turn off numpy shape flag globally. + * */ + enum NumpyShape{Off, ThreadLocalOn, GlobalOn}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -97,14 +108,30 @@ class Imperative { is_recording_ = is_recording; return old; } - /*! brief whether numpy compatibility is on. */ + /*! \brief whether numpy compatibility is on. */ bool is_np_shape() const { - return is_np_shape_; + if (is_np_shape_global_) { + return true; + } + return is_np_shape_thread_local_; } - /*! brief turn on or turn off numpy compatibility switch. */ - bool set_is_np_shape(bool is_np_shape) { - bool old = is_np_shape_; - is_np_shape_ = is_np_shape; + /*! \brief specify numpy compatibility off, thread local on or global on. */ + bool set_is_np_shape(int is_np_shape) { + NumpyShape flag = static_cast(is_np_shape); + bool old = this->is_np_shape(); + switch (flag) { + case GlobalOn: + is_np_shape_global_ = true; + is_np_shape_thread_local_ = true; + break; + case ThreadLocalOn: + is_np_shape_thread_local_ = true; + break; + case Off: + is_np_shape_global_ = false; + is_np_shape_thread_local_ = false; + break; + } return old; } /*! \brief to record operator, return corresponding node. */ @@ -177,14 +204,15 @@ class Imperative { static thread_local bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static thread_local bool is_np_shape_; + static thread_local bool is_np_shape_thread_local_; #else static MX_THREAD_LOCAL bool is_train_; static MX_THREAD_LOCAL bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static MX_THREAD_LOCAL bool is_np_shape_; + static MX_THREAD_LOCAL bool is_np_shape_thread_local_; #endif + bool is_np_shape_global_{false}; /*! \brief node count used for naming */ std::atomic node_count_{0}; /*! \brief variable count used for naming */ diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 5863f08762e6..b80e17c18071 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) { int MXSetIsNumpyShape(int is_np_shape, int* prev) { API_BEGIN(); - *prev = Imperative::Get()->set_is_np_shape(static_cast(is_np_shape)); + *prev = Imperative::Get()->set_is_np_shape(is_np_shape); API_END(); } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index b2ffd1096b2b..b3924cc4d79e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -25,11 +25,11 @@ namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL thread_local bool Imperative::is_train_ = false; thread_local bool Imperative::is_recording_ = false; -thread_local bool Imperative::is_np_shape_ = false; +thread_local bool Imperative::is_np_shape_thread_local_ = false; #else MX_THREAD_LOCAL bool Imperative::is_train_ = false; MX_THREAD_LOCAL bool Imperative::is_recording_ = false; -MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false; +MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false; #endif Imperative* Imperative::Get() { diff --git a/tests/python/other/test_global_numpy_shape.py b/tests/python/other/test_global_numpy_shape.py new file mode 100644 index 000000000000..de6e3c82f5f1 --- /dev/null +++ b/tests/python/other/test_global_numpy_shape.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +import threading +from mxnet.util import set_np_shape +from mxnet.test_utils import assert_almost_equal + + +def test_np_global_shape(): + set_np_shape(2) + data = [] + + def f(): + # scalar + data.append(mx.np.ones(shape=())) + # zero-dim + data.append(mx.np.ones(shape=(0, 1, 2))) + + thread = threading.Thread(target=f) + thread.start() + thread.join() + + assert_almost_equal(data[0].asnumpy(), np.ones(shape=())) + assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2))) + + +if __name__ == '__main__': + import nose + nose.runmodule()