Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adding tests to verify support for Large Tensors in additional Ops al…
Browse files Browse the repository at this point in the history
…ong with new C_Apis supporting 64bit indexing
  • Loading branch information
Rohit Kumar Srivastava committed Aug 20, 2019
1 parent 3dfb19a commit 6d8ef8f
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 15 deletions.
10 changes: 10 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
void *data,
size_t size);

/*!
* \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0
* This function blocks. Do not use it in performance critical code.
Expand Down Expand Up @@ -790,6 +791,11 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_end,
NDArrayHandle *out);

MXNET_DLL int MXNDArraySlice64(NDArrayHandle handle,
int64_t slice_begin,
int64_t slice_end,
NDArrayHandle *out);

/*!
* \brief Index the NDArray along axis 0.
* \param handle the handle to the NDArray
Expand All @@ -801,6 +807,10 @@ MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayAt64(NDArrayHandle handle,
int64_t idx,
NDArrayHandle *out);

/*!
* \brief get the storage type of the array
*/
Expand Down
32 changes: 23 additions & 9 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,14 +932,24 @@ def _get_nd_basic_indexing(self, key):
)
handle = NDArrayHandle()
flat_self = self.reshape(-1)
check_call(
_LIB.MXNDArraySlice(
flat_self.handle,
mx_uint(flat_begin),
mx_uint(flat_end),
ctypes.byref(handle),
if sys.version_info[0] > 2 and _int64_enabled():
check_call(
_LIB.MXNDArraySlice64(
flat_self.handle,
ctypes.c_int64(flat_begin),
ctypes.c_int64(flat_end),
ctypes.byref(handle),
)
)
else:
check_call(
_LIB.MXNDArraySlice(
flat_self.handle,
ctypes.c_uint32(flat_begin),
ctypes.c_uint32(flat_end),
ctypes.byref(handle),
)
)
)
sliced_shape = self._basic_indexing_sliced_shape(slc_key, self.shape)
sliced = NDArray(handle=handle, writable=self.writable).reshape(sliced_shape)
else:
Expand Down Expand Up @@ -1235,8 +1245,12 @@ def _at(self, idx):
if idx < 0:
raise IndexError('index %d is out of bounds for axis 0 with size %d'
% (idx-length, length))
check_call(_LIB.MXNDArrayAt(
self.handle, mx_uint(idx), ctypes.byref(handle)))
if sys.version_info[0] > 2 and _int64_enabled():
check_call(_LIB.MXNDArrayAt64(
self.handle, ctypes.c_int64(idx), ctypes.byref(handle)))
else:
check_call(_LIB.MXNDArrayAt(
self.handle, ctypes.c_uint32(idx), ctypes.byref(handle)))
return self.__class__(handle=handle, writable=self.writable)

def reshape(self, *shape, **kwargs):
Expand Down
33 changes: 29 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,20 +451,35 @@ int MXNDArrayFree(NDArrayHandle handle) {
API_END();
}

template<typename dtype>
void SliceArray(NDArrayHandle handle, dtype slice_begin, dtype slice_end, NDArray* ptr,
NDArrayHandle* out) {
*ptr = static_cast<NDArray*>(handle)->SliceWithRecord(slice_begin, slice_end);
*out = ptr;
}

int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
*ptr = static_cast<NDArray*>(handle)->SliceWithRecord(
slice_begin, slice_end);
*out = ptr;
SliceArray<uint32_t>(handle, slice_begin, slice_end, ptr, out);
API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArraySlice64(NDArrayHandle handle,
int64_t slice_begin,
int64_t slice_end,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
SliceArray<int64_t>(handle, slice_begin, slice_end, ptr, out);
API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
uint32_t idx,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
Expand All @@ -473,6 +488,16 @@ int MXNDArrayAt(NDArrayHandle handle,
API_END_HANDLE_ERROR(delete ptr);
}

int MXNDArrayAt64(NDArrayHandle handle,
int64_t idx,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
*ptr = static_cast<NDArray*>(handle)->AtWithRecord(idx);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}

MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
int ndim,
int *dims,
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void DotForward_(const nnvm::NodeAttrs& attrs,
inputs[0].get<xpu, 1, DType>(s),
inputs[1].get<xpu, 1, DType>(s));
} else {
int ma, na, mb, nb, m, n;
index_t ma, na, mb, nb, m, n;
if (param.transpose_a) {
ma = inputs[0].size(0);
na = inputs[0].Size()/ma;
Expand Down
Loading

0 comments on commit 6d8ef8f

Please sign in to comment.