Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
global numpy shape flag
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Sep 30, 2019
1 parent ea440c7 commit 652f650
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
3 changes: 2 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is thread local on,
* 2 when numpy shape semantics is global on and 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
Expand Down
28 changes: 23 additions & 5 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "./ndarray.h"

namespace mxnet {
enum NumpyShape{Off, On, Global};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -97,14 +98,30 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! brief whether numpy compatibility is on. */
/*! \brief whether numpy compatibility is on. */
bool is_np_shape() const {
if (is_np_shape_global_) {
return true;
}
return is_np_shape_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
bool old = this->is_np_shape();
switch (flag) {
case Global:
is_np_shape_global_ = true;
is_np_shape_ = true;
break;
case On:
is_np_shape_ = true;
break;
case Off:
is_np_shape_global_ = false;
is_np_shape_ = false;
break;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -185,6 +202,7 @@ class Imperative {
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_;
#endif
bool is_np_shape_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ int MXIsNumpyShape(bool* curr) {

int MXSetIsNumpyShape(int is_np_shape, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_np_shape(static_cast<bool>(is_np_shape));
*prev = Imperative::Get()->set_is_np_shape(is_np_shape);
API_END();
}

Expand Down

0 comments on commit 652f650

Please sign in to comment.