Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
global numpy shape flag
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Oct 5, 2019
1 parent 916fbf2 commit 40125aa
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 12 deletions.
7 changes: 7 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,13 @@ test_ubuntu_cpu_python3() {
popd
}

test_global() {
set -ex
export PYTHONPATH=./python/
cd /work/mxnet/tests/python/other
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_unittest.xml --verbose test_global_numpy_shape.py
}

# Functions that run the nightly Tests:

#Runs Apache RAT Check on MXNet Source for License Headers
Expand Down
14 changes: 14 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,20 @@ def test_unix_distributed_kvstore_gpu() {
}]
}

def test_unix_other_cpu() {
return ['unix-other tests CPU': {
node(NODE_LINUX_CPU) {
ws('workspace/unix-other-tests') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('cpu', mx_lib)
utils.docker_run('ubuntu_cpu', 'test_global', true)
utils.publish_test_coverage()
}
}
}
}]
}

def test_centos7_python3_cpu() {
return ['Python3: CentOS 7 CPU': {
node(NODE_LINUX_CPU) {
Expand Down
3 changes: 2 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is thread local on,
* 2 when numpy shape semantics is global on and 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
Expand Down
44 changes: 36 additions & 8 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
* turn on numpy shape flag globally, it includes thread local.
* The flag can be seen in any thread.
* ThreadLocalOn
* only turn on thread local numpy shape flag, it cannot be seen
* in other threads.
* Off
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -97,14 +108,30 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! brief whether numpy compatibility is on. */
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
return is_np_shape_;
if (is_np_shape_global_) {
return true;
}
return is_np_shape_thread_local_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
bool old = this->is_np_shape();
switch (flag) {
case GlobalOn:
is_np_shape_global_ = true;
is_np_shape_thread_local_ = true;
break;
case ThreadLocalOn:
is_np_shape_thread_local_ = true;
break;
case Off:
is_np_shape_global_ = false;
is_np_shape_thread_local_ = false;
break;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -177,14 +204,15 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_;
static thread_local bool is_np_shape_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_;
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) {

int MXSetIsNumpyShape(int is_np_shape, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_np_shape(static_cast<bool>(is_np_shape));
*prev = Imperative::Get()->set_is_np_shape(is_np_shape);
API_END();
}

Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
thread_local bool Imperative::is_np_shape_ = false;
thread_local bool Imperative::is_np_shape_thread_local_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false;
MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
#endif

Imperative* Imperative::Get() {
Expand Down
45 changes: 45 additions & 0 deletions tests/python/other/test_global_numpy_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
import numpy as np
import threading
from mxnet.util import set_np_shape
from mxnet.test_utils import assert_almost_equal


def test_np_global_shape():
set_np_shape(2)
data = []

def f():
# scalar
data.append(mx.np.ones(shape=()))
# zero-dim
data.append(mx.np.ones(shape=(0, 1, 2)))

thread = threading.Thread(target=f)
thread.start()
thread.join()

assert_almost_equal(data[0].asnumpy(), np.ones(shape=()))
assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2)))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 40125aa

Please sign in to comment.