diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index de2ad692adfc..2e39dc6f9889 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -869,7 +869,7 @@ def _sync_copyfrom(self, source_array): source_array = np.ascontiguousarray(source_array, dtype=self.dtype) if source_array.shape != self.shape: raise ValueError('Shape inconsistent: expected %s vs got %s'%( - str(self.shape), str(source_array.shape))) + str(source_array.shape), str(self.shape))) check_call(_LIB.MXNDArraySyncCopyFromCPU( self.handle, source_array.ctypes.data_as(ctypes.c_void_p), @@ -2479,6 +2479,8 @@ def array(source_array, ctx=None, dtype=None): if isinstance(source_array, NDArray): dtype = source_array.dtype if dtype is None else dtype else: + if isinstance(source_array, (float, int)): + source_array = [float(source_array)] dtype = mx_real_t if dtype is None else dtype if not isinstance(source_array, np.ndarray): try: diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index a1c178f8234e..2e3c17141d7d 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1471,6 +1471,24 @@ def test_dlpack(): mx.test_utils.assert_almost_equal(a_np, d_np) mx.test_utils.assert_almost_equal(a_np, e_np) +@with_seed() +def test_ndarray_constant_init(): + # a=mx.nd.array([1]) + a=mx.nd.array(9) + assert(isinstance(a,mx.nd.NDArray)) + +@with_seed() +def test_symbol_constant_init(): + # a=mx.nd.array([1]) + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = a * b + d = c.eval(a=mx.nd.array(2),b=mx.nd.array(3)) + assert(isinstance(d[0],mx.nd.NDArray)) + e = c.bind(ctx=mx.cpu(),args=[mx.nd.array(2),mx.nd.array(3)]) + f = e.forward() + assert(isinstance(f[0],mx.nd.NDArray)) + if __name__ == '__main__': import nose nose.runmodule()