Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
draft 1
Browse files Browse the repository at this point in the history
Preliminary completion

fix rebase mistake
  • Loading branch information
JiangZhaoh committed Jan 13, 2020
1 parent 28e053e commit 8fb4f1f
Show file tree
Hide file tree
Showing 31 changed files with 807 additions and 222 deletions.
15 changes: 15 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,21 @@ MXNET_DLL int MXIsNumpyShape(int* curr);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
/*!
* \brief get numpy default data type
* \param curr returns the current status
* \return 0 when default dtype is flaot32,
* 1 when default dtype is flaot64.
*/
MXNET_DLL int MXIsNumpyDefaultDtype(int* curr);
/*!
* \brief set numpy default data type
* \param dtype_flag 0 when default dtype is flaot32,
* 1 when default dtype is flaot64.
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyDefaultDtype(int dtype_flag, int* prev);
/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
Expand Down
42 changes: 39 additions & 3 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace mxnet {
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
typedef NumpyShape NumpyDefaultDtype;
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -108,14 +109,16 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! \brief return current numpy compatibility status,
/*! \brief return current numpy shape compatibility status,
* GlobalOn(2), ThreadLocalOn(1), Off(0).
* */
int is_np_shape() const {
if (is_np_shape_global_) {
return 2;
return NumpyShape::GlobalOn;
}
return is_np_shape_thread_local_ ? 1 : 0;
return is_np_shape_thread_local_ ?
NumpyShape::ThreadLocalOn :
NumpyShape::Off;
}
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
Expand All @@ -136,6 +139,36 @@ class Imperative {
}
return old;
}
/*! \brief return current numpy default dtype compatibility status,
* GlobalOn(2), ThreadLocalOn(1), Off(0).
* */
int is_np_default_dtype() const {
if (is_np_default_dtype_global_) {
return NumpyDefaultDtype::GlobalOn;
}
return is_np_default_dtype_thread_local_ ?
NumpyDefaultDtype::ThreadLocalOn :
NumpyDefaultDtype::Off;
}
/*! \brief specify numpy default dtype off, thread local on or global on. */
bool set_is_np_default_dtype(int is_np_default_dtype) {
NumpyDefaultDtype flag = static_cast<NumpyDefaultDtype>(is_np_default_dtype);
bool old = this->is_np_default_dtype();
switch (flag) {
case GlobalOn:
is_np_default_dtype_global_ = true;
is_np_default_dtype_thread_local_ = true;
break;
case ThreadLocalOn:
is_np_default_dtype_thread_local_ = true;
break;
case Off:
is_np_default_dtype_global_ = false;
is_np_default_dtype_thread_local_ = false;
break;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
Expand Down Expand Up @@ -207,14 +240,17 @@ class Imperative {
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_thread_local_;
static thread_local bool is_np_default_dtype_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
static MX_THREAD_LOCAL bool is_np_default_dtype_thread_local_;
#endif
bool is_np_shape_global_{false};
bool is_np_default_dtype_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from .util import is_np_array, np_array, use_np_array, use_np
from .util import is_np_default_dtype, np_default_dtype, use_np_default_dtype
from . import base

# version info
Expand Down
Loading

0 comments on commit 8fb4f1f

Please sign in to comment.