diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f1a5106a95d..cb22c5976122 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -642,7 +642,6 @@ if(UNIX) endif() target_link_libraries(mxnet PUBLIC mshadow) target_link_libraries(mxnet PUBLIC ${CMAKE_DL_LIBS}) - target_compile_definitions(mxnet PUBLIC DMLC_LOG_FATAL_THROW=$) if(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") target_compile_options(mxnet PRIVATE "$<$:-Werror>") # Ignore erroneous compiler warnings: @@ -669,7 +668,6 @@ elseif(MSVC) foreach(arch ${arch_code_list}) add_library(mxnet_${arch} SHARED ${SOURCE}) target_link_libraries(mxnet_${arch} PUBLIC mshadow) - target_compile_definitions(mxnet_${arch} PUBLIC DMLC_LOG_FATAL_THROW=$) target_compile_options( mxnet_${arch} PRIVATE @@ -705,10 +703,10 @@ elseif(MSVC) endif(USE_SPLIT_ARCH_DLL) else() add_library(mxnet SHARED ${SOURCE}) - target_compile_definitions(mxnet PUBLIC DMLC_LOG_FATAL_THROW=$) target_link_libraries(mxnet PUBLIC mshadow) endif() endif() +target_compile_definitions(mxnet PUBLIC DMLC_LOG_FATAL_THROW=$) # extension libraries (custom operators, custom subgraphs) are built by default add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 65687fff54a9..0b4bdf9a97c2 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -309,7 +309,6 @@ def _load_lib(): CudaModuleHandle = ctypes.c_void_p CudaKernelHandle = ctypes.c_void_p ProfileHandle = ctypes.c_void_p -DLPackHandle = ctypes.c_void_p #---------------------------- diff --git a/python/mxnet/dlpack.py b/python/mxnet/dlpack.py new file mode 100644 index 000000000000..b5e8ee83304e --- /dev/null +++ b/python/mxnet/dlpack.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=protected-access +# pylint: disable=import-error, no-name-in-module, undefined-variable + +"""DLPack API of MXNet.""" + +import ctypes +from .base import _LIB, c_str, check_call, NDArrayHandle + +DLPackHandle = ctypes.c_void_p + +PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +_c_str_dltensor = c_str('dltensor') +_c_str_used_dltensor = c_str('used_dltensor') + +def _dlpack_deleter(pycapsule): + pycapsule = ctypes.c_void_p(pycapsule) + if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor): + ptr = ctypes.c_void_p( + ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)) + check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr)) + +_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter) + +class DLContext(ctypes.Structure): + _fields_ = [("device_type", ctypes.c_int), + ("device_id", ctypes.c_int)] + +class DLDataType(ctypes.Structure): + _fields_ = [("type_code", ctypes.c_uint8), + ("bits", ctypes.c_uint8), + ("lanes", ctypes.c_uint16)] + TYPE_MAP = { + "int32": (0, 32, 1), + "int64": (0, 64, 1), + "bool": (1, 1, 1), + "uint8": (1, 8, 1), + "uint32": (1, 32, 1), + "uint64": (1, 64, 1), + 'float16': (2, 16, 1), + "float32": (2, 32, 1), + "float64": (2, 64, 1), + } + + +class DLTensor(ctypes.Structure): + _fields_ = [("data", ctypes.c_void_p), + ("ctx", DLContext), + ("ndim", ctypes.c_int), + ("dtype", DLDataType), + ("shape", ctypes.POINTER(ctypes.c_int64)), + ("strides", ctypes.POINTER(ctypes.c_int64)), + ("byte_offset", ctypes.c_uint64)] + +class DLManagedTensor(ctypes.Structure): + pass + + +DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor)) + + +DLManagedTensor._fields_ = [("dl_tensor", DLTensor), # pylint: disable=protected-access + ("manager_ctx", ctypes.c_void_p), + ("deleter", DeleterFunc)] + +@DeleterFunc +def dl_managed_tensor_deleter(dl_managed_tensor_handle): + void_p = dl_managed_tensor_handle.contents.manager_ctx + pyobj = ctypes.cast(void_p, ctypes.py_object) + ctypes.pythonapi.Py_DecRef(pyobj) + +def ndarray_from_dlpack(array_cls): + """Returns a function that returns specified array_cls from dlpack. + + Returns + ------- + fn : dlpack -> array_cls + """ + def from_dlpack(dlpack): + handle = NDArrayHandle() + dlpack = ctypes.py_object(dlpack) + assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError( + 'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.') + dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor)) + check_call(_LIB.MXNDArrayFromDLPackEx(dlpack_handle, False, ctypes.byref(handle))) + # Rename PyCapsule (DLPack) + ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor) + # delete the deleter of the old dlpack + ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None) + return array_cls(handle=handle) + return from_dlpack + + +def ndarray_to_dlpack_for_read(): + """Returns a function that returns dlpack for reading from mxnet array. + + Returns + ------- + fn : tensor -> dlpack + """ + def to_dlpack_for_read(data): + data.wait_to_read() + dlpack = DLPackHandle() + check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) + return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) + return to_dlpack_for_read + +def ndarray_to_dlpack_for_write(): + """Returns a function that returns dlpack for writing from mxnet array. + + Returns + ------- + fn : tensor -> dlpack + """ + def to_dlpack_for_write(data): + + check_call(_LIB.MXNDArrayWaitToWrite(data.handle)) + dlpack = DLPackHandle() + check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) + return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) + return to_dlpack_for_write + +def ndarray_from_numpy(array_cls, array_create_fn): + """Returns a function that creates array_cls from numpy array. + + Returns + ------- + fn : tensor -> dlpack + """ + def from_numpy(ndarray, zero_copy=True): + def _make_manager_ctx(obj): + pyobj = ctypes.py_object(obj) + void_p = ctypes.c_void_p.from_buffer(pyobj) + ctypes.pythonapi.Py_IncRef(pyobj) + return void_p + + def _make_dl_tensor(array): + if str(array.dtype) not in DLDataType.TYPE_MAP: + raise ValueError(str(array.dtype) + " is not supported.") + dl_tensor = DLTensor() + dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p) + dl_tensor.ctx = DLContext(1, 0) + dl_tensor.ndim = array.ndim + dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)] + dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64) + dl_tensor.strides = None + dl_tensor.byte_offset = 0 + return dl_tensor + + def _make_dl_managed_tensor(array): + c_obj = DLManagedTensor() + c_obj.dl_tensor = _make_dl_tensor(array) + c_obj.manager_ctx = _make_manager_ctx(array) + c_obj.deleter = dl_managed_tensor_deleter + return c_obj + + if not zero_copy: + return array_create_fn(ndarray, dtype=ndarray.dtype) + + if not ndarray.flags['C_CONTIGUOUS']: + raise ValueError("Only c-contiguous arrays are supported for zero-copy") + + ndarray.flags['WRITEABLE'] = False + c_obj = _make_dl_managed_tensor(ndarray) + handle = NDArrayHandle() + check_call(_LIB.MXNDArrayFromDLPackEx(ctypes.byref(c_obj), True, ctypes.byref(handle))) + return array_cls(handle=handle) + return from_numpy diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 9cc8b8942c1d..fa26dfff9628 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -34,9 +34,11 @@ from functools import reduce # pylint: disable=redefined-builtin import numpy as np from ..base import _LIB, numeric_types, integer_types -from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t -from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64 +from ..base import c_array, c_array_buf, c_handle_array, mx_real_t +from ..base import mx_uint, NDArrayHandle, check_call, mx_int, mx_int64 from ..base import ctypes2buffer +from ..dlpack import ndarray_to_dlpack_for_read, ndarray_to_dlpack_for_write +from ..dlpack import ndarray_from_dlpack, ndarray_from_numpy from ..runtime import Features from ..context import Context, current_context from ..util import is_np_array @@ -70,9 +72,28 @@ np.int8: 5, np.int64: 6, np.bool_: 7, + np.int16: 8, + np.uint16 : 9, + np.uint32 : 10, + np.uint64 : 11, np.dtype([('bfloat16', np.uint16)]): 12, } +def _register_platform_dependent_mx_dtype(): + """Register platform dependent types to the fixed size counterparts.""" + kind_map = {'i': 'int', 'u': 'uint', 'f': 'float'} + for np_type in [ + np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, np.int_, + np.uint, np.longlong, np.ulonglong, np.half, np.float16, np.single, + np.double, np.longdouble]: + dtype = np.dtype(np_type) + kind, size = dtype.kind, dtype.itemsize + bits = size * 8 + fixed_dtype = getattr(np, kind_map[kind]+str(bits)) + if fixed_dtype in _DTYPE_NP_TO_MX: + _DTYPE_NP_TO_MX[np_type] = _DTYPE_NP_TO_MX[fixed_dtype] +_register_platform_dependent_mx_dtype() + _DTYPE_MX_TO_NP = { -1: None, 0: np.float32, @@ -83,6 +104,10 @@ 5: np.int8, 6: np.int64, 7: np.bool_, + 8: np.int16, + 9: np.uint16, + 10: np.uint32, + 11: np.uint64, 12: np.dtype([('bfloat16', np.uint16)]), } @@ -4914,32 +4939,18 @@ def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False): raise ValueError('indices_or_sections must either int or tuple of ints') return _internal._split_v2(ary, indices, axis, squeeze_axis) -PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) -_c_str_dltensor = c_str('dltensor') -_c_str_used_dltensor = c_str('used_dltensor') - -def _dlpack_deleter(pycapsule): - pycapsule = ctypes.c_void_p(pycapsule) - if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor): - ptr = ctypes.c_void_p( - ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)) - check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr)) - -_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter) - -def to_dlpack_for_read(data): - """Returns a reference view of NDArray that represents as DLManagedTensor until - all previous write operations on the current array are finished. +from_dlpack = ndarray_from_dlpack(NDArray) +from_dlpack_doc = """Returns a NDArray backed by a dlpack tensor. Parameters ---------- - data: NDArray - input data. + dlpack: PyCapsule (the pointer of DLManagedTensor) + input data Returns ------- - PyCapsule (the pointer of DLManagedTensor) - a reference view of NDArray that represents as DLManagedTensor. + NDArray + a NDArray backed by a dlpack tensor Examples -------- @@ -4948,33 +4959,13 @@ def to_dlpack_for_read(data): >>> type(y) >>> z = mx.nd.from_dlpack(y) + >>> type(z) + >>> z - [[1. 1. 1.] - [1. 1. 1.]] + [[ 1. 1. 1.] + [ 1. 1. 1.]] - """ - data.wait_to_read() - dlpack = DLPackHandle() - check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) - return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) - -def to_dlpack_for_write(data): - """Returns a reference view of NDArray that represents as DLManagedTensor until - all previous read/write operations on the current array are finished. - Parameters - ---------- - data: NDArray - input data. - - Returns - ------- - PyCapsule (the pointer of DLManagedTensor) - a reference view of NDArray that represents as DLManagedTensor. - - Examples - -------- - >>> x = mx.nd.ones((2,3)) >>> w = mx.nd.to_dlpack_for_write(x) >>> type(w) @@ -4985,23 +4976,45 @@ def to_dlpack_for_write(data): [2. 2. 2.]] """ - check_call(_LIB.MXNDArrayWaitToWrite(data.handle)) - dlpack = DLPackHandle() - check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) - return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) +from_dlpack.__doc__ = from_dlpack_doc -def from_dlpack(dlpack): - """Returns a NDArray backed by a dlpack tensor. +from_numpy = ndarray_from_numpy(NDArray, array) +from_numpy_doc = """Returns an MXNet's NDArray backed by numpy's ndarray. + When `zero_copy` is set to be true, + this API consumes numpy's ndarray and produces MXNet's ndarray + without having to copy the content. In this case, we disallow + users to modify the given numpy ndarray, and it is suggested + not to read the numpy ndarray as well for internal correctness. Parameters ---------- - dlpack: PyCapsule (the pointer of DLManagedTensor) + ndarray: NDArray input data + zero_copy: bool + Whether we use DLPack's zero-copy conversion to convert to MXNet's NDArray. + This is only available for c-contiguous arrays, i.e. array.flags[C_CONTIGUOUS] == True. Returns ------- NDArray a NDArray backed by a dlpack tensor +""" +from_numpy.__doc__ = from_numpy_doc + + +to_dlpack_for_read = ndarray_to_dlpack_for_read() +to_dlpack_for_read_doc = """Returns a reference view of NDArray that represents as DLManagedTensor until + all previous write operations on the current array are finished. + + Parameters + ---------- + data: NDArray + input data. + + Returns + ------- + PyCapsule (the pointer of DLManagedTensor) + a reference view of NDArray that represents as DLManagedTensor. Examples -------- @@ -5010,13 +5023,30 @@ def from_dlpack(dlpack): >>> type(y) >>> z = mx.nd.from_dlpack(y) - >>> type(z) - >>> z - [[ 1. 1. 1.] - [ 1. 1. 1.]] + [[1. 1. 1.] + [1. 1. 1.]] +""" +to_dlpack_for_read.__doc__ = to_dlpack_for_read_doc + +to_dlpack_for_write = ndarray_to_dlpack_for_write() +to_dlpack_for_write_doc = """Returns a reference view of NDArray that represents as +DLManagedTensor until all previous read/write operations on the current array are finished. + + Parameters + ---------- + data: NDArray + input data. + + Returns + ------- + PyCapsule (the pointer of DLManagedTensor) + a reference view of NDArray that represents as DLManagedTensor. + Examples + -------- + >>> x = mx.nd.ones((2,3)) >>> w = mx.nd.to_dlpack_for_write(x) >>> type(w) @@ -5026,128 +5056,5 @@ def from_dlpack(dlpack): [[2. 2. 2.] [2. 2. 2.]] - """ - handle = NDArrayHandle() - dlpack = ctypes.py_object(dlpack) - assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError( - 'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.') - dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor)) - check_call(_LIB.MXNDArrayFromDLPackEx(dlpack_handle, False, ctypes.byref(handle))) - # Rename PyCapsule (DLPack) - ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor) - # delete the deleter of the old dlpack - ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None) - return NDArray(handle=handle) - -class DLContext(ctypes.Structure): - _fields_ = [("device_type", ctypes.c_int), - ("device_id", ctypes.c_int)] - - -class DLDataType(ctypes.Structure): - _fields_ = [("type_code", ctypes.c_uint8), - ("bits", ctypes.c_uint8), - ("lanes", ctypes.c_uint16)] - TYPE_MAP = { - "int32": (0, 32, 1), - "int64": (0, 64, 1), - "bool": (1, 1, 1), - "uint8": (1, 8, 1), - "uint32": (1, 32, 1), - "uint64": (1, 64, 1), - 'float16': (2, 16, 1), - "float32": (2, 32, 1), - "float64": (2, 64, 1), - } - - -class DLTensor(ctypes.Structure): - _fields_ = [("data", ctypes.c_void_p), - ("ctx", DLContext), - ("ndim", ctypes.c_int), - ("dtype", DLDataType), - ("shape", ctypes.POINTER(ctypes.c_int64)), - ("strides", ctypes.POINTER(ctypes.c_int64)), - ("byte_offset", ctypes.c_uint64)] - -class DLManagedTensor(ctypes.Structure): - pass - - -DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor)) - - -DLManagedTensor._fields_ = [("dl_tensor", DLTensor), # pylint: disable=protected-access - ("manager_ctx", ctypes.c_void_p), - ("deleter", DeleterFunc)] - - -@DeleterFunc -def dl_managed_tensor_deleter(dl_managed_tensor_handle): - void_p = dl_managed_tensor_handle.contents.manager_ctx - pyobj = ctypes.cast(void_p, ctypes.py_object) - ctypes.pythonapi.Py_DecRef(pyobj) - - -def from_numpy(ndarray, zero_copy=True, array_cls=NDArray): - """Returns an MXNet's ndarray backed by numpy's ndarray. - When `zero_copy` is set to be true, - this API consumes numpy's ndarray and produces MXNet's ndarray - without having to copy the content. In this case, we disallow - users to modify the given numpy ndarray, and it is suggested - not to read the numpy ndarray as well for internal correctness. - - Parameters - ---------- - ndarray: numpy.ndarray - input data - zero_copy: bool - Whether we use DLPack's zero-copy conversion to convert to MXNet's NDArray. - This is only available for c-contiguous arrays, i.e. array.flags[C_CONTIGUOUS] == True. - array_cls: ndarray class type - The class type of the output array. - - Returns - ------- - NDArray - a NDArray backed by a dlpack tensor - - """ - - def _make_manager_ctx(obj): - pyobj = ctypes.py_object(obj) - void_p = ctypes.c_void_p.from_buffer(pyobj) - ctypes.pythonapi.Py_IncRef(pyobj) - return void_p - - def _make_dl_tensor(array): - if str(array.dtype) not in DLDataType.TYPE_MAP: - raise ValueError(str(array.dtype) + " is not supported.") - dl_tensor = DLTensor() - dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p) - dl_tensor.ctx = DLContext(1, 0) - dl_tensor.ndim = array.ndim - dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)] - dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64) - dl_tensor.strides = None - dl_tensor.byte_offset = 0 - return dl_tensor - - def _make_dl_managed_tensor(array): - c_obj = DLManagedTensor() - c_obj.dl_tensor = _make_dl_tensor(array) - c_obj.manager_ctx = _make_manager_ctx(array) - c_obj.deleter = dl_managed_tensor_deleter - return c_obj - - if not zero_copy: - return array(ndarray, dtype=ndarray.dtype) - - if not ndarray.flags['C_CONTIGUOUS']: - raise ValueError("Only c-contiguous arrays are supported for zero-copy") - - ndarray.flags['WRITEABLE'] = False - c_obj = _make_dl_managed_tensor(ndarray) - handle = NDArrayHandle() - check_call(_LIB.MXNDArrayFromDLPackEx(ctypes.byref(c_obj), True, ctypes.byref(handle))) - return array_cls(handle=handle) +""" +to_dlpack_for_write.__doc__ = to_dlpack_for_write_doc diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 91fea5f4aeef..14dbf94a1417 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -854,7 +854,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou result array or scalar """ from ...numpy import ndarray - from ..ndarray import from_numpy # pylint: disable=unused-import + from ...numpy_extension import from_numpy # pylint: disable=unused-import if isinstance(lhs, numeric_types): if isinstance(rhs, numeric_types): return fn_scalar(lhs, rhs, out=out) @@ -8049,7 +8049,7 @@ def shares_memory(a, b, max_work=None): the following way(s): - Does not support `max_work`, it is a dummy argument - - Actually it is same as `may_share_memory` in MXNet DeepNumPy + - Actually it is same as `may_share_memory` in MXNet np """ return _api_internal.share_memory(a, b).item() @@ -8090,7 +8090,7 @@ def may_share_memory(a, b, max_work=None): the following way(s): - Does not support `max_work`, it is a dummy argument - - Actually it is same as `shares_memory` in MXNet DeepNumPy + - Actually it is same as `shares_memory` in MXNet np """ return _api_internal.share_memory(a, b).item() diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 41e573c76111..c2a2c0bf78ec 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -390,7 +390,7 @@ def multivariate_normal(mean, cov, size=None, check_valid=None, tol=None): This operator is a little different from the one in official NumPy. The official NumPy operator only accepts 1-D ndarray as mean and 2-D ndarray as cov, - whereas the operator in DeepNumPy supports batch operation and auto-broadcasting. + whereas the operator in MXNet np supports batch operation and auto-broadcasting. Both `mean` and `cov` may have any number of leading dimensions, which correspond to a batch shape. They are not necessarily assumed to have the same batch shape, diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index ddecaea37aa3..5274408e4403 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -49,7 +49,7 @@ from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi -from ..ndarray.ndarray import _storage_type, from_numpy +from ..ndarray.ndarray import _storage_type from .utils import _get_np_op from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback @@ -182,11 +182,11 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name def _as_mx_np_array(object, ctx=None): """Convert object to mxnet.numpy.ndarray.""" - if isinstance(object, _np.ndarray): - if not object.flags['C_CONTIGUOUS']: - object = _np.ascontiguousarray(object, dtype=object.dtype) - ret = from_numpy(object, array_cls=ndarray) - return ret if ctx is None else ret.as_in_ctx(ctx=ctx) + if isinstance(object, ndarray): + return object + elif isinstance(object, _np.ndarray): + np_dtype = _np.dtype(object.dtype).type + return array(object, dtype=np_dtype, ctx=ctx) elif isinstance(object, (integer_types, numeric_types)): return object elif isinstance(object, (list, tuple)): @@ -10171,7 +10171,7 @@ def shares_memory(a, b, max_work=None): the following way(s): - Does not support `max_work`, it is a dummy argument - - Actually it is same as `may_share_memory` in MXNet DeepNumPy + - Actually it is same as `may_share_memory` in MXNet np """ return _mx_nd_np.shares_memory(a, b, max_work) @@ -10212,7 +10212,7 @@ def may_share_memory(a, b, max_work=None): the following way(s): - Does not support `max_work`, it is a dummy argument - - Actually it is same as `shares_memory` in MXNet DeepNumPy + - Actually it is same as `shares_memory` in MXNet np """ return _mx_nd_np.may_share_memory(a, b, max_work) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 127d7d7da1d7..e739036e2c71 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -430,7 +430,7 @@ def multivariate_normal(mean, cov, size=None, check_valid=None, tol=None): This operator is a little different from the one in official NumPy. The official NumPy operator only accepts 1-D ndarray as mean and 2-D ndarray as cov, - whereas the operator in DeepNumPy supports batch operation and auto-broadcasting. + whereas the operator in MXNet np supports batch operation and auto-broadcasting. Both `mean` and `cov` may have any number of leading dimensions, which correspond to a batch shape. They are not necessarily assumed to have the same batch shape, diff --git a/python/mxnet/numpy_extension/utils.py b/python/mxnet/numpy_extension/utils.py index f625439335d5..6d3f25b7f0a8 100644 --- a/python/mxnet/numpy_extension/utils.py +++ b/python/mxnet/numpy_extension/utils.py @@ -20,25 +20,15 @@ import ctypes -from .. util import is_np_array, is_np_shape -from .. base import _LIB, check_call, string_types, c_str_array, DLPackHandle -from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str -from ..numpy import ndarray +from ..util import is_np_array, is_np_shape +from ..base import _LIB, check_call, string_types, c_str_array +from ..base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str +from ..dlpack import ndarray_to_dlpack_for_read, ndarray_to_dlpack_for_write +from ..dlpack import ndarray_from_dlpack, ndarray_from_numpy +from ..numpy import ndarray, array -__all__ = ['save', 'load', 'to_dlpack_for_read', 'to_dlpack_for_write', 'from_dlpack'] - -PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) -_c_str_dltensor = c_str('dltensor') -_c_str_used_dltensor = c_str('used_dltensor') - -def _dlpack_deleter(pycapsule): - pycapsule = ctypes.c_void_p(pycapsule) - if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor): - ptr = ctypes.c_void_p( - ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)) - check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr)) - -_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter) +__all__ = ['save', 'load', 'to_dlpack_for_read', 'to_dlpack_for_write', + 'from_dlpack', 'from_numpy'] def save(file, arr): """Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file. @@ -132,9 +122,8 @@ def load(file): (py_str(names[i]), ndarray(NDArrayHandle(handles[i]))) for i in range(out_size.value)) - -def from_dlpack(dlpack): - """Returns a np.ndarray backed by a dlpack tensor. +from_dlpack = ndarray_from_dlpack(ndarray) +from_dlpack_doc = """Returns a np.ndarray backed by a dlpack tensor. Parameters ---------- @@ -168,21 +157,36 @@ def from_dlpack(dlpack): array([[2., 2., 2.], [2., 2., 2.]]) """ - handle = NDArrayHandle() - dlpack = ctypes.py_object(dlpack) - assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError( - 'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.') - dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor)) - check_call(_LIB.MXNDArrayFromDLPackEx(dlpack_handle, False, ctypes.byref(handle))) - # Rename PyCapsule (DLPack) - ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor) - # delete the deleter of the old dlpack - ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None) - return ndarray(handle=handle) - -def to_dlpack_for_read(data): - """Returns a reference view of np.ndarray that represents as DLManagedTensor until - all previous write operations on the current array are finished. +from_dlpack.__doc__ = from_dlpack_doc + + +from_numpy = ndarray_from_numpy(ndarray, array) +from_numpy_doc = """Returns an MXNet's np.ndarray backed by numpy's ndarray. + When `zero_copy` is set to be true, + this API consumes numpy's ndarray and produces MXNet's np.ndarray + without having to copy the content. In this case, we disallow + users to modify the given numpy ndarray, and it is suggested + not to read the numpy ndarray as well for internal correctness. + + Parameters + ---------- + ndarray: np.ndarray + input data + zero_copy: bool + Whether we use DLPack's zero-copy conversion to convert to MXNet's + np.ndarray. + This is only available for c-contiguous arrays, i.e. array.flags[C_CONTIGUOUS] == True. + + Returns + ------- + np.ndarray + a np.ndarray backed by a dlpack tensor + """ +from_numpy.__doc__ = from_numpy_doc + +to_dlpack_for_read = ndarray_to_dlpack_for_read() +to_dlpack_for_read_doc = """Returns a reference view of np.ndarray that represents +as DLManagedTensor until all previous write operations on the current array are finished. Parameters ---------- @@ -205,14 +209,11 @@ def to_dlpack_for_read(data): array([[1., 1., 1.], [1., 1., 1.]]) """ - data.wait_to_read() - dlpack = DLPackHandle() - check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) - return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) +to_dlpack_for_read.__doc__ = to_dlpack_for_read_doc -def to_dlpack_for_write(data): - """Returns a reference view of ndarray that represents as DLManagedTensor until - all previous read/write operations on the current array are finished. +to_dlpack_for_write = ndarray_to_dlpack_for_write() +to_dlpack_for_write_doc = """Returns a reference view of ndarray that represents +as DLManagedTensor until all previous read/write operations on the current array are finished. Parameters ---------- @@ -236,7 +237,4 @@ def to_dlpack_for_write(data): array([[2., 2., 2.], [2., 2., 2.]]) """ - check_call(_LIB.MXNDArrayWaitToWrite(data.handle)) - dlpack = DLPackHandle() - check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) - return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) +to_dlpack_for_write.__doc__ = to_dlpack_for_write_doc diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 75780df173e9..af834bbeb5d5 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -926,7 +926,7 @@ def multivariate_normal(mean, cov, size=None, check_valid=None, tol=None): This operator is a little different from the one in official NumPy. The official NumPy operator only accepts 1-D ndarray as mean and 2-D ndarray as cov, - whereas the operator in DeepNumPy supports batch operation and auto-broadcasting. + whereas the operator in MXNet np supports batch operation and auto-broadcasting. Both `mean` and `cov` may have any number of leading dimensions, which correspond to a batch shape. They are not necessarily assumed to have the same batch shape, diff --git a/tests/python/unittest/test_base.py b/tests/python/unittest/test_base.py index 74d3f17a645e..2175d7b9a062 100644 --- a/tests/python/unittest/test_base.py +++ b/tests/python/unittest/test_base.py @@ -25,7 +25,9 @@ import logging import os.path as op import platform +import pytest +@pytest.mark.garbage_expected def test_environment(): name1 = 'MXNET_TEST_ENV_VAR_1' name2 = 'MXNET_TEST_ENV_VAR_2' diff --git a/tests/python/unittest/test_gluon_probability_v1.py b/tests/python/unittest/test_gluon_probability_v1.py index 0fece99bb6d7..82395ddf86f5 100644 --- a/tests/python/unittest/test_gluon_probability_v1.py +++ b/tests/python/unittest/test_gluon_probability_v1.py @@ -540,7 +540,7 @@ def hybrid_forward(self, F, n, params, *args): # Test log_prob for shape, hybridize, use_logit in itertools.product(shapes, [True, False], [True, False]): n = np.random.randint(1, 10, size=shape).astype('float32') - prob = np.random.uniform(low=0.1, size=shape) + prob = np.random.uniform(low=0.1, size=shape).astype('float32') sample = np.random.randint(0, 10, size=shape).astype('float32') param = prob if use_logit: @@ -559,7 +559,7 @@ def hybrid_forward(self, F, n, params, *args): for func in ['mean', 'variance']: for use_logit in [True, False]: n = np.random.randint(1, 10, size=shape).astype('float32') - prob = np.random.uniform(low=0.1, size=shape) + prob = np.random.uniform(low=0.1, size=shape).astype('float32') net = TestNegativeBinomial(func, use_logit) param = prob if use_logit: @@ -2015,7 +2015,7 @@ def hybrid_forward(self, F, logit, *args): def test_gluon_kl_v1(): def _test_zero_kl(p, shape): """Check if KL(p || p) = 0 - + Parameters ---------- p : Distribution diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 966b26d7e2d2..f1cd9b38621b 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1369,3 +1369,35 @@ def test_dlpack(dtype, size): same(a_np+1, b) same(a_np+2, c) same(a_np+2, a_copy) + +@use_np +@pytest.mark.parametrize('np_array', [ + # ordinary numpy array + _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32"), + # 0-dim + _np.array((1, )).reshape(()), + # 0-size + _np.array(()).reshape((1, 0, 2)), +]) +@pytest.mark.parametrize('zero_copy', [False, True]) +def test_from_numpy(np_array, zero_copy): + # Test zero_copy + mx_array = mx.npx.from_numpy(np_array, zero_copy=zero_copy) + mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy()) + +def test_from_numpy_exception(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.npx.from_numpy(np_array) + with pytest.raises(ValueError): + np_array[2, 1] = 0 + + mx_array[2, 1] = 100 + mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy()) + np_array = _np.array([[1, 2], [3, 4], [5, 6]]).transpose() + assert not np_array.flags["C_CONTIGUOUS"] + with pytest.raises(ValueError): + mx_array = mx.nd.from_numpy(np_array) + + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.npx.from_numpy(np_array, zero_copy=False) + np_array[2, 1] = 0 # no error diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 24970eaf9e5e..a44ba327b3a1 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -498,11 +498,11 @@ def test_relu(): def frelu(x): return np.maximum(x, 0.0) def frelu_grad(x): - return 1.0 * (x > 0.0) + return np.float32(1.0) * (x > np.float32(0.0)) shape = (3, 4) x = mx.symbol.Variable("x") y = mx.sym.relu(x) - xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype('float32') eps = 1e-4 # Avoid finite difference method inaccuracies due to discontinuous gradient at the origin. # Here we replace small problematic inputs with 1.0. Repro issue with seed 97264195.