Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Oct 29, 2019
1 parent 7d82ee3 commit dd2c917
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
16 changes: 8 additions & 8 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,15 +1995,15 @@ def array(object, dtype=None, ctx=None):
ctx = current_context()
if isinstance(object, (ndarray, _np.ndarray)):
dtype = object.dtype if dtype is None else dtype
elif isinstance(object, NDArray):
raise ValueError("")
else:
dtype = _np.float32 if dtype is None else dtype
if hasattr(object, "dtype"):
dtype = object.dtype
if not isinstance(object, (ndarray, _np.ndarray)):
try:
object = _np.array(object, dtype=dtype)
except Exception as e:
raise TypeError('{}'.format(str(e)))
if dtype is None:
dtype = object.dtype if hasattr(object, "dtype") else _np.float32
try:
object = _np.array(object, dtype=dtype)
except Exception as e:
raise TypeError('{}'.format(str(e)))
ret = empty(object.shape, dtype=dtype, ctx=ctx)
if len(object.shape) == 0:
ret[()] = object
Expand Down
9 changes: 4 additions & 5 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,12 @@ def test_np_array_creation():
for src in objects:
mx_arr = np.array(src, dtype=dtype)
assert mx_arr.ctx == mx.current_context()
np_dtype = _np.float32 if dtype is None else dtype
if dtype is None and isinstance(src, _np.ndarray):
np_dtype = src.dtype
if dtype is None:
dtype = src.dtype if isinstance(src, _np.ndarray) else _np.float32
if isinstance(src, mx.nd.NDArray):
np_arr = _np.array(src.asnumpy(), dtype=np_dtype)
np_arr = _np.array(src.asnumpy(), dtype=dtype)
else:
np_arr = _np.array(src, dtype=np_dtype)
np_arr = _np.array(src, dtype=dtype)
assert mx_arr.dtype == np_arr.dtype
assert same(mx_arr.asnumpy(), np_arr)

Expand Down

0 comments on commit dd2c917

Please sign in to comment.