Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Large Index Support for Slice (#15593)
Browse files Browse the repository at this point in the history
* Adding Large Index Support for slice operator

* adding changes to fix py2 related error in CI/CD

* fixing base.py

* rearrange system call and slower Feature() call

* refactoring c_api, c_symbolic_api, c_api_common

* templatizing code

* caching results of runtime features and minor refactoring

* fixing local caching in ndarray shape
  • Loading branch information
access2rohit authored and apeforest committed Aug 13, 2019
1 parent 795990b commit 05f3ae1
Show file tree
Hide file tree
Showing 17 changed files with 549 additions and 183 deletions.
156 changes: 149 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ extern "C" {
#endif

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
typedef uint32_t mx_uint;
/*! \brief manually define 64-bit int */
typedef int64_t mx_int64;
/*! \brief manually define float */
typedef float mx_float;
/*! \brief data type to store dim size */
Expand Down Expand Up @@ -572,6 +574,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int dtype,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape,
int ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out);

/*!
* \brief create an empty sparse NDArray with specified shape and data type
Expand Down Expand Up @@ -603,6 +612,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *aux_shape,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type,
const mx_int64 *shape,
int ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
int *aux_ndims,
const mx_int64 *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
Expand Down Expand Up @@ -650,6 +672,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
mx_uint *out_name_size,
const char*** out_names);

MXNET_DLL int MXNDArrayLoad64(const char* fname,
mx_int64 *out_size,
NDArrayHandle** out_arr,
mx_int64 *out_name_size,
const char*** out_names);

/*!
* \brief Load list / dictionary of narrays from file content loaded into memory.
* This will load a list of ndarrays in a similar
Expand All @@ -665,11 +693,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);

MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer,
size_t size,
mx_int64 *out_size,
NDArrayHandle** out_arr,
mx_int64 *out_name_size,
const char*** out_names);

/*!
* \brief Perform a synchronize copy from a continugous CPU memory region.
Expand Down Expand Up @@ -809,6 +844,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);

MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle,
int *out_dim,
const int64_t **out_pdata);

/*!
* \brief get the shape of the array
* \param handle the handle to the narray
Expand All @@ -819,6 +859,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
int *out_dim,
const int **out_pdata);

MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle,
int *out_dim,
const mx_int64 **out_pdata);

/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -902,6 +947,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type);

MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle,
mx_int64 i,
int *out_type);

/*!
* \brief Get a deep copy of the ith aux data blob
* in the form of an NDArray of default storage type.
Expand All @@ -911,6 +960,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle,
mx_int64 i,
NDArrayHandle *out);

/*!
* \brief Get a deep copy of the data blob
* in the form of an NDArray of default storage type.
Expand Down Expand Up @@ -966,6 +1019,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
*/
MXNET_DLL int MXListFunctions(mx_uint *out_size,
FunctionHandle **out_array);

MXNET_DLL int MXListFunctions64(mx_int64 *out_size,
FunctionHandle **out_array);

/*!
* \brief get the function handle by name
* \param name the name of the function
Expand Down Expand Up @@ -1233,6 +1290,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
*/
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
const char ***out_array);

MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size,
const char ***out_array);

/*!
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
Expand All @@ -1242,6 +1303,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);

MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size,
AtomicSymbolCreator **out_array);

/*!
* \brief Get the name of an atomic symbol.
* \param creator the AtomicSymbolCreator.
Expand Down Expand Up @@ -1454,6 +1518,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol,
size_t *out_size,
const char ***out_str_array);

/*!
* \brief List returns in the symbol.
* \param symbol the symbol
Expand All @@ -1465,14 +1534,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol,
size_t *out_size,
const char ***out_str_array);

/*!
* \brief Get number of outputs of the symbol.
* \param symbol The symbol
* \param out_size number of outputs
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
mx_uint *output_count);
mx_uint *output_count);

/*!
* \brief Get a symbol that contains all the internals.
Expand Down Expand Up @@ -1511,6 +1584,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol,
size_t *out_size,
const char ***out_str_array);

/*!
* \brief Compose the symbol on other symbols.
*
Expand Down Expand Up @@ -1582,6 +1660,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
const mx_uint ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
Expand Down Expand Up @@ -1619,6 +1713,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
* partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1660,6 +1771,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
const mx_uint ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1701,6 +1827,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
const int ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief infer type of unknown input types given the known one.
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/c_predict_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ extern "C" {
#endif

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
typedef uint32_t mx_uint;
/*! \brief manually define 64-bit int */
typedef int64_t mx_int64;
/*! \brief manually define float */
typedef float mx_float;
/*! \brief handle to Predictor */
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _load_lib():
# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
mx_int64 = ctypes.c_int64
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
mx_real_t = _np.float32
Expand Down
49 changes: 37 additions & 12 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
import warnings
import operator
from functools import reduce # pylint: disable=redefined-builtin
import sys
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
from ..base import ctypes2buffer
from ..runtime import Features
from ..context import Context, current_context
from . import _internal
from . import op
Expand Down Expand Up @@ -105,6 +107,14 @@
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1

# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None

def _int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
return _INT64_TENSOR_SIZE_ENABLED

def _new_empty_handle():
"""Returns a new empty handle.
Expand Down Expand Up @@ -132,14 +142,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
if sys.version_info[0] > 2 and _int64_enabled():
check_call(_LIB.MXNDArrayCreateEx64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
else:
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl


Expand Down Expand Up @@ -2218,9 +2238,14 @@ def shape(self):
(2L, 3L, 4L)
"""
ndim = mx_int()
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
if _int64_enabled():
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShapeEx64(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
else:
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
if ndim.value == -1:
return None
else:
Expand Down
Loading

0 comments on commit 05f3ae1

Please sign in to comment.