diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index b2a50f30af4e..b8a362644fc8 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1524,12 +1524,15 @@ nightly_scala_demo_test_cpu() { bash bin/run_im.sh } -nightly_estimator() { +nightly_estimator_numpy_global_flag() { set -ex - cd /work/mxnet/tests/nightly/estimator + cd /work/mxnet/tests/nightly export PYTHONPATH=/work/mxnet/python/ - nosetests test_estimator_cnn.py - nosetests test_sentiment_rnn.py + nosetests estimator/test_estimator_cnn.py + nosetests estimator/test_sentiment_rnn.py + # test global numpy flag seperately where it should not + # run with other tests parallelly + nosetests test_global_numpy_shape.py } # For testing PRs diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 242d5a4abd36..177ec5d40146 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr); MXNET_DLL int MXIsNumpyShape(bool* curr); /*! * \brief set numpy compatibility switch - * \param is_np_shape 1 when numpy shape semantics is on, 0 when off + * \param is_np_shape 1 when numpy shape semantics is thread local on, + * 2 when numpy shape semantics is global on and 0 when off * \param prev returns the previous status before this set * \return 0 when success, -1 when failure happens */ diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index c588565012a5..18f6424e54f7 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,6 +35,17 @@ #include "./ndarray.h" namespace mxnet { + /*! \brief there are three numpy shape flags based on priority. + * GlobalOn + * turn on numpy shape flag globally, it includes thread local. + * The flag can be seen in any thread. + * ThreadLocalOn + * only turn on thread local numpy shape flag, it cannot be seen + * in other threads. + * Off + * turn off numpy shape flag globally. + * */ + enum NumpyShape{Off, ThreadLocalOn, GlobalOn}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -97,14 +108,30 @@ class Imperative { is_recording_ = is_recording; return old; } - /*! brief whether numpy compatibility is on. */ + /*! \brief whether numpy compatibility is on. */ bool is_np_shape() const { - return is_np_shape_; + if (is_np_shape_global_) { + return true; + } + return is_np_shape_thread_local_; } - /*! brief turn on or turn off numpy compatibility switch. */ - bool set_is_np_shape(bool is_np_shape) { - bool old = is_np_shape_; - is_np_shape_ = is_np_shape; + /*! \brief specify numpy compatibility off, thread local on or global on. */ + bool set_is_np_shape(int is_np_shape) { + NumpyShape flag = static_cast(is_np_shape); + bool old = this->is_np_shape(); + switch (flag) { + case GlobalOn: + is_np_shape_global_ = true; + is_np_shape_thread_local_ = true; + break; + case ThreadLocalOn: + is_np_shape_thread_local_ = true; + break; + case Off: + is_np_shape_global_ = false; + is_np_shape_thread_local_ = false; + break; + } return old; } /*! \brief to record operator, return corresponding node. */ @@ -177,14 +204,15 @@ class Imperative { static thread_local bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static thread_local bool is_np_shape_; + static thread_local bool is_np_shape_thread_local_; #else static MX_THREAD_LOCAL bool is_train_; static MX_THREAD_LOCAL bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static MX_THREAD_LOCAL bool is_np_shape_; + static MX_THREAD_LOCAL bool is_np_shape_thread_local_; #endif + bool is_np_shape_global_{false}; /*! \brief node count used for naming */ std::atomic node_count_{0}; /*! \brief variable count used for naming */ diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 5863f08762e6..b80e17c18071 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) { int MXSetIsNumpyShape(int is_np_shape, int* prev) { API_BEGIN(); - *prev = Imperative::Get()->set_is_np_shape(static_cast(is_np_shape)); + *prev = Imperative::Get()->set_is_np_shape(is_np_shape); API_END(); } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index b2ffd1096b2b..b3924cc4d79e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -25,11 +25,11 @@ namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL thread_local bool Imperative::is_train_ = false; thread_local bool Imperative::is_recording_ = false; -thread_local bool Imperative::is_np_shape_ = false; +thread_local bool Imperative::is_np_shape_thread_local_ = false; #else MX_THREAD_LOCAL bool Imperative::is_train_ = false; MX_THREAD_LOCAL bool Imperative::is_recording_ = false; -MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false; +MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false; #endif Imperative* Imperative::Get() { diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries index 5158274010b3..1bef1dbcf57c 100755 --- a/tests/nightly/JenkinsfileForBinaries +++ b/tests/nightly/JenkinsfileForBinaries @@ -131,7 +131,7 @@ core_logic: { node(NODE_LINUX_GPU) { ws('workspace/estimator-test-gpu') { utils.unpack_and_init('gpu', mx_lib) - utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true) + utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_numpy_global_flag', true) } } } diff --git a/tests/nightly/test_global_numpy_shape.py b/tests/nightly/test_global_numpy_shape.py new file mode 100644 index 000000000000..01cfe261e97f --- /dev/null +++ b/tests/nightly/test_global_numpy_shape.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +import threading +from mxnet.util import set_np_shape +from mxnet.test_utils import assert_almost_equal + +# the test should be tested in isolated environment +# without other tests running parallelly as the flag +# effect every single thread +def test_np_global_shape(): + set_np_shape(2) + data = [] + + def f(): + # scalar + data.append(mx.np.ones(shape=())) + # zero-dim + data.append(mx.np.ones(shape=(0, 1, 2))) + + thread = threading.Thread(target=f) + thread.start() + thread.join() + + assert_almost_equal(data[0].asnumpy(), np.ones(shape=())) + assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2))) + + +if __name__ == '__main__': + import nose + nose.runmodule()