Skip to content

Commit

Permalink
[WIP] Use new shape definition (apache#14453)
Browse files Browse the repository at this point in the history
* Init checkin

* Fix ndarray alloc bug

* Use TShape(0) as default empty tuple params

* Fix bugs

* Fix TShape init value

* Fix infer shape pass shape type and reshape infer shape func
  • Loading branch information
reminisce committed Apr 5, 2019
1 parent 02c9f8b commit 89a3037
Show file tree
Hide file tree
Showing 87 changed files with 574 additions and 557 deletions.
40 changes: 20 additions & 20 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);
int *out_dim,
const int **out_pdata);
/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -1481,16 +1481,16 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
const int *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);
/*!
* \brief partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1520,16 +1520,16 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
const int *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);

/*!
Expand Down Expand Up @@ -1808,7 +1808,7 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const int* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
Expand Down Expand Up @@ -1862,7 +1862,7 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
const int* map_dev_ids,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const int* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
mx_uint* num_in_args,
NDArrayHandle** in_args,
Expand Down Expand Up @@ -2538,8 +2538,8 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
* \param dtype data type of NDArray
* \param out constructed NDArray
*/
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape,
int ndim, int dtype, NDArrayHandle *out);


#ifdef __cplusplus
Expand Down
9 changes: 6 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -859,12 +859,15 @@ class NDArray {
Chunk(mxnet::TShape shape, Context ctx_, bool delay_alloc_, int dtype)
: static_data(false), delay_alloc(true), ctx(ctx_),
storage_ref_(Storage::_GetSharedRef()) {
auto size = shape.Size();
storage_shape = shape;
if (shape_is_known(storage_shape)) {
shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
}
var = Engine::Get()->NewVariable();
shandle.size = size * mshadow::mshadow_sizeof(dtype);
shandle.ctx = ctx_;
if (!delay_alloc_) this->CheckAndAlloc();
if (!delay_alloc_) {
this->CheckAndAlloc();
}
}

Chunk(const TBlob &data, int dev_id)
Expand Down
37 changes: 23 additions & 14 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class Tuple {
* \return the corresponding dimension size
*/
inline ValueType& operator[](int i) {
CHECK(i >= 0 && i < ndim());
CHECK(i >= 0 && i < ndim()) << "index = " << i << " must be in range [0, " << ndim() << ")";
return begin()[i];
}
/*!
Expand All @@ -208,7 +208,7 @@ class Tuple {
* \return the corresponding dimension size
*/
inline const ValueType& operator[](int i) const {
CHECK(i >= 0 && i < ndim());
CHECK(i >= 0 && i < ndim()) << "index = " << i << " must be in range [0, " << ndim() << ")";
return begin()[i];
}
/*!
Expand Down Expand Up @@ -271,14 +271,16 @@ class Tuple {
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
}
// Handle empty tuple
// Handle empty tuple. A tensor whose shape is an empty tuple
// represents a scalar with ndim = 0.
while (isspace(is.peek())) {
is.get();
}
if (is.peek() == ')' || is.peek() == ']') {
is.get();
t.SetDim(0);
return is;
}
// Handle non-empty tuple
Expand Down Expand Up @@ -352,7 +354,7 @@ class Tuple {
delete [] data_heap_;
data_heap_ = new ValueType[ndim];
num_heap_allocated_ = ndim;
} else if (ndim == -1 && data_heap_ != nullptr) {
} else if (ndim <= 0 && data_heap_ != nullptr) {
delete [] data_heap_;
data_heap_ = nullptr;
num_heap_allocated_ = 0;
Expand Down Expand Up @@ -381,14 +383,11 @@ class TShape : public Tuple<dim_t> {
this->SetDim(-1);
}
/*!
* constructor to construct a shape with all 1.
* TODO(junwu): The value should default to -1. Need to keep 1 for now
* for backward compatibility. Change it to -1 in the future when we can
* break backward compatibility.
* constructor to construct a shape with all `value`.
* \param ndim the number of dimension
* \param value the dimension size for all dims
*/
inline TShape(int ndim, int value = 1) { // NOLINT(*)
inline TShape(int ndim, int value = -1) { // NOLINT(*)
this->SetDim(ndim);
if (ndim > 0) {
std::fill_n(begin(), ndim, value);
Expand Down Expand Up @@ -458,7 +457,7 @@ class TShape : public Tuple<dim_t> {
dim_t size = 1;
const dim_t* start = begin(), *fin = end();
for (const dim_t* it = start; it != fin; ++it) {
CHECK_GE(*it, 0) << "Shape dim size cannot be -1, which means unknown.";
CHECK_GE(*it, 0) << "Shape dim size cannot be a negative value " << *it;
size *= *it;
}
return size;
Expand All @@ -473,7 +472,7 @@ class TShape : public Tuple<dim_t> {
dim_t num = 1;
const dim_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
CHECK_GE(d[i], 0) << "Shape dim size cannot be -1, which means unknown.";
CHECK_GE(d[i], 0) << "Shape dim size cannot be a negative value " << d[i];
num *= d[i];
}
return num;
Expand Down Expand Up @@ -608,6 +607,16 @@ class TShape : public Tuple<dim_t> {
#endif
};

/*! brief check if shape is known using the NumPy compatible definition.
* zero-dim and zero-size tensors are valid. -1 means unknown.*/
inline bool shape_is_known(const TShape& x) {
if (x.ndim() == -1) return false;
for (int i = 0; i < x.ndim(); ++i) {
if (x[i] == -1) return false;
}
return true;
}

/*! \brief helper function to cast type of container elements */
template<typename SrcIter, typename DstIter>
inline DstIter ShapeTypeCast(const SrcIter begin,
Expand All @@ -623,7 +632,7 @@ inline DstIter ShapeTypeCast(const SrcIter begin,
template<typename SrcIter>
inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) {
size_t ndim = std::distance(begin, end);
TShape res(ndim);
TShape res(ndim, -1);
ShapeTypeCast(begin, end, res.begin());
return res;
}
Expand Down Expand Up @@ -669,7 +678,7 @@ struct hash<mxnet::Tuple<T> > {
size_t operator()(const mxnet::Tuple<T>& val) const {
std::hash<uint32_t> hash_uint;
size_t res = hash_uint(val.ndim());
for (uint32_t i = 0; i < val.ndim(); ++i) {
for (int i = 0; i < val.ndim(); ++i) {
res = dmlc::HashCombine(res, val[i]);
}
return res;
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _load_lib():
_LIB = _load_lib()

# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import copy
import numpy as np
from .base import _LIB
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str, mx_int
from .base import check_call, c_handle_array, c_array_buf, c_str_array
from .ndarray import NDArray
from .ndarray import _ndarray_cls
Expand Down Expand Up @@ -445,8 +445,8 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
py_array('i', ctx_map_dev_ids)),
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_uint,
py_array('I', provided_arg_shape_data)),
c_array_buf(mx_int,
py_array('i', provided_arg_shape_data)),
c_array_buf(mx_uint,
py_array('I', provided_arg_shape_idx)),
ctypes.byref(num_in_args),
Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
from ..base import ctypes2buffer
from ..context import Context, current_context
from . import _internal
Expand Down Expand Up @@ -146,8 +146,8 @@ def _new_from_shared_mem(shared_pid, shared_id, shape, dtype):
check_call(_LIB.MXNDArrayCreateFromSharedMem(
ctypes.c_int(shared_pid),
ctypes.c_int(shared_id),
c_array(mx_uint, shape),
mx_uint(len(shape)),
c_array(mx_int, shape),
mx_int(len(shape)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl
Expand Down Expand Up @@ -1848,8 +1848,8 @@ def shape(self):
>>> y.shape
(2L, 3L, 4L)
"""
ndim = mx_uint()
pdata = ctypes.POINTER(mx_uint)()
ndim = mx_int()
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index
Expand Down
20 changes: 10 additions & 10 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
from ..base import mx_uint, py_str, string_types, integer_types
from ..base import mx_uint, py_str, string_types, integer_types, mx_int
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
Expand Down Expand Up @@ -1174,14 +1174,14 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
indptr.append(len(sdata))
keys = c_str_array(str_keys)
arg_shape_size = mx_uint()
arg_shape_ndim = ctypes.POINTER(mx_uint)()
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
arg_shape_ndim = ctypes.POINTER(mx_int)()
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
out_shape_size = mx_uint()
out_shape_ndim = ctypes.POINTER(mx_uint)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
out_shape_ndim = ctypes.POINTER(mx_int)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
aux_shape_size = mx_uint()
aux_shape_ndim = ctypes.POINTER(mx_uint)()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
aux_shape_ndim = ctypes.POINTER(mx_int)()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
complete = ctypes.c_int()
if partial:
infer_func = _LIB.MXSymbolInferShapePartial
Expand All @@ -1192,7 +1192,7 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
mx_uint(len(indptr) - 1),
keys,
c_array_buf(mx_uint, array('I', indptr)),
c_array_buf(mx_uint, array('I', sdata)),
c_array_buf(mx_int, array('i', sdata)),
ctypes.byref(arg_shape_size),
ctypes.byref(arg_shape_ndim),
ctypes.byref(arg_shape_data),
Expand Down Expand Up @@ -1576,10 +1576,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_uint,
c_array_buf(mx_int,
array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('I', provided_arg_shape_idx)),
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
Expand Down
15 changes: 8 additions & 7 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
mxnet::Tuple<dim_t> shape(dims, dims+ndim);
CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: "
<< arr->shape();
mxnet::TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
Expand All @@ -493,17 +493,18 @@ int MXNDArrayGetStorageType(NDArrayHandle handle,
}

int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata) {
int *out_dim,
const int **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const mxnet::TShape &s = arr->shape();
*out_dim = s.ndim();
std::vector<uint32_t>& buffer = ret->arg_shape_buffer;
CHECK_GE(s.ndim(), 0);
std::vector<int>& buffer = ret->arg_shape_buffer;
buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), buffer.data());
mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
} else {
*out_dim = 0;
Expand Down Expand Up @@ -1395,8 +1396,8 @@ int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shar
API_END();
}

int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out) {
int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape,
int ndim, int dtype, NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
API_END();
Expand Down
Loading

0 comments on commit 89a3037

Please sign in to comment.