Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
global numpy shape flag
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Oct 4, 2019
1 parent 916fbf2 commit e7e19cb
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
3 changes: 2 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is thread local on,
* 2 when numpy shape semantics is global on and 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
Expand Down
44 changes: 36 additions & 8 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
* turn on numpy shape flag globally, it includes thread local.
* The flag can be seen in any thread.
* ThreadLocalOn
* only turn on thread local numpy shape flag, it cannot be seen
* in other threads.
* Off
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -97,14 +108,30 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! brief whether numpy compatibility is on. */
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
return is_np_shape_;
if (is_np_shape_global_) {
return true;
}
return is_np_shape_thread_local_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
bool old = this->is_np_shape();
switch (flag) {
case GlobalOn:
is_np_shape_global_ = true;
is_np_shape_thread_local_ = true;
break;
case ThreadLocalOn:
is_np_shape_thread_local_ = true;
break;
case Off:
is_np_shape_global_ = false;
is_np_shape_thread_local_ = false;
break;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -177,14 +204,15 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_;
static thread_local bool is_np_shape_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_;
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) {

int MXSetIsNumpyShape(int is_np_shape, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_np_shape(static_cast<bool>(is_np_shape));
*prev = Imperative::Get()->set_is_np_shape(is_np_shape);
API_END();
}

Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
thread_local bool Imperative::is_np_shape_ = false;
thread_local bool Imperative::is_np_shape_thread_local_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
#endif

Imperative* Imperative::Get() {
Expand Down

0 comments on commit e7e19cb

Please sign in to comment.