From 30858871968d222ed713e3f3b3ba61e24eea2d82 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 26 Sep 2018 16:10:28 -0700 Subject: [PATCH 1/3] fix constant init --- python/mxnet/ndarray/ndarray.py | 4 +++- tests/python/unittest/test_ndarray.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index de2ad692adfc..98e92d42c6ef 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -869,7 +869,7 @@ def _sync_copyfrom(self, source_array): source_array = np.ascontiguousarray(source_array, dtype=self.dtype) if source_array.shape != self.shape: raise ValueError('Shape inconsistent: expected %s vs got %s'%( - str(self.shape), str(source_array.shape))) + str(source_array.shape), str(self.shape))) check_call(_LIB.MXNDArraySyncCopyFromCPU( self.handle, source_array.ctypes.data_as(ctypes.c_void_p), @@ -2479,6 +2479,8 @@ def array(source_array, ctx=None, dtype=None): if isinstance(source_array, NDArray): dtype = source_array.dtype if dtype is None else dtype else: + if isinstance(source_array,(float,int)): + source_array=[float(source_array)] dtype = mx_real_t if dtype is None else dtype if not isinstance(source_array, np.ndarray): try: diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index a1c178f8234e..ac7be4735f1c 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1471,6 +1471,24 @@ def test_dlpack(): mx.test_utils.assert_almost_equal(a_np, d_np) mx.test_utils.assert_almost_equal(a_np, e_np) +@with_seed() +def test_ndarray_constant_init(): + # a=mx.nd.array([1]) + a=mx.nd.array(9) + assert(isinstance(a,mx.nd.NDArray)) + +@with_seed() +def test_symbol_constant_init(): + # a=mx.nd.array([1]) + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = a * b + d = c.eval(a=mx.nd.array(2),b=mx.nd.array(3)) + assert(isinstance(d[0],mx.nd.NDArray)) + e = c.bind(mx.cpu(),args=[mx.nd.array(2),mx.nd.array(3)]) + f = e.forward() + assert(isinstance(f[0],mx.nd.NDArray)) + if __name__ == '__main__': import nose nose.runmodule() From 6c5e1f68ea0526b40f0a3028b7e833639fe3e43d Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 26 Sep 2018 17:12:49 -0700 Subject: [PATCH 2/3] lint --- python/mxnet/ndarray/ndarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 98e92d42c6ef..2e39dc6f9889 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2479,8 +2479,8 @@ def array(source_array, ctx=None, dtype=None): if isinstance(source_array, NDArray): dtype = source_array.dtype if dtype is None else dtype else: - if isinstance(source_array,(float,int)): - source_array=[float(source_array)] + if isinstance(source_array, (float, int)): + source_array = [float(source_array)] dtype = mx_real_t if dtype is None else dtype if not isinstance(source_array, np.ndarray): try: From b54ec4b0cc0ecf4d5f85d443b122966e7d6d1677 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 26 Sep 2018 19:01:41 -0700 Subject: [PATCH 3/3] ctx error --- tests/python/unittest/test_ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index ac7be4735f1c..2e3c17141d7d 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1485,7 +1485,7 @@ def test_symbol_constant_init(): c = a * b d = c.eval(a=mx.nd.array(2),b=mx.nd.array(3)) assert(isinstance(d[0],mx.nd.NDArray)) - e = c.bind(mx.cpu(),args=[mx.nd.array(2),mx.nd.array(3)]) + e = c.bind(ctx=mx.cpu(),args=[mx.nd.array(2),mx.nd.array(3)]) f = e.forward() assert(isinstance(f[0],mx.nd.NDArray))