Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Rename np_compat to np_shape #15063

Merged
merged 8 commits into from
May 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,14 +1067,14 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr);
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyCompatible(bool* curr);
MXNET_DLL int MXIsNumpyShape(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_comp 1 when numpy compatibility is on, 0 when off
* \param is_np_shape 1 when numpy shape semantics is on, 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev);
MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes how array format is expected when loading from disk. It would be good to document this on the function call. If NumpyShape is on, then Loading ndarray will fail. Just happened to me now. I think this is a side effect which should be clearly documented, or can we add additional arguments to Load with the different semantics?

/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
Expand Down
14 changes: 7 additions & 7 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ class Imperative {
return old;
}
/*! brief whether numpy compatibility is on. */
bool is_np_comp() const {
return is_np_comp_;
bool is_np_shape() const {
return is_np_shape_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_comp(bool is_np_comp) {
bool old = is_np_comp_;
is_np_comp_ = is_np_comp;
bool set_is_np_shape(bool is_np_shape) {
bool old = is_np_shape_;
is_np_shape_ = is_np_shape;
return old;
}
/*! \brief to record operator, return corresponding node. */
Expand Down Expand Up @@ -177,13 +177,13 @@ class Imperative {
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_comp_;
static thread_local bool is_np_shape_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_comp_;
static MX_THREAD_LOCAL bool is_np_shape_;
#endif
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from .context import Context, current_context, cpu, gpu, cpu_pinned
from . import engine
from .base import MXNetError, is_np_compat, set_np_compat, np_compat, use_np_compat
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from . import base
from . import contrib
from . import ndarray
Expand Down
140 changes: 1 addition & 139 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import

from functools import wraps
import atexit
import ctypes
import os
Expand All @@ -31,7 +30,7 @@

from . import libinfo

__all__ = ['MXNetError', 'is_np_compat', 'set_np_compat', 'np_compat', 'use_np_compat']
__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
Expand Down Expand Up @@ -735,140 +734,3 @@ def write_all_str(module_file, module_all_list):

ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p


def set_np_compat(active):
"""
Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend.

Parameters
----------
active : bool
Indicates whether to turn on/off NumPy compatibility.

Returns
-------
A bool value indicating the previous state of NumPy compatibility.
"""
prev = ctypes.c_int()
check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active), ctypes.byref(prev)))
return bool(prev.value)


def is_np_compat():
"""
Checks whether the NumPy compatibility is currently turned on.
NumPy-compatibility is turned off by default in backend.

Returns
-------
A bool value indicating whether the NumPy compatibility is currently on.
"""
curr = ctypes.c_bool()
check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
return curr.value


class _NumpyCompatibilityStateScope(object):
"""Scope for managing numpy compatibility state.
Do not use this class directly. Use `np_compat(active)` instead.

Example::

with _NumpyCompatibilityStateScope(True):
y = model(x)
backward([y])

"""
def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name
self._enter_is_np_compat = is_np_compat
self._prev_is_np_compat = None

def __enter__(self):
if self._enter_is_np_compat is not None:
self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat)

def __exit__(self, ptype, value, trace):
if self._enter_is_np_compat is not None and self._prev_is_np_compat != self._enter_is_np_compat:
set_np_compat(self._prev_is_np_compat)


def np_compat(active=True):
"""Returns an activated/deactivated NumPy compatibility state scope to be used in 'with' statement
and captures code that needs the compatibility.

Example::

with mx.np_compat(active=True):
# A scalar tensor's shape is `()`, whose `ndim` is `0`.
scalar = mx.nd.ones(shape=())
assert scalar.shape == ()

# In NumPy compatible mode, 0 in a shape means that dimension contains zero elements.
data = mx.sym.var("data", shape=(0, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape()
assert arg_shapes[0] == (0, 2, 3)
assert out_shapes[0] == (0, 2, 3)

# -1 means unknown shape dimension size in the new NumPy-compatible shape definition
data = mx.sym.var("data", shape=(-1, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == (-1, 2, 3)
assert out_shapes[0] == (-1, 2, 3)

# When a shape is completely unknown in NumPy-compatible mode, it is
# represented as `None` in Python.
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] is None
assert out_shapes[0] is None

with mx.np_compat(active=False):
# 0 means unknown shape dimension size in the legacy shape definition.
data = mx.sym.var("data", shape=(0, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == (0, 2, 3)
assert out_shapes[0] == (0, 2, 3)

# When a shape is completely unknown in the legacy mode (default), its ndim is
# equal to 0 and it is represented as `()` in Python.
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == ()
assert out_shapes[0] == ()
"""
return _NumpyCompatibilityStateScope(active)


def use_np_compat(func):
"""Wraps a function with an activated NumPy-compatibility scope. This ensures
that the execution of the function is guaranteed with NumPy compatible semantics,
such as zero-dim and zero size tensors.

Example::
import mxnet as mx
@mx.use_np_compat
def scalar_one():
return mx.nd.ones(())
print(scalar_one())

Parameters
----------
func : a user-provided callable function to be scoped by the NumPy compatibility state.

Returns
-------
Function
A function for wrapping the user functions in the NumPy compatibility scope.
"""
@wraps(func)
def _with_np_compat(*args, **kwargs):
with np_compat(active=True):
return func(*args, **kwargs)

return _with_np_compat
5 changes: 3 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
from ..base import mx_uint, py_str, string_types, integer_types, mx_int, is_np_compat
from ..base import mx_uint, py_str, string_types, integer_types, mx_int
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
Expand All @@ -45,6 +45,7 @@
from . import _internal
from . import op
from ._internal import SymbolBase, _set_symbol_class
from ..util import is_np_shape

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
Expand Down Expand Up @@ -1078,7 +1079,7 @@ def infer_shape(self, *args, **kwargs):
arg_names = self.list_arguments()
unknowns = []
for name, shape in zip(arg_names, arg_shapes):
if is_np_compat():
if is_np_shape():
shape_is_none = not shape or -1 in shape
else:
shape_is_none = not shape or 0 in shape
Expand Down
Loading