diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5ee52f14bb16..4fff4421e5a5 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -706,7 +706,7 @@ def __repr__(self): if dtype == _np.float32: array_str = array_str[:array_str.rindex(',')] + ')' elif dtype not in (_np.float32, _np.bool_): - array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__) + array_str = array_str[:-1] + ', dtype={})'.format(dtype) context = self.context if context.device_type == 'cpu': @@ -1680,6 +1680,26 @@ def size(self): """Number of elements in the array.""" return super(ndarray, self).size + @property + def dtype(self): + """Data-type of the array's elements. + + Returns + ------- + numpy.dtype + This NDArray's data type. + + Examples + -------- + >>> x = np.zeros((2,3)) + >>> x.dtype + dtype('float32') + >>> y = np.zeros((2,3), dtype='int32') + >>> y.dtype + dtype('int32') + """ + return _np.dtype(super(ndarray, self).dtype) + def tostype(self, stype): raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype') diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 53a8076f1303..b6692c009c18 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1075,6 +1075,29 @@ def test_boolean_indexing_autograd(): test_boolean_indexing_autograd() +@with_seed() +@use_np +def test_np_get_dtype(): + dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, _np.bool, _np.bool_, + 'int8', 'int32', 'float16', 'float32', 'float64', 'bool', None] + objects = [ + [], + (), + [[1, 2], [3, 4]], + _np.random.uniform(size=rand_shape_nd(3)), + _np.random.uniform(size=(3, 0, 4)) + ] + for dtype in dtypes: + for src in objects: + mx_arr = np.array(src, dtype=dtype) + assert mx_arr.ctx == mx.current_context() + if isinstance(src, mx.nd.NDArray): + np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) + else: + np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32) + assert type(mx_arr.dtype) == type(np_arr.dtype) + + if __name__ == '__main__': import nose nose.runmodule()