Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Large Index Support for Slice #15593

Merged
merged 8 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 149 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ extern "C" {
#endif

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
typedef uint32_t mx_uint;
/*! \brief manually define 64-bit int */
typedef int64_t mx_int64;
apeforest marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief manually define float */
typedef float mx_float;
/*! \brief data type to store dim size */
Expand Down Expand Up @@ -572,6 +574,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int dtype,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape,
int ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out);

/*!
* \brief create an empty sparse NDArray with specified shape and data type
Expand Down Expand Up @@ -603,6 +612,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *aux_shape,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type,
const mx_int64 *shape,
int ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
int *aux_ndims,
const mx_int64 *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
Expand Down Expand Up @@ -650,6 +672,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
mx_uint *out_name_size,
const char*** out_names);

MXNET_DLL int MXNDArrayLoad64(const char* fname,
mx_int64 *out_size,
NDArrayHandle** out_arr,
mx_int64 *out_name_size,
const char*** out_names);

/*!
* \brief Load list / dictionary of narrays from file content loaded into memory.
* This will load a list of ndarrays in a similar
Expand All @@ -665,11 +693,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);

MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer,
size_t size,
mx_int64 *out_size,
NDArrayHandle** out_arr,
mx_int64 *out_name_size,
const char*** out_names);

/*!
* \brief Perform a synchronize copy from a continugous CPU memory region.
Expand Down Expand Up @@ -809,6 +844,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);

MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle,
int *out_dim,
const int64_t **out_pdata);

/*!
* \brief get the shape of the array
* \param handle the handle to the narray
Expand All @@ -819,6 +859,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
int *out_dim,
const int **out_pdata);

MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle,
int *out_dim,
const mx_int64 **out_pdata);

/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -902,6 +947,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type);

MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle,
mx_int64 i,
int *out_type);

/*!
* \brief Get a deep copy of the ith aux data blob
* in the form of an NDArray of default storage type.
Expand All @@ -911,6 +960,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle,
mx_int64 i,
NDArrayHandle *out);

/*!
* \brief Get a deep copy of the data blob
* in the form of an NDArray of default storage type.
Expand Down Expand Up @@ -966,6 +1019,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
*/
MXNET_DLL int MXListFunctions(mx_uint *out_size,
FunctionHandle **out_array);

MXNET_DLL int MXListFunctions64(mx_int64 *out_size,
FunctionHandle **out_array);

/*!
* \brief get the function handle by name
* \param name the name of the function
Expand Down Expand Up @@ -1233,6 +1290,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
*/
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
const char ***out_array);

MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size,
const char ***out_array);

/*!
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
Expand All @@ -1242,6 +1303,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);

MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size,
AtomicSymbolCreator **out_array);

/*!
* \brief Get the name of an atomic symbol.
* \param creator the AtomicSymbolCreator.
Expand Down Expand Up @@ -1454,6 +1518,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol,
size_t *out_size,
const char ***out_str_array);

/*!
* \brief List returns in the symbol.
* \param symbol the symbol
Expand All @@ -1465,14 +1534,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol,
size_t *out_size,
const char ***out_str_array);

/*!
* \brief Get number of outputs of the symbol.
* \param symbol The symbol
* \param out_size number of outputs
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
mx_uint *output_count);
mx_uint *output_count);

/*!
* \brief Get a symbol that contains all the internals.
Expand Down Expand Up @@ -1511,6 +1584,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);

MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol,
size_t *out_size,
const char ***out_str_array);

/*!
* \brief Compose the symbol on other symbols.
*
Expand Down Expand Up @@ -1582,6 +1660,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
const mx_uint ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym,
apeforest marked this conversation as resolved.
Show resolved Hide resolved
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
Expand Down Expand Up @@ -1619,6 +1713,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
* partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1660,6 +1771,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
const mx_uint ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1701,6 +1827,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
const int ***aux_shape_data,
int *complete);

MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_int64 *arg_ind_ptr,
const mx_int64 *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const mx_int64 ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const mx_int64 ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const mx_int64 ***aux_shape_data,
int *complete);

/*!
* \brief infer type of unknown input types given the known one.
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/c_predict_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ extern "C" {
#endif

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
typedef uint32_t mx_uint;
apeforest marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief manually define 64-bit int */
typedef int64_t mx_int64;
/*! \brief manually define float */
typedef float mx_float;
/*! \brief handle to Predictor */
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _load_lib():
# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
mx_int64 = ctypes.c_int64
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
mx_real_t = _np.float32
Expand Down
49 changes: 37 additions & 12 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
import warnings
import operator
from functools import reduce # pylint: disable=redefined-builtin
import sys
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
from ..base import ctypes2buffer
from ..runtime import Features
from ..context import Context, current_context
from . import _internal
from . import op
Expand Down Expand Up @@ -105,6 +107,14 @@
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1

# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None

def _int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
return _INT64_TENSOR_SIZE_ENABLED

def _new_empty_handle():
"""Returns a new empty handle.
Expand Down Expand Up @@ -132,14 +142,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
if sys.version_info[0] > 2 and _int64_enabled():
check_call(_LIB.MXNDArrayCreateEx64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
else:
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl


Expand Down Expand Up @@ -2218,9 +2238,14 @@ def shape(self):
(2L, 3L, 4L)
"""
ndim = mx_int()
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
if _int64_enabled():
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShapeEx64(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
else:
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
if ndim.value == -1:
return None
else:
Expand Down
Loading