From dd2c917b67f08d14028738a8a96a60786cf846ec Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 29 Oct 2019 22:56:55 +0000 Subject: [PATCH] address comments --- python/mxnet/numpy/multiarray.py | 16 ++++++++-------- tests/python/unittest/test_numpy_ndarray.py | 9 ++++----- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b0701f7d5457..97ca52e62993 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1995,15 +1995,15 @@ def array(object, dtype=None, ctx=None): ctx = current_context() if isinstance(object, (ndarray, _np.ndarray)): dtype = object.dtype if dtype is None else dtype + elif isinstance(object, NDArray): + raise ValueError("") else: - dtype = _np.float32 if dtype is None else dtype - if hasattr(object, "dtype"): - dtype = object.dtype - if not isinstance(object, (ndarray, _np.ndarray)): - try: - object = _np.array(object, dtype=dtype) - except Exception as e: - raise TypeError('{}'.format(str(e))) + if dtype is None: + dtype = object.dtype if hasattr(object, "dtype") else _np.float32 + try: + object = _np.array(object, dtype=dtype) + except Exception as e: + raise TypeError('{}'.format(str(e))) ret = empty(object.shape, dtype=dtype, ctx=ctx) if len(object.shape) == 0: ret[()] = object diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index f16e722ff94c..239f300e028e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -96,13 +96,12 @@ def test_np_array_creation(): for src in objects: mx_arr = np.array(src, dtype=dtype) assert mx_arr.ctx == mx.current_context() - np_dtype = _np.float32 if dtype is None else dtype - if dtype is None and isinstance(src, _np.ndarray): - np_dtype = src.dtype + if dtype is None: + dtype = src.dtype if isinstance(src, _np.ndarray) else _np.float32 if isinstance(src, mx.nd.NDArray): - np_arr = _np.array(src.asnumpy(), dtype=np_dtype) + np_arr = _np.array(src.asnumpy(), dtype=dtype) else: - np_arr = _np.array(src, dtype=np_dtype) + np_arr = _np.array(src, dtype=dtype) assert mx_arr.dtype == np_arr.dtype assert same(mx_arr.asnumpy(), np_arr)