Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
enable default large tensor in np
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 3, 2020
1 parent 0c8b6b2 commit 0e9bcac
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." ON)
option(USE_TENSORRT "Enable inference optimization with TensorRT." OFF)
option(USE_ASAN "Enable Clang/GCC ASAN sanitizers." OFF)
cmake_dependent_option(ENABLE_TESTCOVERAGE "Enable compilation with test coverage metric output" OFF "NOT MSVC" OFF)
option(USE_INT64_TENSOR_SIZE "Use int64_t to represent the total number of elements in a tensor" OFF)
option(USE_INT64_TENSOR_SIZE "Use int64_t to represent the total number of elements in a tensor" ON)
option(BUILD_CYTHON_MODULES "Build cython modules." OFF)
option(LOG_FATAL_THROW "Log exceptions but do not abort" ON)
cmake_dependent_option(USE_SPLIT_ARCH_DLL "Build a separate DLL for each Cuda arch (Windows only)." ON "MSVC" OFF)
Expand Down
104 changes: 87 additions & 17 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
get_oshape_of_gather_nd_op
from ..ndarray._internal import _set_np_ndarray_class
from . import _op as _mx_np_op
from ..base import check_call, _LIB, NDArrayHandle, c_array
from ..base import check_call, _LIB, NDArrayHandle, c_array, mx_int, mx_int64
from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types
from ..runtime import Features
from ..context import Context
from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\
is_np_default_dtype
Expand Down Expand Up @@ -92,6 +93,16 @@
_NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1
_NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0
_NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1
_SIGNED_INT32_UPPER_LIMIT = (2**31 - 1)

# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None

def _int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
return _INT64_TENSOR_SIZE_ENABLED

