From 51ce493d6f8b1a0bcd566eee4e314bda4b8b3aa5 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 14 Oct 2019 05:17:10 +0000 Subject: [PATCH 1/3] fix dtype bug --- python/mxnet/ndarray/ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 4e3c7efa7be3..9dca471a290b 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2428,7 +2428,7 @@ def dtype(self): mx_dtype = ctypes.c_int() check_call(_LIB.MXNDArrayGetDType( self.handle, ctypes.byref(mx_dtype))) - return _DTYPE_MX_TO_NP[mx_dtype.value] + return np.dtype(_DTYPE_MX_TO_NP[mx_dtype.value]) @property def stype(self): From a27a9b7ca069859bcc479c77cf10f31d58273353 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 14 Oct 2019 07:51:53 +0000 Subject: [PATCH 2/3] override dtype property; add tests --- python/mxnet/ndarray/ndarray.py | 2 +- python/mxnet/numpy/multiarray.py | 20 ++++++++++++++++++ tests/python/unittest/test_numpy_ndarray.py | 23 +++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 9dca471a290b..4e3c7efa7be3 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2428,7 +2428,7 @@ def dtype(self): mx_dtype = ctypes.c_int() check_call(_LIB.MXNDArrayGetDType( self.handle, ctypes.byref(mx_dtype))) - return np.dtype(_DTYPE_MX_TO_NP[mx_dtype.value]) + return _DTYPE_MX_TO_NP[mx_dtype.value] @property def stype(self): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5ee52f14bb16..2478b974ddf6 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1680,6 +1680,26 @@ def size(self): """Number of elements in the array.""" return super(ndarray, self).size + @property + def dtype(self): + """Data-type of the array's elements. + + Returns + ------- + numpy.dtype + This NDArray's data type. + + Examples + -------- + >>> x = np.zeros((2,3)) + >>> x.dtype + dtype('float32') + >>> y = np.zeros((2,3), dtype='int32') + >>> y.dtype + dtype('int32') + """ + return _np.dtype(super(ndarray, self).dtype) + def tostype(self, stype): raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype') diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 53a8076f1303..b6692c009c18 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1075,6 +1075,29 @@ def test_boolean_indexing_autograd(): test_boolean_indexing_autograd() +@with_seed() +@use_np +def test_np_get_dtype(): + dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, _np.bool, _np.bool_, + 'int8', 'int32', 'float16', 'float32', 'float64', 'bool', None] + objects = [ + [], + (), + [[1, 2], [3, 4]], + _np.random.uniform(size=rand_shape_nd(3)), + _np.random.uniform(size=(3, 0, 4)) + ] + for dtype in dtypes: + for src in objects: + mx_arr = np.array(src, dtype=dtype) + assert mx_arr.ctx == mx.current_context() + if isinstance(src, mx.nd.NDArray): + np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) + else: + np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32) + assert type(mx_arr.dtype) == type(np_arr.dtype) + + if __name__ == '__main__': import nose nose.runmodule() From 791024a267801d16e34a4d258344c3061ac154d0 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 15 Oct 2019 01:50:46 +0000 Subject: [PATCH 3/3] fix display bug --- python/mxnet/numpy/multiarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 2478b974ddf6..4fff4421e5a5 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -706,7 +706,7 @@ def __repr__(self): if dtype == _np.float32: array_str = array_str[:array_str.rindex(',')] + ')' elif dtype not in (_np.float32, _np.bool_): - array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__) + array_str = array_str[:-1] + ', dtype={})'.format(dtype) context = self.context if context.device_type == 'cpu':