Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
global numpy shape flag
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Oct 6, 2019
1 parent 916fbf2 commit 4bc025d
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 17 deletions.
11 changes: 7 additions & 4 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1524,12 +1524,15 @@ nightly_scala_demo_test_cpu() {
bash bin/run_im.sh
}

nightly_estimator() {
nightly_estimator_numpy_global_flag() {
set -ex
cd /work/mxnet/tests/nightly/estimator
cd /work/mxnet/tests/nightly
export PYTHONPATH=/work/mxnet/python/
nosetests test_estimator_cnn.py
nosetests test_sentiment_rnn.py
nosetests estimator/test_estimator_cnn.py
nosetests estimator/test_sentiment_rnn.py
# test global numpy flag seperately where it should not
# run with other tests parallelly
nosetests test_global_numpy_shape.py
}

# For testing PRs
Expand Down
3 changes: 2 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is thread local on,
* 2 when numpy shape semantics is global on and 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
Expand Down
44 changes: 36 additions & 8 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
* turn on numpy shape flag globally, it includes thread local.
* The flag can be seen in any thread.
* ThreadLocalOn
* only turn on thread local numpy shape flag, it cannot be seen
* in other threads.
* Off
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -97,14 +108,30 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! brief whether numpy compatibility is on. */
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
return is_np_shape_;
if (is_np_shape_global_) {
return true;
}
return is_np_shape_thread_local_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
bool old = this->is_np_shape();
switch (flag) {
case GlobalOn:
is_np_shape_global_ = true;
is_np_shape_thread_local_ = true;
break;
case ThreadLocalOn:
is_np_shape_thread_local_ = true;
break;
case Off:
is_np_shape_global_ = false;
is_np_shape_thread_local_ = false;
break;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -177,14 +204,15 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_;
static thread_local bool is_np_shape_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_;
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) {

int MXSetIsNumpyShape(int is_np_shape, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_np_shape(static_cast<bool>(is_np_shape));
*prev = Imperative::Get()->set_is_np_shape(is_np_shape);
API_END();
}

Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
thread_local bool Imperative::is_np_shape_ = false;
thread_local bool Imperative::is_np_shape_thread_local_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
#endif

Imperative* Imperative::Get() {
Expand Down
2 changes: 1 addition & 1 deletion tests/nightly/JenkinsfileForBinaries
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ core_logic: {
node(NODE_LINUX_GPU) {
ws('workspace/estimator-test-gpu') {
utils.unpack_and_init('gpu', mx_lib)
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true)
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_numpy_global_flag', true)
}
}
}
Expand Down
47 changes: 47 additions & 0 deletions tests/nightly/test_global_numpy_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
import numpy as np
import threading
from mxnet.util import set_np_shape
from mxnet.test_utils import assert_almost_equal

# the test should be tested in isolated environment
# without other tests running parallelly as the flag
# effect every single thread
def test_np_global_shape():
set_np_shape(2)
data = []

def f():
# scalar
data.append(mx.np.ones(shape=()))
# zero-dim
data.append(mx.np.ones(shape=(0, 1, 2)))

thread = threading.Thread(target=f)
thread.start()
thread.join()

assert_almost_equal(data[0].asnumpy(), np.ones(shape=()))
assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2)))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 4bc025d

Please sign in to comment.