diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index f04f94fa38b1..eb0ab66fe2b0 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -293,11 +293,12 @@ def broadcast_to(array, shape): def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments """ Return a new array of given shape and type, filled with `fill_value`. + Parameters ---------- shape : int or sequence of ints Shape of the new array, e.g., ``(2, 3)`` or ``2``. - fill_value : scalar + fill_value : scalar or ndarray Fill value. dtype : data-type, optional The desired data-type for the array. The default, `None`, means @@ -310,10 +311,14 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin A location into which the result is stored. If provided, it must have the same shape and dtype as input ndarray. If not provided or `None`, a freshly-allocated array is returned. + Returns ------- out : ndarray Array of `fill_value` with the given shape, dtype, and order. + If `fill_value` is an ndarray, out will have the same context as `fill_value` + regardless of the provided `ctx`. + Notes ----- This function differs from the original `numpy.full @@ -322,11 +327,13 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin - Have an additional `ctx` argument to specify the device - Have an additional `out` argument - Currently does not support `order` selection + See Also -------- empty : Return a new uninitialized array. ones : Return a new array setting values to one. zeros : Return a new array setting values to zero. + Examples -------- >>> np.full((2, 2), 10) @@ -335,11 +342,18 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin >>> np.full((2, 2), 2, dtype=np.int32, ctx=mx.cpu(0)) array([[2, 2], [2, 2]], dtype=int32) + """ if order != 'C': raise NotImplementedError if ctx is None: ctx = current_context() + if isinstance(fill_value, NDArray): + if dtype is None: + ret = broadcast_to(fill_value, shape) + else: + ret = broadcast_to(fill_value, shape).astype(dtype) + return ret dtype = _np.float32 if dtype is None else dtype return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out) # pylint: enable=too-many-arguments, redefined-outer-name diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 7fd8c1b4c86f..caef155bb99b 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -2321,7 +2321,7 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): ---------- shape : int or sequence of ints Shape of the new array, e.g., ``(2, 3)`` or ``2``. - fill_value : scalar + fill_value : scalar or ndarray Fill value. dtype : data-type, optional The desired data-type for the array. The default, `None`, means @@ -2339,6 +2339,8 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): ------- out : ndarray Array of `fill_value` with the given shape, dtype, and order. + If `fill_value` is an ndarray, out will have the same context as `fill_value` + regardless of the provided `ctx`. Notes ----- diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 147fba2eaa30..e4eff9b46571 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -36,8 +36,8 @@ except ImportError: from builtins import slice as py_slice -__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'empty_like', 'bitwise_not', 'invert', 'delete', - 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', +__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'bitwise_not', 'invert', + 'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', @@ -1171,7 +1171,7 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin ---------- shape : int or sequence of ints Shape of the new array, e.g., ``(2, 3)`` or ``2``. - fill_value : scalar + fill_value : scalar or _Symbol Fill value. dtype : data-type, optional The desired data-type for the array. The default, `None`, means @@ -1214,6 +1214,12 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin raise NotImplementedError if ctx is None: ctx = current_context() + if isinstance(fill_value, Symbol): + if dtype is None: + ret = broadcast_to(fill_value, shape) + else: + ret = broadcast_to(fill_value, shape).astype(dtype) + return ret dtype = _np.float32 if dtype is None else dtype return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 274f37124262..579a74fa3ef2 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4572,6 +4572,55 @@ def g(data): assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_full(): + class TestFull(HybridBlock): + def __init__(self, shape, dtype=None): + super(TestFull, self).__init__() + self._shape = shape + self._dtype = dtype + + def hybrid_forward(self, F, a): + return F.np.full(self._shape, a, dtype=self._dtype) + + configs = [ + ((3, 4), 2.0), + ((0, 3), 2.0), + ((3, 4), np.array(2.0)), + ((0, 3), np.array(2.0)), + ((2, 3), np.array([1, 2, 3], dtype=np.float32)), + ((2, 3), np.array([1, 2, 3], dtype=np.int64)), + ((0, 3), np.array([1, 2, 3], dtype=np.float32)), + ((0, 3), np.array([1, 2, 3], dtype=np.int64)), + ] + + rtol, atol = 1e-3, 1e-5 + dtypes = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + for shape, fill_value in configs: + for hybridize in [True, False]: + for dtype in dtypes: + if isinstance(fill_value, np.ndarray): + test_full = TestFull(shape, dtype=dtype) + if hybridize: + test_full.hybridize() + mx_out = test_full(fill_value) + expected_np = _np.full(shape, fill_value.asnumpy(), dtype=dtype) + assert mx_out.shape == expected_np.shape + assert mx_out.dtype == expected_np.dtype + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.full(shape, fill_value, dtype=dtype) + if isinstance(fill_value, np.ndarray): + expected_np = _np.full(shape, fill_value.asnumpy(), dtype=dtype) + else: + expected_np = _np.full(shape, fill_value, dtype=dtype) + assert mx_out.shape == expected_np.shape + assert mx_out.dtype == expected_np.dtype + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_full_like():