Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] Fix unit tests after introducing numpy compatible shapes (#14487
Browse files Browse the repository at this point in the history
)

* Fix infer shape rnn

* Fix boolean mask and custom op unit tests

* Fix multi proposal

* Fix diag

* Add global switch for backward compatibility and fix infer shape bugs

* Fix slice op infer shape

* Fix rnn infer shape

* Add util funcs for ndim_is_known and dim_size_is_known

* Revert rnn_cell.py
  • Loading branch information
reminisce committed Apr 6, 2019
1 parent f659034 commit 19434dd
Show file tree
Hide file tree
Showing 44 changed files with 405 additions and 157 deletions.
15 changes: 14 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/,
typedef int (*CustomOpDelFunc)(void* /*state*/);
typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
unsigned** /*shapes*/, void* /*state*/);
int** /*shapes*/, void* /*state*/);
typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/);
typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/,
int * /*stypes*/,
Expand Down Expand Up @@ -1036,6 +1036,19 @@ MXNET_DLL int MXAutogradIsRecording(bool* curr);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradIsTraining(bool* curr);
/*!
* \brief get whether numpy compatibility is on
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyCompatible(bool* curr);
/*!
* \brief set numpy compatibility switch
* \param is_np_comp 1 when numpy compatibility is on, 0 when off
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev);
/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
Expand Down
16 changes: 16 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! brief whether numpy compatibility is on. */
bool is_np_comp() const {
return is_np_comp_;
}
/*! brief turn on or turn off numpy compatibility switch. */
bool set_is_np_comp(bool is_np_comp) {
bool old = is_np_comp_;
is_np_comp_ = is_np_comp;
return old;
}
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
Expand Down Expand Up @@ -165,9 +175,15 @@ class Imperative {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_comp_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_comp_;
#endif
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
Expand Down
28 changes: 26 additions & 2 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,36 @@ class TShape : public Tuple<dim_t> {
#endif
};

/*! brief check if a shape's ndim is known. */
inline bool ndim_is_known(const int ndim) {
CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;
return ndim != -1;
}

/*! brief check if a shape's ndim is known. */
inline bool ndim_is_known(const TShape& x) {
return ndim_is_known(x.ndim());
}

/*! brief check if a shape's dim size is known. */
inline bool dim_size_is_known(const int dim_size) {
CHECK_GE(dim_size, -1) << "shape dim size must be >= -1, while received " << dim_size;
return dim_size != -1;
}

/*! brief check if a shape's dim size is known. */
inline bool dim_size_is_known(const TShape& x, const int idx) {
CHECK(idx >= 0 && idx < x.ndim())
<< "idx = " << idx << " exceeds shape dimension range [0, " << x.ndim() << ")";
return dim_size_is_known(x[idx]);
}

/*! brief check if shape is known using the NumPy compatible definition.
* zero-dim and zero-size tensors are valid. -1 means unknown.*/
inline bool shape_is_known(const TShape& x) {
if (x.ndim() == -1) return false;
if (!ndim_is_known(x)) return false;
for (int i = 0; i < x.ndim(); ++i) {
if (x[i] == -1) return false;
if (!dim_size_is_known(x, i)) return false;
}
return true;
}
Expand Down
5 changes: 4 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,10 @@ def shape(self):
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index
if ndim.value == -1:
return None
else:
return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index


@property
Expand Down
46 changes: 46 additions & 0 deletions python/mxnet/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,50 @@
# specific language governing permissions and limitations
# under the License.

import ctypes
from ..base import _LIB, check_call

__all__ = []


def set_np_comp(is_np_comp):
prev = ctypes.c_int()
check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(is_np_comp), ctypes.byref(prev)))
return bool(prev.value)


def is_np_comp():
curr = ctypes.c_bool()
check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
return curr.value


class _NumpyCompatibilityStateScope(object):
"""Scope for managing numpy compatibility state.
Example::
with _NumpyCompatibilityStateScope(True):
y = model(x)
backward([y])
"""
def __init__(self, is_np_comp): #pylint: disable=redefined-outer-name
self._enter_is_np_comp = is_np_comp
self._prev_is_np_comp = None

def __enter__(self):
if self._enter_is_np_comp is not None:
self._prev_is_np_comp = set_np_comp(self._enter_is_np_comp)

def __exit__(self, ptype, value, trace):
if self._enter_is_np_comp is not None and self._prev_is_np_comp != self._enter_is_np_comp:
set_np_comp(self._prev_is_np_comp)


