diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1c2300a6aaeb..1b1c10e79fea 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1067,14 +1067,14 @@ MXNET_DLL int MXAutogradIsTraining(bool* curr); * \param curr returns the current status * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXIsNumpyCompatible(bool* curr); +MXNET_DLL int MXIsNumpyShape(bool* curr); /*! * \brief set numpy compatibility switch - * \param is_np_comp 1 when numpy compatibility is on, 0 when off + * \param is_np_shape 1 when numpy shape semantics is on, 0 when off * \param prev returns the previous status before this set * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev); +MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev); /*! * \brief mark NDArrays as variables to compute gradient for autograd * \param num_var number of variable NDArrays diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index ad209913ac53..a86cc085a34b 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -98,13 +98,13 @@ class Imperative { return old; } /*! brief whether numpy compatibility is on. */ - bool is_np_comp() const { - return is_np_comp_; + bool is_np_shape() const { + return is_np_shape_; } /*! brief turn on or turn off numpy compatibility switch. */ - bool set_is_np_comp(bool is_np_comp) { - bool old = is_np_comp_; - is_np_comp_ = is_np_comp; + bool set_is_np_shape(bool is_np_shape) { + bool old = is_np_shape_; + is_np_shape_ = is_np_shape; return old; } /*! \brief to record operator, return corresponding node. */ @@ -177,13 +177,13 @@ class Imperative { static thread_local bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static thread_local bool is_np_comp_; + static thread_local bool is_np_shape_; #else static MX_THREAD_LOCAL bool is_train_; static MX_THREAD_LOCAL bool is_recording_; // TOOD(junwu): Added numpy compatibility switch for backward compatibility. // Delete it in the next major release. - static MX_THREAD_LOCAL bool is_np_comp_; + static MX_THREAD_LOCAL bool is_np_shape_; #endif /*! \brief node count used for naming */ std::atomic node_count_{0}; diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 79eb1f10f427..ab4bffde28a9 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -23,7 +23,8 @@ from .context import Context, current_context, cpu, gpu, cpu_pinned from . import engine -from .base import MXNetError, is_np_compat, set_np_compat, np_compat, use_np_compat +from .base import MXNetError +from .util import is_np_shape, set_np_shape, np_shape, use_np_shape from . import base from . import contrib from . import ndarray diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 53414016e39e..73fae4876873 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -20,7 +20,6 @@ """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import -from functools import wraps import atexit import ctypes import os @@ -31,7 +30,7 @@ from . import libinfo -__all__ = ['MXNetError', 'is_np_compat', 'set_np_compat', 'np_compat', 'use_np_compat'] +__all__ = ['MXNetError'] #---------------------------- # library loading #---------------------------- @@ -735,140 +734,3 @@ def write_all_str(module_file, module_all_list): ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p - - -def set_np_compat(active): - """ - Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend. - - Parameters - ---------- - active : bool - Indicates whether to turn on/off NumPy compatibility. - - Returns - ------- - A bool value indicating the previous state of NumPy compatibility. - """ - prev = ctypes.c_int() - check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active), ctypes.byref(prev))) - return bool(prev.value) - - -def is_np_compat(): - """ - Checks whether the NumPy compatibility is currently turned on. - NumPy-compatibility is turned off by default in backend. - - Returns - ------- - A bool value indicating whether the NumPy compatibility is currently on. - """ - curr = ctypes.c_bool() - check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr))) - return curr.value - - -class _NumpyCompatibilityStateScope(object): - """Scope for managing numpy compatibility state. - Do not use this class directly. Use `np_compat(active)` instead. - - Example:: - - with _NumpyCompatibilityStateScope(True): - y = model(x) - backward([y]) - - """ - def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name - self._enter_is_np_compat = is_np_compat - self._prev_is_np_compat = None - - def __enter__(self): - if self._enter_is_np_compat is not None: - self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat) - - def __exit__(self, ptype, value, trace): - if self._enter_is_np_compat is not None and self._prev_is_np_compat != self._enter_is_np_compat: - set_np_compat(self._prev_is_np_compat) - - -def np_compat(active=True): - """Returns an activated/deactivated NumPy compatibility state scope to be used in 'with' statement - and captures code that needs the compatibility. - - Example:: - - with mx.np_compat(active=True): - # A scalar tensor's shape is `()`, whose `ndim` is `0`. - scalar = mx.nd.ones(shape=()) - assert scalar.shape == () - - # In NumPy compatible mode, 0 in a shape means that dimension contains zero elements. - data = mx.sym.var("data", shape=(0, 2, 3)) - ret = mx.sym.sin(data) - arg_shapes, out_shapes, _ = ret.infer_shape() - assert arg_shapes[0] == (0, 2, 3) - assert out_shapes[0] == (0, 2, 3) - - # -1 means unknown shape dimension size in the new NumPy-compatible shape definition - data = mx.sym.var("data", shape=(-1, 2, 3)) - ret = mx.sym.sin(data) - arg_shapes, out_shapes, _ = ret.infer_shape_partial() - assert arg_shapes[0] == (-1, 2, 3) - assert out_shapes[0] == (-1, 2, 3) - - # When a shape is completely unknown in NumPy-compatible mode, it is - # represented as `None` in Python. - data = mx.sym.var("data") - ret = mx.sym.sin(data) - arg_shapes, out_shapes, _ = ret.infer_shape_partial() - assert arg_shapes[0] is None - assert out_shapes[0] is None - - with mx.np_compat(active=False): - # 0 means unknown shape dimension size in the legacy shape definition. - data = mx.sym.var("data", shape=(0, 2, 3)) - ret = mx.sym.sin(data) - arg_shapes, out_shapes, _ = ret.infer_shape_partial() - assert arg_shapes[0] == (0, 2, 3) - assert out_shapes[0] == (0, 2, 3) - - # When a shape is completely unknown in the legacy mode (default), its ndim is - # equal to 0 and it is represented as `()` in Python. - data = mx.sym.var("data") - ret = mx.sym.sin(data) - arg_shapes, out_shapes, _ = ret.infer_shape_partial() - assert arg_shapes[0] == () - assert out_shapes[0] == () - """ - return _NumpyCompatibilityStateScope(active) - - -def use_np_compat(func): - """Wraps a function with an activated NumPy-compatibility scope. This ensures - that the execution of the function is guaranteed with NumPy compatible semantics, - such as zero-dim and zero size tensors. - - Example:: - import mxnet as mx - @mx.use_np_compat - def scalar_one(): - return mx.nd.ones(()) - print(scalar_one()) - - Parameters - ---------- - func : a user-provided callable function to be scoped by the NumPy compatibility state. - - Returns - ------- - Function - A function for wrapping the user functions in the NumPy compatibility scope. - """ - @wraps(func) - def _with_np_compat(*args, **kwargs): - with np_compat(active=True): - return func(*args, **kwargs) - - return _with_np_compat diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0ea7c9f01d7c..d3cd519b9a8c 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -34,7 +34,7 @@ from ..attribute import AttrScope from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array -from ..base import mx_uint, py_str, string_types, integer_types, mx_int, is_np_compat +from ..base import mx_uint, py_str, string_types, integer_types, mx_int from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context @@ -45,6 +45,7 @@ from . import _internal from . import op from ._internal import SymbolBase, _set_symbol_class +from ..util import is_np_shape __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", @@ -1078,7 +1079,7 @@ def infer_shape(self, *args, **kwargs): arg_names = self.list_arguments() unknowns = [] for name, shape in zip(arg_names, arg_shapes): - if is_np_compat(): + if is_np_shape(): shape_is_none = not shape or -1 in shape else: shape_is_none = not shape or 0 in shape diff --git a/python/mxnet/util.py b/python/mxnet/util.py index fc8d985b9566..29f5b78e454e 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -19,6 +19,7 @@ import ctypes import os import sys +import functools from .base import _LIB, check_call @@ -44,3 +45,203 @@ def get_gpu_memory(gpu_dev_id): total_mem = ctypes.c_uint64(0) check_call(_LIB.MXGetGPUMemoryInformation64(gpu_dev_id, ctypes.byref(free_mem), ctypes.byref(total_mem))) return free_mem.value, total_mem.value + + +def set_np_shape(active): + """ + Turns on/off NumPy shape semantics, in which `()` represents the shape of scalar tensors, + and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent the shapes + of zero-size tensors. This is turned off by default for keeping backward compatibility. + + Please note that this is designed as an infrastructure for the incoming + MXNet-NumPy operators. Legacy operators registered in the modules + `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts + in NumPy within this semantics. + + Parameters + ---------- + active : bool + Indicates whether to turn on/off NumPy shape semantics. + + Returns + ------- + A bool value indicating the previous state of NumPy shape semantics. + + Example + ------- + >>> import mxnet as mx + >>> prev_state = mx.set_np_shape(True) + >>> print(prev_state) + False + >>> print(mx.is_np_shape()) + True + """ + prev = ctypes.c_int() + check_call(_LIB.MXSetIsNumpyShape(ctypes.c_int(active), ctypes.byref(prev))) + return bool(prev.value) + + +def is_np_shape(): + """ + Checks whether the NumPy shape semantics is currently turned on. + In NumPy shape semantics, `()` represents the shape of scalar tensors, + and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent + the shapes of zero-size tensors. This is turned off by default for keeping + backward compatibility. + + Please note that this is designed as an infrastructure for the incoming + MXNet-NumPy operators. Legacy operators registered in the modules + `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts + in NumPy within this semantics. + + Returns + ------- + A bool value indicating whether the NumPy shape semantics is currently on. + + Example + ------- + >>> import mxnet as mx + >>> prev_state = mx.set_np_shape(True) + >>> print(prev_state) + False + >>> print(mx.is_np_shape()) + True + """ + curr = ctypes.c_bool() + check_call(_LIB.MXIsNumpyShape(ctypes.byref(curr))) + return curr.value + + +class _NumpyShapeScope(object): + """Scope for managing NumPy shape semantics. + In NumPy shape semantics, `()` represents the shape of scalar tensors, + and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent + the shapes of zero-size tensors. + + Do not use this class directly. Use `np_shape(active)` instead. + + Example:: + + with _NumpyShapeScope(True): + y = model(x) + backward([y]) + + """ + def __init__(self, is_np_shape): #pylint: disable=redefined-outer-name + self._enter_is_np_shape = is_np_shape + self._prev_is_np_shape = None + + def __enter__(self): + if self._enter_is_np_shape is not None: + self._prev_is_np_shape = set_np_shape(self._enter_is_np_shape) + + def __exit__(self, ptype, value, trace): + if self._enter_is_np_shape is not None and self._prev_is_np_shape != self._enter_is_np_shape: + set_np_shape(self._prev_is_np_shape) + + +def np_shape(active=True): + """Returns an activated/deactivated NumPy shape scope to be used in 'with' statement + and captures code that needs the NumPy shape semantics, i.e. support of scalar and + zero-size tensors. + + Please note that this is designed as an infrastructure for the incoming + MXNet-NumPy operators. Legacy operators registered in the modules + `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts + in NumPy even within this scope. + + Parameters + ---------- + active : bool + Indicates whether to activate NumPy-shape semantics. + + Returns + ------- + _NumpyShapeScope + A scope object for wrapping the code w/ or w/o NumPy-shape semantics. + + Example:: + + with mx.np_shape(active=True): + # A scalar tensor's shape is `()`, whose `ndim` is `0`. + scalar = mx.nd.ones(shape=()) + assert scalar.shape == () + + # If NumPy shape semantics is enabled, 0 in a shape means that + # dimension contains zero elements. + data = mx.sym.var("data", shape=(0, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape() + assert arg_shapes[0] == (0, 2, 3) + assert out_shapes[0] == (0, 2, 3) + + # -1 means unknown shape dimension size in the new NumPy shape definition + data = mx.sym.var("data", shape=(-1, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == (-1, 2, 3) + assert out_shapes[0] == (-1, 2, 3) + + # When a shape is completely unknown when NumPy shape semantics is on, it is + # represented as `None` in Python. + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] is None + assert out_shapes[0] is None + + with mx.np_shape(active=False): + # 0 means unknown shape dimension size in the legacy shape definition. + data = mx.sym.var("data", shape=(0, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == (0, 2, 3) + assert out_shapes[0] == (0, 2, 3) + + # When a shape is completely unknown in the legacy mode (default), its ndim is + # equal to 0 and it is represented as `()` in Python. + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == () + assert out_shapes[0] == () + """ + return _NumpyShapeScope(active) + + +def use_np_shape(func): + """Wraps a function with an activated NumPy-shape scope. This ensures + that the execution of the function is guaranteed with the support of + scalar and zero-size tensors as in NumPy. + + Please note that this is designed as an infrastructure for the incoming + MXNet-NumPy operators. Legacy operators registered in the modules + `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts + in NumPy even within this scope. + + + Parameters + ---------- + func : a user-provided callable function to be scoped by the NumPy-shape semantics. + + Returns + ------- + Function + A function for wrapping the user functions in the NumPy-shape semantics. + + + Examples + -------- + >>> import mxnet as mx + >>> @mx.use_np_shape + ... def scalar_one(): + ... return mx.nd.ones(()) + ... + >>> print(scalar_one()) + """ + @functools.wraps(func) + def _with_np_shape(*args, **kwargs): + with np_shape(active=True): + return func(*args, **kwargs) + + return _with_np_shape diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index aba618540141..640ecf5d5978 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -350,6 +350,6 @@ private[mxnet] class LibInfo { @native def mxDumpProfile(finished: Int): Int // Numpy - @native def mxIsNumpyCompatible(compatible: RefInt): Int - @native def mxSetIsNumpyCompatible(isNpComp: Int, prev: RefInt): Int + @native def mxIsNumpyShape(compatible: RefInt): Int + @native def mxSetIsNumpyShape(isNpComp: Int, prev: RefInt): Int } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala index d3e76f1044a7..b63095a10cc1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala @@ -25,24 +25,24 @@ import org.apache.mxnet.Base._ * is introduced first to support zero-dim and zero-size tensors as in NumPy. */ object NumpyScope { - def setNumpyCompatible(isNpComp: Boolean): Boolean = { + def setNumpyShape(isNpComp: Boolean): Boolean = { val prev = new RefInt() - checkCall(_LIB.mxSetIsNumpyCompatible(if (isNpComp) 1 else 0, prev)) + checkCall(_LIB.mxSetIsNumpyShape(if (isNpComp) 1 else 0, prev)) if (prev.value != 0) true else false } - def isNumpyCompatible: Boolean = { + def isNumpyShape: Boolean = { val curr = new RefInt - checkCall(_LIB.mxIsNumpyCompatible(curr)) + checkCall(_LIB.mxIsNumpyShape(curr)) if (curr.value != 0) true else false } - def enableNumpyCompatible: NumpyScope = { + def enableNumpyShape: NumpyScope = { new NumpyScope(true) } - def disableNumpyCompatible: NumpyScope = { + def disableNumpyShape: NumpyScope = { new NumpyScope(false) } } @@ -51,12 +51,12 @@ class NumpyScope(var isCompatible: Boolean) { private var prev: Boolean = false def withScope[T](body: => T): T = { - prev = NumpyScope.setNumpyCompatible(isCompatible) + prev = NumpyScope.setNumpyShape(isCompatible) try { body } finally { if (prev != isCompatible) { - NumpyScope.setNumpyCompatible(prev) + NumpyScope.setNumpyShape(prev) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 68db2b1d9144..80f4dc935282 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -293,7 +293,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso val (argShapes, _, _) = inferShapeImpl(partial = true, keys, indPtr, values) val argNames = listArguments() val unknown = (argNames zip argShapes).map { case (name, shape) => - val shapeIsNone = if (NumpyScope.isNumpyCompatible) { + val shapeIsNone = if (NumpyScope.isNumpyShape) { shape == null || shape.toVector.contains(-1) } else { shape == null || shape.toVector.contains(0) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala index bf6627ac7e91..0581a9890d84 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala @@ -21,14 +21,14 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} class NumpyScopeSuite extends FunSuite with BeforeAndAfterAll { test("compatible") { - NumpyScope.enableNumpyCompatible.withScope { - assert(NumpyScope.isNumpyCompatible === true) + NumpyScope.enableNumpyShape.withScope { + assert(NumpyScope.isNumpyShape === true) } } test("incompatible") { - NumpyScope.disableNumpyCompatible.withScope { - assert(NumpyScope.isNumpyCompatible === false) + NumpyScope.disableNumpyShape.withScope { + assert(NumpyScope.isNumpyShape === false) } } } diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 7323d23ac556..9b19fd360fc4 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -2707,18 +2707,18 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile } // Numpy -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape (JNIEnv *env, jobject obj, jobject compatibleRef) { - bool isCompatible; - int ret = MXIsNumpyCompatible(&isCompatible); - SetIntField(env, compatibleRef, static_cast(isCompatible)); + bool isNumpyShape; + int ret = MXIsNumpyShape(&isNumpyShape); + SetIntField(env, compatibleRef, static_cast(isNumpyShape)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) { int prev; - int ret = MXSetIsNumpyCompatible(isNpComp, &prev); + int ret = MXSetIsNumpyShape(isNpComp, &prev); SetIntField(env, prevRef, prev); return ret; } diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index 467272cea9cf..fac32bb0a410 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -873,18 +873,18 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile /* * Class: org_apache_mxnet_LibInfo - * Method: mxIsNumpyCompatible + * Method: mxIsNumpyShape * Signature: (Lorg/apache/mxnet/Base/RefInt;)I */ -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape (JNIEnv *, jobject, jobject); /* * Class: org_apache_mxnet_LibInfo - * Method: mxSetIsNumpyCompatible + * Method: mxSetIsNumpyShape * Signature: (ILorg/apache/mxnet/Base/RefInt;)I */ -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape (JNIEnv *, jobject, jint, jobject); #ifdef __cplusplus diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 7f8d5f590a4c..f5d72d53d2b7 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -521,7 +521,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, NDArray *arr = static_cast(handle); if (!arr->is_none()) { mxnet::TShape s = arr->shape(); - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&s); } *out_dim = s.ndim(); @@ -532,7 +532,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, *out_pdata = buffer.data(); } } else { - if (Imperative::Get()->is_np_comp()) { + if (Imperative::Get()->is_np_shape()) { *out_dim = -1; } else { *out_dim = 0; diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 8fade7df223e..ebe3f17d7d90 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -415,7 +415,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, CHECK(p.second) << "Duplicate shapes are provided for argument " << provided_arg_shape_names[i] << " in simple_bind"; } - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { for (auto &kv : arg_shape_map) { common::ConvertToNumpyShape(&kv.second); } @@ -749,7 +749,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, CHECK(p.second) << "Duplicate shapes are provided for argument " << provided_arg_shape_names[i] << " in simple_bind"; } - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { for (auto &kv : arg_shape_map) { common::ConvertToNumpyShape(&kv.second); } diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 0e136b03ecd7..c9c6000e2f6f 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -276,15 +276,15 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) { API_END(); } -int MXIsNumpyCompatible(bool* curr) { +int MXIsNumpyShape(bool* curr) { API_BEGIN(); - *curr = Imperative::Get()->is_np_comp(); + *curr = Imperative::Get()->is_np_shape(); API_END(); } -int MXSetIsNumpyCompatible(int is_np_comp, int* prev) { +int MXSetIsNumpyShape(int is_np_shape, int* prev) { API_BEGIN(); - *prev = Imperative::Get()->set_is_np_comp(static_cast(is_np_comp)); + *prev = Imperative::Get()->set_is_np_shape(static_cast(is_np_shape)); API_END(); } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index a3b9fce6057a..4c6229ee29b0 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -556,7 +556,7 @@ int MXSymbolInferShape(SymbolHandle sym, // if use legacy shape definition, need to convert numpy shape to legacy shape mxnet::ShapeVector shapes = g.GetAttr("shape"); - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&shapes); } @@ -629,7 +629,7 @@ int MXSymbolInferShapeEx(SymbolHandle sym, // if use legacy shape definition, need to convert numpy shape to legacy shape mxnet::ShapeVector shapes = g.GetAttr("shape"); - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&shapes); } diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index a71e5ecbdd6f..d72325392604 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -470,7 +470,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, std::vector is_dynamic(rshape.size(), 0); // convert to numpy compatible shape to use operator's infer shape function - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToNumpyShape(&rshape); } @@ -490,7 +490,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (it != inode.source->attrs.dict.end()) { std::istringstream is(it->second); CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToNumpyShape(&rshape[out_ent_id]); } } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index f014ab9dcf3e..d8fba1c169ec 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -25,11 +25,11 @@ namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL thread_local bool Imperative::is_train_ = false; thread_local bool Imperative::is_recording_ = false; -thread_local bool Imperative::is_np_comp_ = false; +thread_local bool Imperative::is_np_shape_ = false; #else MX_THREAD_LOCAL bool Imperative::is_train_ = false; MX_THREAD_LOCAL bool Imperative::is_recording_ = false; -MX_THREAD_LOCAL bool Imperative::is_np_comp_ = false; +MX_THREAD_LOCAL bool Imperative::is_np_shape_ = false; #endif Imperative* Imperative::Get() { diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 5c9706834b2d..5cb805c5abcb 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -122,7 +122,7 @@ inline void SetShapeType(const Context& ctx, if (!infershape.count(attrs.op)) { is_dynamic_shape_existing = true; } else { - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToNumpyShape(&in_shapes); common::ConvertToNumpyShape(&out_shapes); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 9474d0ce40c2..16c579fefa32 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1582,7 +1582,7 @@ static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; void NDArray::Save(dmlc::Stream *strm) const { // TODO(junwu): Support this after NumPy operators are merged - CHECK(!Imperative::Get()->is_np_comp()) + CHECK(!Imperative::Get()->is_np_shape()) << "Saving ndarray within the scope of np_shape is not supported."; // write magic number to mark this version // for storage type @@ -1702,7 +1702,7 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { bool NDArray::Load(dmlc::Stream *strm) { // TODO(junwu): Support this after NumPy operators are merged - CHECK(!Imperative::Get()->is_np_comp()) + CHECK(!Imperative::Get()->is_np_shape()) << "Loading ndarray within the scope of np_shape is not supported."; uint32_t magic; if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index e4b090db933e..fd491534f83a 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -242,7 +242,7 @@ inline bool InitShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 0U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape param_shape = param.shape; - if (!Imperative::Get()->is_np_comp()) { + if (!Imperative::Get()->is_np_shape()) { common::ConvertToNumpyShape(¶m_shape); } if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return true; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 9c004cdfdab1..064f783ec6c8 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2012,14 +2012,14 @@ def check_proposal_consistency(op, batch_size, with_nms=False): # The following 2 functions launch 0-thread kernels, an error that should be caught and signaled. def kernel_error_check_imperative(): os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - with mx.np_compat(active=True): + with mx.np_shape(active=True): a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) b = mx.nd.array([],ctx=mx.gpu(0)) c = (a / b).asnumpy() def kernel_error_check_symbolic(): os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - with mx.np_compat(active=True): + with mx.np_shape(active=True): a = mx.sym.Variable('a') b = mx.sym.Variable('b') c = a / b diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 2bf7e8bf9d71..1312be0c0081 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -154,7 +154,7 @@ def test_shape_completely_unknown(): assert arg_shapes[0] == () assert out_shapes[0] == () - with mx.np_compat(): + with mx.np_shape(): data = mx.sym.var("data") ret = mx.sym.sin(data) arg_shapes, out_shapes, _ = ret.infer_shape_partial() @@ -169,7 +169,7 @@ def test_dot_partial_shape(): # batch size(first dim) of lhs unknown _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5)) assert result_shape == [(0, 3, 5)] - with mx.np_compat(True): + with mx.np_shape(True): _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5)) assert result_shape == [(-1, 3, 5)] @@ -184,7 +184,7 @@ def test_batch_dot_partial_shape(): # rhs second dim unknown _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5)) assert result_shape == [()] - with mx.np_compat(True): + with mx.np_shape(True): _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5)) assert result_shape == [(-1, 3, 5)] _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5)) @@ -198,7 +198,7 @@ def test_embedding_partial_shape(): y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10) _, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10)) assert result_shape == [(0, 5, 10)] - with mx.np_compat(True): + with mx.np_shape(True): _, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10)) assert result_shape == [(-1, 5, 10)] @@ -213,7 +213,7 @@ def test_transpose_partial_shape(): _, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224)) assert result == [(0, 224, 224, 3)] - with mx.np_compat(True): + with mx.np_shape(True): _, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224)) assert result == [(-1, 224, 224, 3)] @@ -225,7 +225,7 @@ def test_pick_partial_shape(): # batch size unknown _, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,)) assert result == [(0, 3)] - with mx.np_compat(True): + with mx.np_shape(True): _, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,)) assert result == [(-1, 3)] @@ -240,7 +240,7 @@ def test_where_partial_shape(): assert result == [()] _, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2)) assert result == [()] - with mx.np_compat(True): + with mx.np_shape(True): _, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2)) assert result == [None] _, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2)) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8998b215d704..8b2a270a34a2 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -123,7 +123,7 @@ def test_ndarray_setitem(): # numpy assignment for empty axis for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]: if trivial_shape == tuple(): - with mx.np_compat(): + with mx.np_shape(): x = mx.nd.zeros(trivial_shape) else: x = mx.nd.zeros(trivial_shape) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7767863668a2..52fe69bbd434 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4610,7 +4610,7 @@ def test_invalid_reps(): assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3)) test_normal_case() - with mx.np_compat(): + with mx.np_shape(): test_empty_tensor() test_empty_reps() test_tile_backward() @@ -4671,7 +4671,7 @@ def test_zero_depth(): test_normal_case(index_type=np.float64) test_normal_case(index_type=np.float32) test_normal_case(index_type=np.float16) - with mx.np_compat(): + with mx.np_shape(): test_empty_indices() test_zero_depth() @@ -7222,7 +7222,7 @@ def check_slice_axis_partial_infer(data, axis, begin, end, expected_out_shape): check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0)) check_slice_axis_partial_infer(var1, 1, 0, 5, (10, 0)) - with mx.np_compat(): + with mx.np_shape(): var1 = mx.sym.var(name="data", shape=(-1, 20)) check_slice_partial_infer(var1, (None, None), (None, 10), [], (-1, 10)) check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (-1, 5)) @@ -7247,7 +7247,7 @@ def test_float16_min_max(): @with_seed() -@mx.use_np_compat +@mx.use_np_shape def test_zero_size_min_max(): def min(): a = mx.nd.zeros(shape=(5, 0)) @@ -8457,7 +8457,7 @@ def test_index_array_default(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - @mx.use_np_compat + @mx.use_np_shape def test_index_array_default_zero_dim(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) @@ -8468,7 +8468,7 @@ def test_index_array_default_zero_dim(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - @mx.use_np_compat + @mx.use_np_shape def test_index_array_default_zero_size(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) @@ -8492,7 +8492,7 @@ def test_index_array_select_axes(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - @mx.use_np_compat + @mx.use_np_shape def test_index_array_select_axes_zero_size(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) @@ -8502,7 +8502,7 @@ def test_index_array_select_axes_zero_size(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - + test_index_array_default() test_index_array_default_zero_dim() test_index_array_default_zero_size() @@ -8514,7 +8514,7 @@ def test_index_array_select_axes_zero_size(): def test_scalar_tensor_creation(): assertRaises(MXNetError, mx.nd.zeros, shape=()) assertRaises(MXNetError, mx.nd.ones, shape=()) - with mx.np_compat(): + with mx.np_shape(): data_mx = mx.nd.ones(shape=()) data_np = np.ones((), dtype=data_mx.dtype) assert same(data_mx.asnumpy(), data_np) @@ -8524,7 +8524,7 @@ def test_scalar_tensor_creation(): def test_zero_size_tensor_creation(): assertRaises(MXNetError, mx.nd.zeros, shape=(0, 1, 3, 0)) assertRaises(MXNetError, mx.nd.ones, shape=(0, 1, 3, 0)) - with mx.np_compat(): + with mx.np_shape(): data_mx = mx.nd.ones(shape=(0, 1, 0, 4)) data_np = np.ones(shape=data_mx.shape, dtype=data_mx.dtype) assert same(data_mx.asnumpy(), data_np) @@ -8532,7 +8532,7 @@ def test_zero_size_tensor_creation(): @with_seed() def test_concat_with_zero_size_tensor(): - with mx.np_compat(): + with mx.np_shape(): data1 = mx.nd.ones((0, 8, 12)) data2 = mx.nd.ones((3, 8, 12)) data3 = mx.nd.ones((0, 8, 12)) @@ -8547,8 +8547,8 @@ def test_concat_with_zero_size_tensor(): @with_seed() -def test_np_compat_decorator(): - @mx.use_np_compat +def test_np_shape_decorator(): + @mx.use_np_shape def check_scalar_one(): """Generate scalar one tensor""" return mx.nd.ones(shape=()) @@ -8556,12 +8556,12 @@ def check_scalar_one(): assert check_scalar_one.__doc__ == "Generate scalar one tensor" assert check_scalar_one().shape == () for active in [True, False]: - with mx.np_compat(active=active): + with mx.np_shape(active=active): assert check_scalar_one.__name__ == "check_scalar_one" assert check_scalar_one.__doc__ == "Generate scalar one tensor" assert check_scalar_one().shape == () - @mx.use_np_compat + @mx.use_np_shape def check_concat(shape1, shape2, axis): data1 = mx.nd.ones(shape1) data2 = mx.nd.ones(shape2)