Skip to content

Commit

Permalink
Rename np_compat to np_shape (apache#15063)
Browse files Browse the repository at this point in the history
* Change np_compat to np_shape

* Fix scala

* Fix pylint

* Add examples and fix documentation

* Fix doc

* More doc

* Rename np_compat to np_shape in test_operatory.py

* Rename in ndarray.cc
  • Loading branch information
reminisce authored and haohuw committed Jun 23, 2019
1 parent fe046bf commit 87c589c
Show file tree
Hide file tree
Showing 25 changed files with 285 additions and 220 deletions.
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,14 +1067,14 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyCompatible(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_comp 1 when numpy compatibility is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev);
MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
Expand Down
14 changes: 7 additions & 7 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ class Imperative {
return old;
}
/*! brief whether numpy compatibility is on. */
bool is_np_comp() const {
return is_np_comp_;
bool is_np_shape() const {
return is_np_shape_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_comp(bool is_np_comp) {
bool old = is_np_comp_;
is_np_comp_ = is_np_comp;
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -177,13 +177,13 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_comp_;
static thread_local bool is_np_shape_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_comp_;
static MX_THREAD_LOCAL bool is_np_shape_;
#endif
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from .context import Context, current_context, cpu, gpu, cpu_pinned
from . import engine
from .base import MXNetError, is_np_compat, set_np_compat, np_compat, use_np_compat
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from . import base
from . import contrib
from . import ndarray
Expand Down
140 changes: 1 addition & 139 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import

from functools import wraps
import atexit
import ctypes
import os
Expand All @@ -31,7 +30,7 @@

from . import libinfo

__all__ = ['MXNetError', 'is_np_compat', 'set_np_compat', 'np_compat', 'use_np_compat']
__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
Expand Down Expand Up @@ -735,140 +734,3 @@ def write_all_str(module_file, module_all_list):

ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p


def set_np_compat(active):
"""
Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend.
Parameters
----------
active : bool
Indicates whether to turn on/off NumPy compatibility.
Returns
-------
A bool value indicating the previous state of NumPy compatibility.
"""
prev = ctypes.c_int()
check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active), ctypes.byref(prev)))
return bool(prev.value)


def is_np_compat():
"""
Checks whether the NumPy compatibility is currently turned on.
NumPy-compatibility is turned off by default in backend.
Returns
-------
A bool value indicating whether the NumPy compatibility is currently on.
"""
curr = ctypes.c_bool()
check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
return curr.value


class _NumpyCompatibilityStateScope(object):
"""Scope for managing numpy compatibility state.
Do not use this class directly. Use `np_compat(active)` instead.
Example::
with _NumpyCompatibilityStateScope(True):
y = model(x)
backward([y])
"""
def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name
self._enter_is_np_compat = is_np_compat
self._prev_is_np_compat = None

def __enter__(self):
if self._enter_is_np_compat is not None:
self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat)

def __exit__(self, ptype, value, trace):
if self._enter_is_np_compat is not None and self._prev_is_np_compat != self._enter_is_np_compat:
set_np_compat(self._prev_is_np_compat)


def np_compat(active=True):
"""Returns an activated/deactivated NumPy compatibility state scope to be used in 'with' statement
and captures code that needs the compatibility.
Example::
with mx.np_compat(active=True):
# A scalar tensor's shape is `()`, whose `ndim` is `0`.
scalar = mx.nd.ones(shape=())
assert scalar.shape == ()
# In NumPy compatible mode, 0 in a shape means that dimension contains zero elements.
data = mx.sym.var("data", shape=(0, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape()
assert arg_shapes[0] == (0, 2, 3)
assert out_shapes[0] == (0, 2, 3)
# -1 means unknown shape dimension size in the new NumPy-compatible shape definition
data = mx.sym.var("data", shape=(-1, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == (-1, 2, 3)
assert out_shapes[0] == (-1, 2, 3)
# When a shape is completely unknown in NumPy-compatible mode, it is
# represented as `None` in Python.
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] is None
assert out_shapes[0] is None
with mx.np_compat(active=False):
# 0 means unknown shape dimension size in the legacy shape definition.
data = mx.sym.var("data", shape=(0, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == (0, 2, 3)
assert out_shapes[0] == (0, 2, 3)
# When a shape is completely unknown in the legacy mode (default), its ndim is
# equal to 0 and it is represented as `()` in Python.
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == ()
assert out_shapes[0] == ()
"""
return _NumpyCompatibilityStateScope(active)


def use_np_compat(func):
"""Wraps a function with an activated NumPy-compatibility scope. This ensures
that the execution of the function is guaranteed with NumPy compatible semantics,
such as zero-dim and zero size tensors.
Example::
import mxnet as mx
@mx.use_np_compat
def scalar_one():
return mx.nd.ones(())
print(scalar_one())
Parameters
----------
func : a user-provided callable function to be scoped by the NumPy compatibility state.
Returns
-------
Function
A function for wrapping the user functions in the NumPy compatibility scope.
"""
@wraps(func)
def _with_np_compat(*args, **kwargs):
with np_compat(active=True):
return func(*args, **kwargs)

return _with_np_compat
5 changes: 3 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
from ..base import mx_uint, py_str, string_types, integer_types, mx_int, is_np_compat
from ..base import mx_uint, py_str, string_types, integer_types, mx_int
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
Expand All @@ -45,6 +45,7 @@
from . import _internal
from . import op
from ._internal import SymbolBase, _set_symbol_class
from ..util import is_np_shape

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
Expand Down Expand Up @@ -1078,7 +1079,7 @@ def infer_shape(self, *args, **kwargs):
arg_names = self.list_arguments()
unknowns = []
for name, shape in zip(arg_names, arg_shapes):
if is_np_compat():
if is_np_shape():
shape_is_none = not shape or -1 in shape
else:
shape_is_none = not shape or 0 in shape
Expand Down
Loading

0 comments on commit 87c589c

Please sign in to comment.