# This function is copied from ndarray.py since pylint
# keeps giving false alarm error of undefined-all-variable
Expand All @@ -106,14 +117,37 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): # pylint: disa
A new empty `ndarray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[_np.dtype(dtype).type])),
ctypes.byref(hdl)))
if _int64_enabled():
check_call(_LIB.MXNDArrayCreateEx64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[_np.dtype(dtype).type])),
ctypes.byref(hdl)))
else:
# When shape is larger than uint32 then there is an overflow error at python end itself.
# It needs to be caught here since the call doesn't even reach backend.
array_size = 1
for idx in shape:
array_size = array_size * idx
if array_size > _SIGNED_INT32_UPPER_LIMIT:
raise Exception("[_new_alloc_handle] Size of tensor you are trying to allocate is " +
"larger than 2^31 elements. Please build with flag " +
"USE_INT64_TENSOR_SIZE=1")
if _np.dtype(dtype) == _np.dtype([('bfloat16', _np.uint16)]):
dtype_type = _np.dtype(dtype)
else:
dtype_type = _np.dtype(dtype).type
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[dtype_type])),
ctypes.byref(hdl)))
return hdl


Expand Down Expand Up @@ -399,14 +433,24 @@ def _get_np_basic_indexing(self, key):
)
handle = NDArrayHandle()
flat_self = self.reshape_view(-1)
check_call(
_LIB.MXNDArraySlice(
flat_self.handle,
mx_uint(flat_begin),
mx_uint(flat_end),
ctypes.byref(handle),
if _int64_enabled():
check_call(
_LIB.MXNDArraySlice64(
flat_self.handle,
ctypes.c_int64(flat_begin),
ctypes.c_int64(flat_end),
ctypes.byref(handle),
)
)
else:
check_call(
_LIB.MXNDArraySlice(
flat_self.handle,
ctypes.c_uint32(flat_begin),
ctypes.c_uint32(flat_end),
ctypes.byref(handle),
)
)
)
sliced_shape = self._basic_indexing_sliced_shape(slc_key, self.shape)
sliced = self.__class__(handle=handle, writable=self.writable)
if 0 in sliced_shape:
Expand Down Expand Up @@ -2255,7 +2299,33 @@ def _scatter_set_nd(self, value_nd, indices):

@property
def shape(self):
return super(ndarray, self).shape
"""Tuple of array dimensions.
Examples
--------
>>> x = mx.np.array([1, 2, 3, 4])
>>> x.shape
(4L,)
>>> y = mx.np.zeros((2, 3, 4))
>>> y.shape
(2L, 3L, 4L)
>>> z = mx.np.array(3)
>>> z.shape
()
"""
num_dim = mx_int()
if _int64_enabled():
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShapeEx64(
self.handle, ctypes.byref(num_dim), ctypes.byref(pdata)))
else:
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShapeEx(
self.handle, ctypes.byref(num_dim), ctypes.byref(pdata)))
if num_dim.value == -1:
return None
else:
return tuple(pdata[:num_dim.value]) # pylint: disable=invalid-slice-index

@property
def ndim(self):
Expand Down
101 changes: 101 additions & 0 deletions tests/python/unittest/test_np_large_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import sys
import tempfile
import math
import numpy as np
import mxnet as mx

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../python/unittest/'))

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, use_np
from mxnet import gluon, nd
from common import with_seed
import pytest


# dimension constants
SMALL_X = 100
MEDIUM_X = 10000
LARGE_X = 100000000
VLARGE_X = 4300000000
SMALL_Y = 50


@use_np
def test_gluon_embedding():
m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X)
m.initialize()
a = mx.np.zeros((MEDIUM_X, SMALL_Y))
b = m(a)
assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
assert b.asnumpy().size == LARGE_X * SMALL_Y

@use_np
def test_fully_connected():
a = mx.np.ones(shape=(LARGE_X, SMALL_Y))
b = mx.np.ones(shape=(SMALL_Y, SMALL_Y))
c = mx.np.ones(shape=(b.shape[0],))

# w/o bias
res = mx.npx.FullyConnected(a, b, num_hidden=b.shape[0], no_bias=True)
assert np.sum(res[-1].asnumpy() == a.shape[1]) == b.shape[0]

# w/ bias
res = mx.npx.FullyConnected(a, b, c, num_hidden=b.shape[0], no_bias=False)
assert np.sum(res[-1].asnumpy() == a.shape[1] + 1) == b.shape[0]

@use_np
def test_dense():
data = mx.np.ones(shape=(50*1000*1000, 100))
linear = gluon.nn.Dense(100)
linear.initialize()
res = linear(data)
assert res.shape == (50000000, 100)

@use_np
def test_softmax():
input_data = mx.np.ones((SMALL_Y, LARGE_X))
for axis in [0, 1]:
true_output = np.full((SMALL_Y, LARGE_X), (1 / input_data.shape[axis]))
output = npx.softmax(input_data, axis=axis)
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)

@use_np
def test_softmax_cross_entropy():
# dtype of input data, mxnet cross entropy set explicitly to float64
# numpy implicitly takes care of double precision
batch_size = SMALL_Y
num_labels = LARGE_X
input_data = mx.np.ones((batch_size, num_labels), dtype="float64")
input_label = mx.np.zeros((batch_size,), dtype="float64")
true_softmax = np.full((batch_size, num_labels), (1 / num_labels))
# use 1/batch_size when softmax axis=0
# here 1/num_labels since softmax_cross_entropy uses default axis
# by default axis=1
np_one_hot_label = np.zeros((batch_size, num_labels))
np_one_hot_label[:, 0] = 1
true_softmax_cross_entropy = np.sum(-np.log(true_softmax) *
np_one_hot_label)
mx_softmax_cross_entropy = mx.npx.softmax_cross_entropy(input_data,
input_label,
dtype="float64")
assert_almost_equal(mx_softmax_cross_entropy.asnumpy(),
true_softmax_cross_entropy, rtol=1e-3, atol=1e-5)

0 comments on commit 0e9bcac

Please sign in to comment.