diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index cd4b3b8fae5c..87bfa7a06355 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -182,7 +182,7 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name def _as_mx_np_array(object, ctx=None): """Convert object to mxnet.numpy.ndarray.""" - if isinstance(object, ndarray): + if object is None or isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): np_dtype = _np.dtype(object.dtype).type