def enable_np_comp():
return _NumpyCompatibilityStateScope(True)


def disable_np_comp():
return _NumpyCompatibilityStateScope(False)
26 changes: 13 additions & 13 deletions python/mxnet/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ctypes import CFUNCTYPE, POINTER, Structure, pointer
from ctypes import c_void_p, c_int, c_char, c_char_p, cast, c_bool

from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf
from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int
from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
from . import symbol, context
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_symbol(self, *args, **kwargs):
fb_functype = CFUNCTYPE(None, c_int, POINTER(POINTER(mx_float)), POINTER(c_int),
POINTER(POINTER(mx_uint)), POINTER(c_int), c_void_p)
infer_functype = CFUNCTYPE(None, c_int, POINTER(c_int),
POINTER(POINTER(mx_uint)), c_void_p)
POINTER(POINTER(mx_int)), c_void_p)
list_functype = CFUNCTYPE(None, POINTER(POINTER(POINTER(c_char))), c_void_p)
class NumpyOpInfo(Structure):
"""Structure that holds Callback information. Passed to NumpyOpProp"""
Expand Down Expand Up @@ -214,9 +214,9 @@ def infer_shape_entry(num_tensor, tensor_dims,
assert len(ishape) == n_in
rshape = list(ishape) + list(oshape)
for i in range(n_in+n_out):
tensor_shapes[i] = cast(c_array_buf(mx_uint,
array('I', rshape[i])),
POINTER(mx_uint))
tensor_shapes[i] = cast(c_array_buf(mx_int,
array('i', rshape[i])),
POINTER(mx_int))
tensor_dims[i] = len(rshape[i])

def list_outputs_entry(out, _):
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(self, need_top_grad=True):
def get_symbol(self, *args, **kwargs):
fb_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_void_p), POINTER(c_int), c_void_p)
infer_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_int),
POINTER(POINTER(mx_uint)), c_void_p)
POINTER(POINTER(mx_int)), c_void_p)
list_functype = CFUNCTYPE(c_bool, POINTER(POINTER(POINTER(c_char))), c_void_p)
deps_functype = CFUNCTYPE(c_bool, c_int_p, c_int_p, c_int_p,
c_int_p, POINTER(c_int_p), c_void_p)
Expand Down Expand Up @@ -335,9 +335,9 @@ def infer_shape_entry(num_tensor, tensor_dims,
assert len(ishape) == n_in
rshape = list(ishape) + list(oshape)
for i in range(n_in+n_out):
tensor_shapes[i] = cast(c_array_buf(mx_uint,
array('I', rshape[i])),
POINTER(mx_uint))
tensor_shapes[i] = cast(c_array_buf(mx_int,
array('i', rshape[i])),
POINTER(mx_int))
tensor_dims[i] = len(rshape[i])
except Exception:
print('Error in NDArrayOp.infer_shape: %s' % traceback.format_exc())
Expand Down Expand Up @@ -698,7 +698,7 @@ def do_register(prop_cls):
del_functype = CFUNCTYPE(c_int, c_void_p)

infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
POINTER(POINTER(mx_uint)), c_void_p)
POINTER(POINTER(mx_int)), c_void_p)
infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), \
Expand Down Expand Up @@ -747,9 +747,9 @@ def infer_shape_entry(num_tensor, tensor_dims,
"shapes, got %d."%(n_aux, len(ashape))
rshape = list(ishape) + list(oshape) + list(ashape)
for i in range(n_in+n_out+n_aux):
tensor_shapes[i] = cast(c_array_buf(mx_uint,
array('I', rshape[i])),
POINTER(mx_uint))
tensor_shapes[i] = cast(c_array_buf(mx_int,
array('i', rshape[i])),
POINTER(mx_int))
tensor_dims[i] = len(rshape[i])

infer_shape_entry._ref_holder = [tensor_shapes]
Expand Down
22 changes: 15 additions & 7 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from ..ndarray import _ndarray_cls
from ..executor import Executor
from ..numpy import is_np_comp
from . import _internal
from . import op
from ._internal import SymbolBase, _set_symbol_class
Expand Down Expand Up @@ -1078,7 +1079,11 @@ def infer_shape(self, *args, **kwargs):
arg_names = self.list_arguments()
unknowns = []
for name, shape in zip(arg_names, arg_shapes):
if not shape or not _numpy.prod(shape):
if is_np_comp():
shape_is_none = not shape or -1 in shape
else:
shape_is_none = not shape or 0 in shape
if shape_is_none:
if len(unknowns) >= 10:
unknowns.append('...')
break
Expand Down Expand Up @@ -1204,12 +1209,15 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
ctypes.byref(aux_shape_data),
ctypes.byref(complete)))
if complete.value != 0:
arg_shapes = [
tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)]
out_shapes = [
tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)]
aux_shapes = [
tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)]
arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]])
if arg_shape_ndim[i] >= 0 else None
for i in range(arg_shape_size.value)]
out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]])
if out_shape_ndim[i] >= 0 else None
for i in range(out_shape_size.value)]
aux_shapes = [tuple(aux_shape_data[i][:aux_shape_ndim[i]])
if aux_shape_ndim[i] >= 0 else None
for i in range(aux_shape_size.value)]
return (arg_shapes, out_shapes, aux_shapes)
else:
return (None, None, None)
Expand Down
24 changes: 17 additions & 7 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
#include "mxnet/rtc.h"
#include "mxnet/storage.h"
#include "mxnet/libinfo.h"
#include "mxnet/imperative.h"
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
#include "../common/utils.h"

