Skip to content

Commit

Permalink
global numpy shape flag (apache#16335)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 authored and aaronmarkham committed Oct 16, 2019
1 parent 4547cc4 commit d7870c1
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 14 deletions.
3 changes: 2 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is thread local on,
* 2 when numpy shape semantics is global on and 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
Expand Down
44 changes: 36 additions & 8 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
* turn on numpy shape flag globally, it includes thread local.
* The flag can be seen in any thread.
* ThreadLocalOn
* only turn on thread local numpy shape flag, it cannot be seen
* in other threads.
* Off
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -97,14 +108,30 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! brief whether numpy compatibility is on. */
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
return is_np_shape_;
if (is_np_shape_global_) {
return true;
}
return is_np_shape_thread_local_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
bool old = this->is_np_shape();
switch (flag) {
case GlobalOn:
is_np_shape_global_ = true;
is_np_shape_thread_local_ = true;
break;
case ThreadLocalOn:
is_np_shape_thread_local_ = true;
break;
case Off:
is_np_shape_global_ = false;
is_np_shape_thread_local_ = false;
break;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -177,14 +204,15 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_;
static thread_local bool is_np_shape_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_;
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) {

int MXSetIsNumpyShape(int is_np_shape, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_np_shape(static_cast<bool>(is_np_shape));
*prev = Imperative::Get()->set_is_np_shape(is_np_shape);
API_END();
}

Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
thread_local bool Imperative::is_np_shape_ = false;
thread_local bool Imperative::is_np_shape_thread_local_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
#endif

Imperative* Imperative::Get() {
Expand Down
26 changes: 24 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
# under the License.

import threading
import numpy as np
import mxnet as mx
from mxnet import context, attribute, name
from mxnet.gluon import block
from mxnet.context import Context
from mxnet.attribute import AttrScope
from mxnet.name import NameManager
from mxnet.test_utils import set_default_context
from mxnet.util import _NumpyArrayScope
from mxnet.test_utils import assert_almost_equal, set_default_context
from mxnet.util import _NumpyArrayScope, set_np_shape


def test_context():
ctx_list = []
Expand Down Expand Up @@ -199,6 +201,26 @@ def g():
assert status[0], "Spawned thread didn't set status correctly"


def test_np_global_shape():
set_np_shape(2)
data = []

def f():
# scalar
data.append(mx.np.ones(shape=()))
# zero-dim
data.append(mx.np.ones(shape=(0, 1, 2)))
try:
thread = threading.Thread(target=f)
thread.start()
thread.join()

assert_almost_equal(data[0].asnumpy(), np.ones(shape=()))
assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2)))
finally:
set_np_shape(0)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit d7870c1

Please sign in to comment.