Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adding Large Index Support for slice operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Jul 19, 2019
1 parent 076b2f3 commit fb9540e
Show file tree
Hide file tree
Showing 12 changed files with 416 additions and 71 deletions.
156 changes: 149 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ extern "C" {
#endif

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
typedef uint32_t mx_uint;
/*! \brief manually define 64-bit int */
typedef int64_t mx_int64;
/*! \brief manually define float */
typedef float mx_float;
/*! \brief data type to store dim size */
Expand Down Expand Up @@ -556,6 +558,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int dtype,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateExInt64(const mx_int64 *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out);

/*!
* \brief create an empty sparse NDArray with specified shape and data type
Expand Down Expand Up @@ -587,6 +596,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *aux_shape,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateSparseExInt64(int storage_type,
const mx_int64 *shape,
mx_int64 ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
Expand Down Expand Up @@ -634,6 +656,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
mx_uint *out_name_size,
const char*** out_names);

MXNET_DLL int MXNDArrayLoadInt64(const char* fname,
mx_int64 *out_size,
NDArrayHandle** out_arr,
mx_int64 *out_name_size,
const char*** out_names);

/*!
* \brief Load list / dictionary of narrays from file content loaded into memory.
* This will load a list of ndarrays in a similar
Expand All @@ -649,11 +677,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);

MXNET_DLL int MXNDArrayLoadFromBufferInt64(const void *ndarray_buffer,
size_t size,
mx_int64 *out_size,
NDArrayHandle** out_arr,
mx_int64 *out_name_size,
const char*** out_names);

/*!
* \brief Perform a synchronize copy from a continugous CPU memory region.
Expand Down Expand Up @@ -793,6 +828,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);

MXNET_DLL int MXNDArrayGetShapeInt64(NDArrayHandle handle,
mx_int64 *out_dim,
const mx_int64 **out_pdata);

/*!
* \brief get the shape of the array
* \param handle the handle to the narray
Expand All @@ -803,6 +843,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
int *out_dim,
const int **out_pdata);

MXNET_DLL int MXNDArrayGetShapeExInt64(NDArrayHandle handle,
int *out_dim,
const int64_t **out_pdata);

/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -886,6 +931,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type);

MXNET_DLL int MXNDArrayGetAuxTypeInt64(NDArrayHandle handle,
mx_int64 i,
int *out_type);

/*!
* \brief Get a deep copy of the ith aux data blob
* in the form of an NDArray of default storage type.
Expand All @@ -895,6 +944,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayGetAuxNDArrayInt64(NDArrayHandle handle,
mx_int64 i,
NDArrayHandle *out);

/*!
* \brief Get a deep copy of the data blob
* in the form of an NDArray of default storage type.
Expand Down Expand Up @@ -950,6 +1003,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
*/
MXNET_DLL int MXListFunctions(mx_uint *out_size,
FunctionHandle **out_array);

MXNET_DLL int MXListFunctionsInt64(mx_int64 *out_size,
FunctionHandle **out_array);

/*!
* \brief get the function handle by name
* \param name the name of the function
Expand Down Expand Up @@ -1217,6 +1274,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
*/
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
const char ***out_array);

MXNET_DLL int MXListAllOpNamesInt64(mx_int64 *out_size,
const char ***out_array);

/*!
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
Expand All @@ -1226,6 +1287,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);

MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64(mx_int64 *out_size,
AtomicSymbolCreator **out_array);

/*!
* \brief Get the name of an atomic symbol.
* \param creator the AtomicSymbolCreator.
Expand Down Expand Up @@ -1438,6 +1502,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListArgumentsInt64(SymbolHandle symbol,
mx_int64 *out_size,
const char ***out_str_array);

/*!
* \brief List returns in the symbol.
* \param symbol the symbol
Expand All @@ -1449,14 +1518,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListOutputsInt64(SymbolHandle symbol,
mx_int64 *out_size,
const char ***out_str_array);

/*!
* \brief Get number of outputs of the symbol.
* \param symbol The symbol
* \param out_size number of outputs
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
mx_uint *output_count);
mx_uint *output_count);

/*!
* \brief Get a symbol that contains all the internals.
Expand Down Expand Up @@ -1495,6 +1568,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListAuxiliaryStatesInt64(SymbolHandle symbol,
mx_int64 *out_size,
const char ***out_str_array);

/*!
* \brief Compose the symbol on other symbols.
*
Expand Down Expand Up @@ -1566,6 +1644,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
const mx_uint ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapeInt64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
mx_int64 *in_shape_size,
const mx_int64 **in_shape_ndim,
const mx_int64 ***in_shape_data,
mx_int64 *out_shape_size,
const mx_int64 **out_shape_ndim,
const mx_int64 ***out_shape_data,
mx_int64 *aux_shape_size,
const mx_int64 **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
Expand Down Expand Up @@ -1603,6 +1697,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapeExInt64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const int *arg_shape_data,
mx_uint *in_shape_size,
const int **in_shape_ndim,
const int64_t ***in_shape_data,
mx_uint *out_shape_size,
const int **out_shape_ndim,
const int64_t ***out_shape_data,
mx_uint *aux_shape_size,
const int **aux_shape_ndim,
const int64_t ***aux_shape_data,
int *complete);

/*!
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
* partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1644,6 +1755,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
const mx_uint ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapePartialInt64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
mx_int64 *in_shape_size,
const mx_int64 **in_shape_ndim,
const mx_int64 ***in_shape_data,
mx_int64 *out_shape_size,
const mx_int64 **out_shape_ndim,
const mx_int64 ***out_shape_data,
mx_int64 *aux_shape_size,
const mx_int64 **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1685,6 +1811,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
const int ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapePartialExInt64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const int *arg_shape_data,
mx_int64 *in_shape_size,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_int64 *out_shape_size,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_int64 *aux_shape_size,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);

/*!
* \brief infer type of unknown input types given the known one.
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/c_predict_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ extern "C" {
#endif

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
typedef uint32_t mx_uint;
/*! \brief manually define 64-bit int */
typedef int64_t mx_int64;
/*! \brief manually define float */
typedef float mx_float;
/*! \brief handle to Predictor */
Expand Down
1 change: 0 additions & 1 deletion include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ class Tuple {
}
};


/*! brief check if a shape's ndim is known. */
inline bool ndim_is_known(const int ndim) {
CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _load_lib():
# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
mx_int64 = ctypes.c_int64
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
mx_real_t = _np.float32
Expand Down
40 changes: 28 additions & 12 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
from ..base import ctypes2buffer
from ..runtime import Features
from ..context import Context, current_context
from . import _internal
from . import op
Expand Down Expand Up @@ -131,14 +132,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
if Features().is_enabled('INT64_TENSOR_SIZE'):
check_call(_LIB.MXNDArrayCreateExInt64(
c_array_buf(mx_int64, native_array('q', shape)),
mx_int64(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
else:
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl


Expand Down Expand Up @@ -1847,9 +1858,14 @@ def shape(self):
(2L, 3L, 4L)
"""
ndim = mx_int()
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
if Features().is_enabled('INT64_TENSOR_SIZE'):
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShapeExInt64(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
else:
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
if ndim.value == -1:
return None
else:
Expand Down
Loading

0 comments on commit fb9540e

Please sign in to comment.