Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

New set default dtype #18251

Merged
merged 9 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/python/einsum/benchmark_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ def test_np_einsum():


if __name__ == "__main__":
npx.set_np()
npx.set_np(dtype=False)
test_np_einsum()
5 changes: 4 additions & 1 deletion benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def generate_workloads():
def prepare_workloads():
pool = generate_workloads()
OpArgMngr.add_workload("zeros", (2, 2))
OpArgMngr.add_workload("full", (2, 2), 10)
OpArgMngr.add_workload("identity", 3)
OpArgMngr.add_workload("ones", (2, 2))
OpArgMngr.add_workload("einsum", "ii", pool['2x2'], optimize=False)
OpArgMngr.add_workload("unique", pool['1'], return_index=True, return_inverse=True, return_counts=True, axis=-1)
OpArgMngr.add_workload("dstack", (pool['2x1'], pool['2x1'], pool['2x1'], pool['2x1']))
Expand Down Expand Up @@ -252,7 +255,7 @@ def show_results(results):
import numpy as onp
from mxnet import np as dnp

mx.npx.set_np()
mx.npx.set_np(dtype=False)
packages = {
"onp": {
"module": onp,
Expand Down
14 changes: 14 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,20 @@ MXNET_DLL int MXIsNumpyShape(int* curr);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
/*!
* \brief get numpy default data type
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyDefaultDtype(bool* curr);
/*!
* \brief set numpy default data type
* \param dtype_flag false when default dtype is flaot32,
* true when default dtype is flaot64.
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyDefaultDtype(bool dtype_flag, bool* prev);
/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
Expand Down
26 changes: 24 additions & 2 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace mxnet {
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
typedef NumpyShape NumpyDefaultDtype;
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -189,9 +190,11 @@ class Imperative {
* */
int is_np_shape() const {
if (is_np_shape_global_) {
return 2;
return NumpyShape::GlobalOn;
}
return is_np_shape_thread_local_ ? 1 : 0;
return is_np_shape_thread_local_ ?
NumpyShape::ThreadLocalOn :
NumpyShape::Off;
}
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
Expand All @@ -212,6 +215,24 @@ class Imperative {
}
return old;
}
/*! \brief return current numpy default dtype compatibility status.
* */
bool is_np_default_dtype() const {
if (is_np_default_dtype_global_) {
return true;
}
return false;
}
/*! \brief specify numpy default dtype off or global on. */
bool set_is_np_default_dtype(bool is_np_default_dtype) {
bool old = this->is_np_default_dtype();
if (is_np_default_dtype) {
is_np_default_dtype_global_ = true;
} else {
is_np_default_dtype_global_ = false;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
Expand Down Expand Up @@ -301,6 +322,7 @@ class Imperative {
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
bool is_np_default_dtype_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from .util import is_np_array, np_array, use_np_array, use_np
from .util import is_np_default_dtype, np_default_dtype, use_np_default_dtype
from . import base

# version info
Expand Down
Loading