From c60ad8df532ace7315e91a76e6a43cae677860ce Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Mon, 28 Oct 2019 23:33:42 +0000 Subject: [PATCH 1/7] fix behavior of np.array when given official numpy ndarray --- Makefile | 2 +- python/mxnet/numpy/multiarray.py | 6 ++++-- tests/python/unittest/test_numpy_ndarray.py | 8 ++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 63a978d01d8a..911ded6a8fa2 100644 --- a/Makefile +++ b/Makefile @@ -431,7 +431,7 @@ endif # be JIT-compiled by the updated driver from the included PTX. ifeq ($(USE_CUDA), 1) ifeq ($(CUDA_ARCH),) - KNOWN_CUDA_ARCHS := 30 35 50 52 60 61 70 75 + KNOWN_CUDA_ARCHS := 52 # Run nvcc on a zero-length file to check architecture-level support. # Create args to include SASS in the fat binary for supported levels. CUDA_ARCH := $(foreach arch,$(KNOWN_CUDA_ARCHS), \ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a6d90881da4f..909ff1d23cfe 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -577,7 +577,7 @@ def __setitem__(self, key, value): if not isinstance(key, tuple) or len(key) != 0: raise IndexError('scalar tensor can only accept `()` as index') if isinstance(value, numeric_types): - self.full(value) + self._full(value) elif isinstance(value, ndarray) and value.size == 1: if value.shape != self.shape: value = value.reshape(self.shape) @@ -1993,10 +1993,12 @@ def array(object, dtype=None, ctx=None): """ if ctx is None: ctx = current_context() - if isinstance(object, ndarray): + if isinstance(object, (ndarray, _np.ndarray)): dtype = object.dtype if dtype is None else dtype else: dtype = _np.float32 if dtype is None else dtype + if hasattr(object, "dtype"): + dtype = object.dtype if not isinstance(object, (ndarray, _np.ndarray)): try: object = _np.array(object, dtype=dtype) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 6077f4df13ae..dc479b3e17a3 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -87,6 +87,7 @@ def test_np_array_creation(): [], (), [[1, 2], [3, 4]], + _np.random.randint(-10, 10, size=rand_shape_nd(3)), _np.random.uniform(size=rand_shape_nd(3)), _np.random.uniform(size=(3, 0, 4)) ] @@ -94,10 +95,13 @@ def test_np_array_creation(): for src in objects: mx_arr = np.array(src, dtype=dtype) assert mx_arr.ctx == mx.current_context() + np_dtype = _np.float32 if dtype is None else dtype + if dtype is None and isinstance(src, _np.ndarray): + np_dtype = src.dtype if isinstance(src, mx.nd.NDArray): - np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) + np_arr = _np.array(src.asnumpy(), dtype=np_dtype) else: - np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32) + np_arr = _np.array(src, dtype=np_dtype) assert mx_arr.dtype == np_arr.dtype assert same(mx_arr.asnumpy(), np_arr) From b83d805f91f6dceb4e67056089c3a70685f8577b Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 29 Oct 2019 09:10:30 +0000 Subject: [PATCH 2/7] bool for expand_dims and cast --- src/operator/mxnet_op.h | 2 +- src/operator/numpy/np_init_op.h | 2 +- src/operator/tensor/elemwise_unary_op.h | 2 +- tests/python/unittest/test_numpy_interoperability.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 463c71b5b0eb..91478660a123 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -671,7 +671,7 @@ template MSHADOW_CINLINE void copy(mshadow::Stream *s, const TBlob& to, const TBlob& from) { CHECK_EQ(from.Size(), to.Size()); CHECK_EQ(from.dev_mask(), to.dev_mask()); - MSHADOW_TYPE_SWITCH(to.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to.type_flag_, DType, { if (to.type_flag_ == from.type_flag_) { mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); } else { diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h index 69999ae8710e..df30d611aa02 100644 --- a/src/operator/numpy/np_init_op.h +++ b/src/operator/numpy/np_init_op.h @@ -205,7 +205,7 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); const TBlob& out_data = outputs[0]; int n = out_data.shape_[0]; - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { Kernel, xpu>::Launch( s, out_data.Size(), out_data.dptr(), n); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index b7625fccf258..188ccd68a340 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -451,7 +451,7 @@ void CastCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, { Tensor out = outputs[0].FlatTo1D(s); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); if (outputs[0].type_flag_ != inputs[0].type_flag_ || req[0] != kWriteInplace) { diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 103f2c117ea6..5d6e8af7fa47 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -93,6 +93,7 @@ def _add_workload_copy(): def _add_workload_expand_dims(): OpArgMngr.add_workload('expand_dims', np.random.uniform(size=(4, 1)), -1) + OpArgMngr.add_workload('expand_dims', np.random.uniform(size=(4, 1)) > 0.5, -1) for axis in range(-5, 4): OpArgMngr.add_workload('expand_dims', np.empty((2, 3, 4, 5)), axis) @@ -852,8 +853,8 @@ def _signs(dt): # test_float_remainder_corner_cases # Check remainder magnitude. for ct in _FLOAT_DTYPES: - b = _np.array(1.0) - a = np.array(_np.nextafter(_np.array(0.0), -b), dtype=ct) + b = _np.array(1.0, dtype=ct) + a = np.array(_np.nextafter(_np.array(0.0, dtype=ct), -b), dtype=ct) b = np.array(b, dtype=ct) OpArgMngr.add_workload('remainder', a, b) OpArgMngr.add_workload('remainder', -a, -b) From ce820628ad8a42a4fd8f351255ffee3bde238847 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 29 Oct 2019 09:18:06 +0000 Subject: [PATCH 3/7] recover original Makefile --- Makefile | 2 +- tests/python/unittest/test_numpy_ndarray.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 911ded6a8fa2..63a978d01d8a 100644 --- a/Makefile +++ b/Makefile @@ -431,7 +431,7 @@ endif # be JIT-compiled by the updated driver from the included PTX. ifeq ($(USE_CUDA), 1) ifeq ($(CUDA_ARCH),) - KNOWN_CUDA_ARCHS := 52 + KNOWN_CUDA_ARCHS := 30 35 50 52 60 61 70 75 # Run nvcc on a zero-length file to check architecture-level support. # Create args to include SASS in the fat binary for supported levels. CUDA_ARCH := $(foreach arch,$(KNOWN_CUDA_ARCHS), \ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index dc479b3e17a3..f16e722ff94c 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -18,6 +18,7 @@ # pylint: skip-file from __future__ import absolute_import from __future__ import division +import itertools import os import unittest import numpy as _np @@ -475,9 +476,6 @@ def test_np_grad_ndarray_type(): @with_seed() @use_np def test_np_ndarray_astype(): - mx_data = np.array([2, 3, 4, 5], dtype=_np.int32) - np_data = mx_data.asnumpy() - class TestAstype(HybridBlock): def __init__(self, dtype, copy): super(TestAstype, self).__init__() @@ -487,24 +485,29 @@ def __init__(self, dtype, copy): def hybrid_forward(self, F, x): return x.astype(dtype=self._dtype, copy=self._copy) - def check_astype_equal(dtype, copy, expect_zero_copy=False, hybridize=False): - test_astype = TestAstype(dtype, copy) + def check_astype_equal(itype, otype, copy, expect_zero_copy=False, hybridize=False): + expect_zero_copy = copy is False and itype == otype + mx_data = np.array([2, 3, 4, 5], dtype=itype) + np_data = mx_data.asnumpy() + test_astype = TestAstype(otype, copy) if hybridize: test_astype.hybridize() mx_ret = test_astype(mx_data) assert type(mx_ret) is np.ndarray - np_ret = np_data.astype(dtype=dtype, copy=copy) + np_ret = np_data.astype(dtype=otype, copy=copy) assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) if expect_zero_copy and not hybridize: assert id(mx_ret) == id(mx_data) assert id(np_ret) == id(np_data) - for dtype in [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, - 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool']: + dtypes = [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, + 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool'] + + for itype, otype in itertools.product(dtypes, dtypes): for copy in [True, False]: for hybridize in [True, False]: - check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype, hybridize) + check_astype_equal(itype, otype, copy, hybridize) @with_seed() From 0b8c2af25d16586507043f423613cd7a76dd7291 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 29 Oct 2019 22:56:55 +0000 Subject: [PATCH 4/7] address comments --- python/mxnet/numpy/multiarray.py | 16 ++++++++-------- tests/python/unittest/test_numpy_ndarray.py | 9 ++++----- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 909ff1d23cfe..4af51fc5bf3a 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1995,15 +1995,15 @@ def array(object, dtype=None, ctx=None): ctx = current_context() if isinstance(object, (ndarray, _np.ndarray)): dtype = object.dtype if dtype is None else dtype + elif isinstance(object, NDArray): + raise ValueError("") else: - dtype = _np.float32 if dtype is None else dtype - if hasattr(object, "dtype"): - dtype = object.dtype - if not isinstance(object, (ndarray, _np.ndarray)): - try: - object = _np.array(object, dtype=dtype) - except Exception as e: - raise TypeError('{}'.format(str(e))) + if dtype is None: + dtype = object.dtype if hasattr(object, "dtype") else _np.float32 + try: + object = _np.array(object, dtype=dtype) + except Exception as e: + raise TypeError('{}'.format(str(e))) ret = empty(object.shape, dtype=dtype, ctx=ctx) if len(object.shape) == 0: ret[()] = object diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index f16e722ff94c..239f300e028e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -96,13 +96,12 @@ def test_np_array_creation(): for src in objects: mx_arr = np.array(src, dtype=dtype) assert mx_arr.ctx == mx.current_context() - np_dtype = _np.float32 if dtype is None else dtype - if dtype is None and isinstance(src, _np.ndarray): - np_dtype = src.dtype + if dtype is None: + dtype = src.dtype if isinstance(src, _np.ndarray) else _np.float32 if isinstance(src, mx.nd.NDArray): - np_arr = _np.array(src.asnumpy(), dtype=np_dtype) + np_arr = _np.array(src.asnumpy(), dtype=dtype) else: - np_arr = _np.array(src, dtype=np_dtype) + np_arr = _np.array(src, dtype=dtype) assert mx_arr.dtype == np_arr.dtype assert same(mx_arr.asnumpy(), np_arr) From 558c15146a9433738fa6880b7a4602c83848df2d Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 30 Oct 2019 09:33:29 +0000 Subject: [PATCH 5/7] add boolean support for cumsum --- src/operator/numpy/np_cumsum-inl.h | 4 ++-- src/operator/numpy/np_cumsum.cc | 3 +++ tests/python/unittest/test_numpy_op.py | 10 ++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/operator/numpy/np_cumsum-inl.h b/src/operator/numpy/np_cumsum-inl.h index 6c6b56d46e76..375d83b2240f 100644 --- a/src/operator/numpy/np_cumsum-inl.h +++ b/src/operator/numpy/np_cumsum-inl.h @@ -98,7 +98,7 @@ void CumsumForwardImpl(const OpContext& ctx, } Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, IType, { MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { Kernel::Launch( s, out.Size() / middle, out.dptr(), @@ -157,7 +157,7 @@ void CumsumBackwardImpl(const OpContext& ctx, } } Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(igrad.type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(igrad.type_flag_, IType, { MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { Kernel::Launch( s, igrad.Size() / middle, igrad.dptr(), diff --git a/src/operator/numpy/np_cumsum.cc b/src/operator/numpy/np_cumsum.cc index 0ddbf521186c..2d5dbb99f90a 100644 --- a/src/operator/numpy/np_cumsum.cc +++ b/src/operator/numpy/np_cumsum.cc @@ -55,6 +55,9 @@ inline bool CumsumType(const nnvm::NodeAttrs& attrs, } else { TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + if (out_attrs->at(0) == mshadow::kBool) { + (*out_attrs)[0] = mshadow::kInt64; + } } return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 67c1ede6cc1a..1a8ec9f60928 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2390,6 +2390,16 @@ def hybrid_forward(self, F, a): np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + for shape in shapes: + for axis in [None] + [i for i in range(0, len(shape))]: + for otype in [None, _np.int32, _np.int64]: + for itype in [_np.bool, _np.int8, _np.int32, _np.int64]: + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype) + mx_out = np.cumsum(x, axis=axis, dtype=otype) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + @with_seed() @use_np From 82dca7bfea667274fb00231d954af494d46991c1 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 30 Oct 2019 09:33:48 +0000 Subject: [PATCH 6/7] add gpu cast boolean support --- src/ndarray/ndarray_function.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 2a1461cc8c48..da7b60db7f13 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -76,7 +76,7 @@ void Copy(const TBlob &from, TBlob *to, from.FlatTo1D(s), s); } else { - MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); }) From e35a2e355cdf522de9dee26957762f34d2ef6ce9 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 31 Oct 2019 00:08:27 +0000 Subject: [PATCH 7/7] add error message --- python/mxnet/numpy/multiarray.py | 5 ++++- tests/python/unittest/test_numpy_op.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 4af51fc5bf3a..bc4b409d5be7 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1996,13 +1996,16 @@ def array(object, dtype=None, ctx=None): if isinstance(object, (ndarray, _np.ndarray)): dtype = object.dtype if dtype is None else dtype elif isinstance(object, NDArray): - raise ValueError("") + raise ValueError("If you're trying to create a mxnet.numpy.ndarray " + "from mx.nd.NDArray, please use the zero-copy as_np_ndarray function.") else: if dtype is None: dtype = object.dtype if hasattr(object, "dtype") else _np.float32 try: object = _np.array(object, dtype=dtype) except Exception as e: + # printing out the error raised by official NumPy's array function + # for transparency on users' side raise TypeError('{}'.format(str(e))) ret = empty(object.shape, dtype=dtype, ctx=ctx) if len(object.shape) == 0: diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1a8ec9f60928..0b15c7ea0d2d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -442,7 +442,7 @@ def is_int(dtype): for axis in ([i for i in range(in_data_dim)] + [(), None]): for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']: for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: - if (is_int(dtype) and not is_int(itype))\ + if (is_int(dtype) and not is_int(itype)) or (is_windows and is_int(itype))\ or (itype == 'bool' and\ (dtype not in ('float32', 'float64', 'int32', 'int64') or is_windows)): continue