diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 8b8a319b2125..f4b3ab0db43a 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -151,7 +151,8 @@ def _as_mx_np_array(object, ctx=None): if isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): - return array(object, dtype=object.dtype, ctx=ctx) + np_dtype = _np.dtype(object.dtype).type + return array(object, dtype=np_dtype, ctx=ctx) elif isinstance(object, (integer_types, numeric_types)): return object elif isinstance(object, (list, tuple)):