Skip to content

Commit

Permalink
Fix dtype bug (apache#16467)
Browse files Browse the repository at this point in the history
* fix dtype bug

* override dtype property; add tests

* fix display bug
  • Loading branch information
xidulu authored and aaronmarkham committed Oct 16, 2019
1 parent 7d0753b commit 4cb190b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
22 changes: 21 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def __repr__(self):
if dtype == _np.float32:
array_str = array_str[:array_str.rindex(',')] + ')'
elif dtype not in (_np.float32, _np.bool_):
array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__)
array_str = array_str[:-1] + ', dtype={})'.format(dtype)

context = self.context
if context.device_type == 'cpu':
Expand Down Expand Up @@ -1680,6 +1680,26 @@ def size(self):
"""Number of elements in the array."""
return super(ndarray, self).size

@property
def dtype(self):
"""Data-type of the array's elements.
Returns
-------
numpy.dtype
This NDArray's data type.
Examples
--------
>>> x = np.zeros((2,3))
>>> x.dtype
dtype('float32')
>>> y = np.zeros((2,3), dtype='int32')
>>> y.dtype
dtype('int32')
"""
return _np.dtype(super(ndarray, self).dtype)

def tostype(self, stype):
raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype')

Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,29 @@ def test_boolean_indexing_autograd():
test_boolean_indexing_autograd()


@with_seed()
@use_np
def test_np_get_dtype():
dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, _np.bool, _np.bool_,
'int8', 'int32', 'float16', 'float32', 'float64', 'bool', None]
objects = [
[],
(),
[[1, 2], [3, 4]],
_np.random.uniform(size=rand_shape_nd(3)),
_np.random.uniform(size=(3, 0, 4))
]
for dtype in dtypes:
for src in objects:
mx_arr = np.array(src, dtype=dtype)
assert mx_arr.ctx == mx.current_context()
if isinstance(src, mx.nd.NDArray):
np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32)
else:
np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32)
assert type(mx_arr.dtype) == type(np_arr.dtype)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 4cb190b

Please sign in to comment.