diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 242d5a4abd36..177ec5d40146 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr); MXNET_DLL int MXIsNumpyShape(bool* curr); /*! * \brief set numpy compatibility switch - * \param is_np_shape 1 when numpy shape semantics is on, 0 when off + * \param is_np_shape 1 when numpy shape semantics is thread local on, + * 2 when numpy shape semantics is global on and 0 when off * \param prev returns the previous status before this set * \return 0 when success, -1 when failure happens */ diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index c588565012a5..575f94b26f9d 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,6 +35,7 @@ #include "./ndarray.h" namespace mxnet { + enum NumpyShape{Off, On, Global}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -97,14 +98,30 @@ class Imperative { is_recording_ = is_recording; return old; } - /*! brief whether numpy compatibility is on. */ + /*! \brief whether numpy compatibility is on. */ bool is_np_shape() const { + if (is_np_shape_global_) { + return true; + } return is_np_shape_; } - /*! brief turn on or turn off numpy compatibility switch. */ - bool set_is_np_shape(bool is_np_shape) { - bool old = is_np_shape_; - is_np_shape_ = is_np_shape; + /*! \brief specify numpy compatibility off, thread local on or global on. */ + bool set_is_np_shape(int is_np_shape) { + NumpyShape flag = static_cast(is_np_shape); + bool old = this->is_np_shape(); + switch (flag) { + case Global: + is_np_shape_global_ = true; + is_np_shape_ = true; + break; + case On: + is_np_shape_ = true; + break; + case Off: + is_np_shape_global_ = false; + is_np_shape_ = false; + break; + } return old; } /*! \brief to record operator, return corresponding node. */ @@ -185,6 +202,7 @@ class Imperative { // Delete it in the next major release. static MX_THREAD_LOCAL bool is_np_shape_; #endif + bool is_np_shape_global_{false}; /*! \brief node count used for naming */ std::atomic node_count_{0}; /*! \brief variable count used for naming */ diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 5863f08762e6..b80e17c18071 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) { int MXSetIsNumpyShape(int is_np_shape, int* prev) { API_BEGIN(); - *prev = Imperative::Get()->set_is_np_shape(static_cast(is_np_shape)); + *prev = Imperative::Get()->set_is_np_shape(is_np_shape); API_END(); }