using namespace mxnet;

Expand Down Expand Up @@ -499,15 +501,23 @@ int MXNDArrayGetShape(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const mxnet::TShape &s = arr->shape();
mxnet::TShape s = arr->shape();
if (!Imperative::Get()->is_np_comp()) {
common::ConvertToLegacyShape(&s);
}
*out_dim = s.ndim();
CHECK_GE(s.ndim(), 0);
std::vector<int>& buffer = ret->arg_shape_buffer;
buffer.resize(s.ndim());
mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
if (s.ndim() >= 0) {
std::vector<int> &buffer = ret->arg_shape_buffer;
buffer.resize(s.ndim());
mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
}
} else {
*out_dim = 0;
if (Imperative::Get()->is_np_comp()) {
*out_dim = -1;
} else {
*out_dim = 0;
}
}
API_END();
}
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct MXAPIThreadLocalEntry {
data->resize(shapes.size());
size_t size = 0;
for (const auto& s : shapes) {
CHECK_GE(s.ndim(), 0);
if (s.ndim() > 0);
size += s.ndim();
}
buffer->resize(size);
Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
#include <mxnet/base.h>
#include <mxnet/c_api.h>
#include <mxnet/executor.h>
#include <mxnet/imperative.h>
#include "./c_api_common.h"
#include "../executor/graph_executor.h"
#include "../common/utils.h"
#if MXNET_USE_TENSORRT
#include "../executor/trt_graph_executor.h"
#endif // MXNET_USE_TENSORRT
Expand Down Expand Up @@ -416,6 +418,11 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
CHECK(p.second) << "Duplicate shapes are provided for argument "
<< provided_arg_shape_names[i] << " in simple_bind";
}
if (!Imperative::Get()->is_np_comp()) {
for (auto &kv : arg_shape_map) {
common::ConvertToNumpyShape(&kv.second);
}
}

// create para name set for sharing data array memory
std::unordered_set<std::string> shared_arg_name_set(num_shared_arg_names);
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,18 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) {
API_END();
}

int MXIsNumpyCompatible(bool* curr) {
API_BEGIN();
*curr = Imperative::Get()->is_np_comp();
API_END();
}

int MXSetIsNumpyCompatible(int is_np_comp, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_np_comp(static_cast<bool>(is_np_comp));
API_END();
}

int MXAutogradMarkVariables(mx_uint num_var,
NDArrayHandle *var_handles,
mx_uint *reqs_array,
Expand Down
9 changes: 8 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#include "mxnet/base.h"
#include "mxnet/c_api.h"
#include "mxnet/imperative.h"
#include "nnvm/c_api.h"
#include "nnvm/pass.h"
#include "nnvm/pass_functions.h"
Expand Down Expand Up @@ -543,8 +544,14 @@ int MXSymbolInferShape(SymbolHandle sym,
throw dmlc::Error(err.msg);
}

// if use legacy shape definition, need to convert numpy shape to legacy shape
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
if (!Imperative::Get()->is_np_comp()) {
common::ConvertToLegacyShape(&shapes);
}

// copy back
CopyAttr(g.indexed_graph(), g.GetAttr<mxnet::ShapeVector>("shape"),
CopyAttr(g.indexed_graph(), shapes,
&(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));

// copy data back
Expand Down
Loading

0 comments on commit 19434dd

Please sign in to comment.