Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Use new shape definition #14453

Merged
merged 6 commits into from
Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);
int *out_dim,
const int **out_pdata);
/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -1481,16 +1481,16 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
const int *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);
/*!
* \brief partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1520,16 +1520,16 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
const int *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);

/*!
Expand Down Expand Up @@ -1808,7 +1808,7 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const int* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
Expand Down Expand Up @@ -1862,7 +1862,7 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
const int* map_dev_ids,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const int* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
mx_uint* num_in_args,
NDArrayHandle** in_args,
Expand Down Expand Up @@ -2538,8 +2538,8 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
* \param dtype data type of NDArray
* \param out constructed NDArray
*/
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape,
int ndim, int dtype, NDArrayHandle *out);


#ifdef __cplusplus
Expand Down
9 changes: 6 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -859,12 +859,15 @@ class NDArray {
Chunk(mxnet::TShape shape, Context ctx_, bool delay_alloc_, int dtype)
: static_data(false), delay_alloc(true), ctx(ctx_),
storage_ref_(Storage::_GetSharedRef()) {
auto size = shape.Size();
storage_shape = shape;
if (shape_is_known(storage_shape)) {
shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
}
var = Engine::Get()->NewVariable();
shandle.size = size * mshadow::mshadow_sizeof(dtype);
shandle.ctx = ctx_;
if (!delay_alloc_) this->CheckAndAlloc();
if (!delay_alloc_) {
this->CheckAndAlloc();
}
}

Chunk(const TBlob &data, int dev_id)
Expand Down
37 changes: 23 additions & 14 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class Tuple {
* \return the corresponding dimension size
*/
inline ValueType& operator[](int i) {
CHECK(i >= 0 && i < ndim());
CHECK(i >= 0 && i < ndim()) << "index = " << i << " must be in range [0, " << ndim() << ")";
return begin()[i];
}
/*!
Expand All @@ -208,7 +208,7 @@ class Tuple {
* \return the corresponding dimension size
*/
inline const ValueType& operator[](int i) const {
CHECK(i >= 0 && i < ndim());
CHECK(i >= 0 && i < ndim()) << "index = " << i << " must be in range [0, " << ndim() << ")";
return begin()[i];
}
/*!
Expand Down Expand Up @@ -271,14 +271,16 @@ class Tuple {
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
}
// Handle empty tuple
// Handle empty tuple. A tensor whose shape is an empty tuple
// represents a scalar with ndim = 0.
while (isspace(is.peek())) {
is.get();
}
if (is.peek() == ')' || is.peek() == ']') {
is.get();
t.SetDim(0);
return is;
}
// Handle non-empty tuple
Expand Down Expand Up @@ -352,7 +354,7 @@ class Tuple {
delete [] data_heap_;
data_heap_ = new ValueType[ndim];
num_heap_allocated_ = ndim;
} else if (ndim == -1 && data_heap_ != nullptr) {
} else if (ndim <= 0 && data_heap_ != nullptr) {
delete [] data_heap_;
data_heap_ = nullptr;
num_heap_allocated_ = 0;
Expand Down Expand Up @@ -381,14 +383,11 @@ class TShape : public Tuple<dim_t> {
this->SetDim(-1);
}
/*!
* constructor to construct a shape with all 1.
* TODO(junwu): The value should default to -1. Need to keep 1 for now
* for backward compatibility. Change it to -1 in the future when we can
* break backward compatibility.
* constructor to construct a shape with all `value`.
* \param ndim the number of dimension
* \param value the dimension size for all dims
*/
inline TShape(int ndim, int value = 1) { // NOLINT(*)
inline TShape(int ndim, int value = -1) { // NOLINT(*)
this->SetDim(ndim);
if (ndim > 0) {
std::fill_n(begin(), ndim, value);
Expand Down Expand Up @@ -458,7 +457,7 @@ class TShape : public Tuple<dim_t> {
dim_t size = 1;
const dim_t* start = begin(), *fin = end();
for (const dim_t* it = start; it != fin; ++it) {
CHECK_GE(*it, 0) << "Shape dim size cannot be -1, which means unknown.";
CHECK_GE(*it, 0) << "Shape dim size cannot be a negative value " << *it;
size *= *it;
}
return size;
Expand All @@ -473,7 +472,7 @@ class TShape : public Tuple<dim_t> {
dim_t num = 1;
const dim_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
CHECK_GE(d[i], 0) << "Shape dim size cannot be -1, which means unknown.";
CHECK_GE(d[i], 0) << "Shape dim size cannot be a negative value " << d[i];
num *= d[i];
}
return num;
Expand Down Expand Up @@ -608,6 +607,16 @@ class TShape : public Tuple<dim_t> {
#endif
};

/*! brief check if shape is known using the NumPy compatible definition.
* zero-dim and zero-size tensors are valid. -1 means unknown.*/
inline bool shape_is_known(const TShape& x) {
if (x.ndim() == -1) return false;
for (int i = 0; i < x.ndim(); ++i) {
if (x[i] == -1) return false;
}
return true;
}

/*! \brief helper function to cast type of container elements */
template<typename SrcIter, typename DstIter>
inline DstIter ShapeTypeCast(const SrcIter begin,
Expand All @@ -623,7 +632,7 @@ inline DstIter ShapeTypeCast(const SrcIter begin,
template<typename SrcIter>
inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) {
size_t ndim = std::distance(begin, end);
TShape res(ndim);
TShape res(ndim, -1);
ShapeTypeCast(begin, end, res.begin());
return res;
}
Expand Down Expand Up @@ -669,7 +678,7 @@ struct hash<mxnet::Tuple<T> > {
size_t operator()(const mxnet::Tuple<T>& val) const {
std::hash<uint32_t> hash_uint;
size_t res = hash_uint(val.ndim());
for (uint32_t i = 0; i < val.ndim(); ++i) {
for (int i = 0; i < val.ndim(); ++i) {
res = dmlc::HashCombine(res, val[i]);
}
return res;
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _load_lib():
_LIB = _load_lib()

# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import copy
import numpy as np
from .base import _LIB
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str, mx_int
from .base import check_call, c_handle_array, c_array_buf, c_str_array
from .ndarray import NDArray
from .ndarray import _ndarray_cls
Expand Down Expand Up @@ -445,8 +445,8 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
py_array('i', ctx_map_dev_ids)),
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_uint,
py_array('I', provided_arg_shape_data)),
c_array_buf(mx_int,
py_array('i', provided_arg_shape_data)),
c_array_buf(mx_uint,
py_array('I', provided_arg_shape_idx)),
ctypes.byref(num_in_args),
Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
from ..base import ctypes2buffer
from ..context import Context, current_context
from . import _internal
Expand Down Expand Up @@ -146,8 +146,8 @@ def _new_from_shared_mem(shared_pid, shared_id, shape, dtype):
check_call(_LIB.MXNDArrayCreateFromSharedMem(
ctypes.c_int(shared_pid),
ctypes.c_int(shared_id),
c_array(mx_uint, shape),
mx_uint(len(shape)),
c_array(mx_int, shape),
mx_int(len(shape)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl
Expand Down Expand Up @@ -1848,8 +1848,8 @@ def shape(self):
>>> y.shape
(2L, 3L, 4L)
"""
ndim = mx_uint()
pdata = ctypes.POINTER(mx_uint)()
ndim = mx_int()
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index
Expand Down
20 changes: 10 additions & 10 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
from ..base import mx_uint, py_str, string_types, integer_types
from ..base import mx_uint, py_str, string_types, integer_types, mx_int
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
Expand Down Expand Up @@ -1174,14 +1174,14 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
indptr.append(len(sdata))
keys = c_str_array(str_keys)
arg_shape_size = mx_uint()
arg_shape_ndim = ctypes.POINTER(mx_uint)()
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
arg_shape_ndim = ctypes.POINTER(mx_int)()
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
out_shape_size = mx_uint()
out_shape_ndim = ctypes.POINTER(mx_uint)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
out_shape_ndim = ctypes.POINTER(mx_int)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
aux_shape_size = mx_uint()
aux_shape_ndim = ctypes.POINTER(mx_uint)()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
aux_shape_ndim = ctypes.POINTER(mx_int)()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
complete = ctypes.c_int()
if partial:
infer_func = _LIB.MXSymbolInferShapePartial
Expand All @@ -1192,7 +1192,7 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
mx_uint(len(indptr) - 1),
keys,
c_array_buf(mx_uint, array('I', indptr)),
c_array_buf(mx_uint, array('I', sdata)),
c_array_buf(mx_int, array('i', sdata)),
ctypes.byref(arg_shape_size),
ctypes.byref(arg_shape_ndim),
ctypes.byref(arg_shape_data),
Expand Down Expand Up @@ -1576,10 +1576,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_uint,
c_array_buf(mx_int,
array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('I', provided_arg_shape_idx)),
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
Expand Down
15 changes: 8 additions & 7 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
mxnet::Tuple<dim_t> shape(dims, dims+ndim);
CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: "
<< arr->shape();
mxnet::TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
Expand All @@ -492,17 +492,18 @@ int MXNDArrayGetStorageType(NDArrayHandle handle,
}

int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata) {
int *out_dim,
const int **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const mxnet::TShape &s = arr->shape();
*out_dim = s.ndim();
std::vector<uint32_t>& buffer = ret->arg_shape_buffer;
CHECK_GE(s.ndim(), 0);
std::vector<int>& buffer = ret->arg_shape_buffer;
buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), buffer.data());
mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
} else {
*out_dim = 0;
Expand Down Expand Up @@ -1394,8 +1395,8 @@ int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shar
API_END();
}

int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out) {
int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape,
int ndim, int dtype, NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
API_END();
Expand Down
Loading