Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
handle array_like fill_value for np.full; add unit test coverage (#17245
Browse files Browse the repository at this point in the history
)
  • Loading branch information
haojin2 authored and reminisce committed Jan 14, 2020
1 parent 1a61a86 commit 2938684
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
16 changes: 15 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,12 @@ def broadcast_to(array, shape):
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments
"""
Return a new array of given shape and type, filled with `fill_value`.
Parameters
----------
shape : int or sequence of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
fill_value : scalar
fill_value : scalar or ndarray
Fill value.
dtype : data-type, optional
The desired data-type for the array. The default, `None`, means
Expand All @@ -311,10 +312,14 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Array of `fill_value` with the given shape, dtype, and order.
If `fill_value` is an ndarray, out will have the same context as `fill_value`
regardless of the provided `ctx`.
Notes
-----
This function differs from the original `numpy.full
Expand All @@ -323,11 +328,13 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
- Have an additional `ctx` argument to specify the device
- Have an additional `out` argument
- Currently does not support `order` selection
See Also
--------
empty : Return a new uninitialized array.
ones : Return a new array setting values to one.
zeros : Return a new array setting values to zero.
Examples
--------
>>> np.full((2, 2), 10)
Expand All @@ -336,11 +343,18 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
>>> np.full((2, 2), 2, dtype=np.int32, ctx=mx.cpu(0))
array([[2, 2],
[2, 2]], dtype=int32)
"""
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
if isinstance(fill_value, NDArray):
if dtype is None:
ret = broadcast_to(fill_value, shape)
else:
ret = broadcast_to(fill_value, shape).astype(dtype)
return ret
dtype = _np.float32 if dtype is None else dtype
return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out)
# pylint: enable=too-many-arguments, redefined-outer-name
Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,7 +2321,7 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None):
----------
shape : int or sequence of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
fill_value : scalar
fill_value : scalar or ndarray
Fill value.
dtype : data-type, optional
The desired data-type for the array. The default, `None`, means
Expand All @@ -2339,6 +2339,8 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None):
-------
out : ndarray
Array of `fill_value` with the given shape, dtype, and order.
If `fill_value` is an ndarray, out will have the same context as `fill_value`
regardless of the provided `ctx`.
Notes
-----
Expand Down
12 changes: 9 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
except ImportError:
from builtins import slice as py_slice

__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'empty_like', 'bitwise_not', 'invert', 'delete',
'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'bitwise_not', 'invert',
'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
----------
shape : int or sequence of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
fill_value : scalar
fill_value : scalar or _Symbol
Fill value.
dtype : data-type, optional
The desired data-type for the array. The default, `None`, means
Expand Down Expand Up @@ -1215,6 +1215,12 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
raise NotImplementedError
if ctx is None:
ctx = current_context()
if isinstance(fill_value, Symbol):
if dtype is None:
ret = broadcast_to(fill_value, shape)
else:
ret = broadcast_to(fill_value, shape).astype(dtype)
return ret
dtype = _np.float32 if dtype is None else dtype
return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out)

Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4636,6 +4636,55 @@ def g(data):
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_full():
class TestFull(HybridBlock):
def __init__(self, shape, dtype=None):
super(TestFull, self).__init__()
self._shape = shape
self._dtype = dtype

def hybrid_forward(self, F, a):
return F.np.full(self._shape, a, dtype=self._dtype)

configs = [
((3, 4), 2.0),
((0, 3), 2.0),
((3, 4), np.array(2.0)),
((0, 3), np.array(2.0)),
((2, 3), np.array([1, 2, 3], dtype=np.float32)),
((2, 3), np.array([1, 2, 3], dtype=np.int64)),
((0, 3), np.array([1, 2, 3], dtype=np.float32)),
((0, 3), np.array([1, 2, 3], dtype=np.int64)),
]

rtol, atol = 1e-3, 1e-5
dtypes = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
for shape, fill_value in configs:
for hybridize in [True, False]:
for dtype in dtypes:
if isinstance(fill_value, np.ndarray):
test_full = TestFull(shape, dtype=dtype)
if hybridize:
test_full.hybridize()
mx_out = test_full(fill_value)
expected_np = _np.full(shape, fill_value.asnumpy(), dtype=dtype)
assert mx_out.shape == expected_np.shape
assert mx_out.dtype == expected_np.dtype
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test imperative once again
mx_out = np.full(shape, fill_value, dtype=dtype)
if isinstance(fill_value, np.ndarray):
expected_np = _np.full(shape, fill_value.asnumpy(), dtype=dtype)
else:
expected_np = _np.full(shape, fill_value, dtype=dtype)
assert mx_out.shape == expected_np.shape
assert mx_out.dtype == expected_np.dtype
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_full_like():
Expand Down

0 comments on commit 2938684

Please sign in to comment.