diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 242d5a4abd36..177ec5d40146 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr); MXNET_DLL int MXIsNumpyShape(bool* curr); /*! * \brief set numpy compatibility switch - * \param is_np_shape 1 when numpy shape semantics is on, 0 when off + * \param is_np_shape 1 when numpy shape semantics is thread local on, + * 2 when numpy shape semantics is global on and 0 when off * \param prev returns the previous status before this set * \return 0 when success, -1 when failure happens */ diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index c588565012a5..18f6424e54f7 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,6 +35,17 @@ #include "./ndarray.h" namespace mxnet { + /*! \brief there are three numpy shape flags based on priority. + * GlobalOn + * turn on numpy shape flag globally, it includes thread local. + * The flag can be seen in any thread. + * ThreadLocalOn + * only turn on thread local numpy shape flag, it cannot be seen + * in other threads. + * Off + * turn off numpy shape flag globally. + * */ + enum NumpyShape{Off, ThreadLocalOn, GlobalOn}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -97,14 +108,30 @@ class Imperative { is_recording_ = is_recording; return old; } - /*! brief whether numpy compatibility is on. */ + /*! \brief whether numpy compatibility is on. */ bool is_np_shape() const { - return is_np_shape_; + if (is_np_shape_global_) { + return true; + } + return is_np_shape_thread_local_; } - /*! brief turn on or turn off numpy compatibility switch. */ - bool set_is_np_shape(bool is_np_shape) { - bool old = is_np_shape_; - is_np_shape_ = is_np_shape; + /*! \brief specify numpy compatibility off, thread local on or global on. */ + bool set_is_np_shape(int is_np_shape) { + NumpyShape flag = static_cast(is_np_shape); + bool old = this->is_np_shape(); + switch (flag) { + case GlobalOn: + is_np_shape_global_ = true; + is_np_shape_thread_local_ = true; + break; + case ThreadLocalOn: + is_np_shape_thread_local_ = true; + break; + case Off: + is_np_shape_global_ = false; + is_np_shape_thread_local_ = false; + break; + } return old; } /*! \brief to record operator, return corresponding node. */ @@ -177,14 +204,15 @@ class Imperative { static thread_local bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static thread_local bool is_np_shape_; + static thread_local bool is_np_shape_thread_local_; #else static MX_THREAD_LOCAL bool is_train_; static MX_THREAD_LOCAL bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static MX_THREAD_LOCAL bool is_np_shape_; + static MX_THREAD_LOCAL bool is_np_shape_thread_local_; #endif + bool is_np_shape_global_{false}; /*! \brief node count used for naming */ std::atomic node_count_{0}; /*! \brief variable count used for naming */ diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 5863f08762e6..b80e17c18071 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) { int MXSetIsNumpyShape(int is_np_shape, int* prev) { API_BEGIN(); - *prev = Imperative::Get()->set_is_np_shape(static_cast(is_np_shape)); + *prev = Imperative::Get()->set_is_np_shape(is_np_shape); API_END(); } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index b2ffd1096b2b..b3924cc4d79e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -25,11 +25,11 @@ namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL thread_local bool Imperative::is_train_ = false; thread_local bool Imperative::is_recording_ = false; -thread_local bool Imperative::is_np_shape_ = false; +thread_local bool Imperative::is_np_shape_thread_local_ = false; #else MX_THREAD_LOCAL bool Imperative::is_train_ = false; MX_THREAD_LOCAL bool Imperative::is_recording_ = false; -MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false; +MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false; #endif Imperative* Imperative::Get() { diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index ee56ba780a95..f0e3c660181b 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -16,14 +16,16 @@ # under the License. import threading +import numpy as np import mxnet as mx from mxnet import context, attribute, name from mxnet.gluon import block from mxnet.context import Context from mxnet.attribute import AttrScope from mxnet.name import NameManager -from mxnet.test_utils import set_default_context -from mxnet.util import _NumpyArrayScope +from mxnet.test_utils import assert_almost_equal, set_default_context +from mxnet.util import _NumpyArrayScope, set_np_shape + def test_context(): ctx_list = [] @@ -199,6 +201,26 @@ def g(): assert status[0], "Spawned thread didn't set status correctly" +def test_np_global_shape(): + set_np_shape(2) + data = [] + + def f(): + # scalar + data.append(mx.np.ones(shape=())) + # zero-dim + data.append(mx.np.ones(shape=(0, 1, 2))) + try: + thread = threading.Thread(target=f) + thread.start() + thread.join() + + assert_almost_equal(data[0].asnumpy(), np.ones(shape=())) + assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2))) + finally: + set_np_shape(0) + + if __name__ == '__main__': import nose nose.runmodule()