Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

custom op full_like #16778

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num']
'nan_to_num', 'full_like']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5211,6 +5211,59 @@ def resize(a, new_shape):
return _npi.resize_fallback(a, new_shape=new_shape)


@set_module('mxnet.ndarray.numpy')
def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments
"""
Return a full array with the same shape and type as a given array.

Parameters
----------
a : ndarray
The shape and data-type of `a` define these same attributes of
the returned array.
fill_value : scalar
Fill value.
dtype : data-type, optional
Overrides the data type of the result.
Temporarily do not support boolean type.

Returns
-------
out : ndarray
Array of `fill_value` with the same shape and type as `a`.

See Also
--------
empty_like : Return an empty array with shape and type of input.
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full : Return a new array of given shape filled with value.

Examples
--------
>>> x = np.arange(6, dtype=int)
>>> np.full_like(x, 1)
array([1, 1, 1, 1, 1, 1], dtype=int64)
>>> np.full_like(x, 0.1)
array([0, 0, 0, 0, 0, 0], dtype=int64)
>>> np.full_like(x, 0.1, dtype=np.float64)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64)
>>> np.full_like(x, np.nan, dtype=np.double)
array([nan, nan, nan, nan, nan, nan], dtype=float64)
>>> y = np.arange(6, dtype=np.float32)
>>> np.full_like(y, 0.1)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool'}
try:
dtype = dtype if isinstance(dtype, str) else dtype_list[dtype]
except:
raise NotImplementedError("Do not support this dtype at this moment")
return _npi.full_like_fallback(a, fill_value=fill_value, dtype=dtype)


@set_module('mxnet.ndarray.numpy')
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
"""
Expand Down
47 changes: 46 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num']
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'full_like']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -7204,6 +7204,51 @@ def resize(a, new_shape):
return _mx_nd_np.resize(a, new_shape)


@set_module('mxnet.numpy')
def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments
"""
Return a full array with the same shape and type as a given array.

Parameters
----------
a : ndarray
The shape and data-type of `a` define these same attributes of
the returned array.
fill_value : scalar
Fill value.
dtype : data-type, optional
Overrides the data type of the result.

Returns
-------
out : ndarray
Array of `fill_value` with the same shape and type as `a`.

See Also
--------
empty_like : Return an empty array with shape and type of input.
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full : Return a new array of given shape filled with value.

Examples
--------
>>> x = np.arange(6, dtype=int)
>>> np.full_like(x, 1)
array([1, 1, 1, 1, 1, 1], dtype=int64)
>>> np.full_like(x, 0.1)
array([0, 0, 0, 0, 0, 0], dtype=int64)
>>> np.full_like(x, 0.1, dtype=np.float64)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64)
>>> np.full_like(x, np.nan, dtype=np.double)
array([nan, nan, nan, nan, nan, nan], dtype=float64)
>>> y = np.arange(6, dtype=np.float32)
>>> np.full_like(y, 0.1)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
"""
return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype)


@set_module('mxnet.numpy')
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'may_share_memory',
'diff',
'resize',
'full_like',
]


Expand Down
55 changes: 55 additions & 0 deletions python/mxnet/numpy_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,61 @@ def _register_helper(prop_cls):
return _register_helper


@use_np # enforce np shape and array semantics for all the methods in this class
class FullLike(operator.CustomOp):
"""Fallback to NumPy full_like operator."""
def __init__(self, fill_value, dtype):
super(FullLike, self).__init__()
self._fill_value = fill_value
self._dtype = dtype

def forward(self, is_train, req, in_data, out_data, aux):
out = np.full_like(in_data[0].asnumpy(), self._fill_value, dtype=self._dtype)
self.assign(out_data[0], req[0], _mx_np.array(out, dtype=out.dtype, ctx=out_data[0].ctx))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
raise NotImplementedError('Operator full_like does not support gradient computation')


@register('full_like_fallback')
class FullLikeProp(operator.CustomOpProp):
"""Fallback full_like operator properties."""
def __init__(self, fill_value, dtype):
super(FullLikeProp, self).__init__(need_top_grad=True)
self._fill_value = self.get_fill_value(fill_value)
self._dtype = None if dtype == 'None' else dtype

@staticmethod
def get_fill_value(fill_value):
"""Convert fill_value to corresponding data type"""
if fill_value == 'nan':
return np.nan
elif fill_value == 'inf':
return np.inf
elif fill_value == '-inf':
return -np.inf
try:
return ast.literal_eval(fill_value)
except:
raise NotImplementedError("Do not support fill_value %s at this moment"%(fill_value))

def list_arguments(self):
return ['a']

def infer_type(self, in_type):
if self._dtype is None:
return (in_type[0],), (in_type[0],), ()
else:
out_dtype = eval('np.'+self._dtype) # pylint: disable=W0123
return (in_type[0],), (out_dtype,), ()

def infer_shape(self, in_shape):
return (in_shape[0],), (in_shape[0],), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
return FullLike(self._fill_value, self._dtype)


@use_np # enforce np shape and array semantics for all the methods in this class
class Resize(operator.CustomOp):
"""Fallback to NumPy resize operator."""
Expand Down
39 changes: 38 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num']
'resize', 'nan_to_num', 'full_like']


def _num_outputs(sym):
Expand Down Expand Up @@ -4824,6 +4824,43 @@ def resize(a, new_shape):
return _npi.resize_fallback(a, new_shape=new_shape)


@set_module('mxnet.symbol.numpy')
def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments
"""
Return a full array with the same shape and type as a given array.

Parameters
----------
a : _Symbol
The shape and data-type of `a` define these same attributes of
the returned array.
fill_value : scalar
Fill value.
dtype : data-type, optional
Overrides the data type of the result.

Returns
-------
out : _Symbol
Array of `fill_value` with the same shape and type as `a`.

See Also
--------
empty_like : Return an empty array with shape and type of input.
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full : Return a new array of given shape filled with value.
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool'}
try:
dtype = dtype if isinstance(dtype, str) else dtype_list[dtype]
except:
raise NotImplementedError("Do not support this dtype at this moment")
return _npi.full_like_fallback(a, fill_value=fill_value, dtype=dtype)


@set_module('mxnet.symbol.numpy')
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,13 @@ def _add_workload_resize():
OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 100))


def _add_workload_full_like():
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(1,3,4), dtype='float64'), 1)
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3,1)), 2, np.int64)
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3)), np.nan)
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(0,3)), 0, np.float32)


@use_np
def _prepare_workloads():
array_pool = {
Expand Down Expand Up @@ -1284,6 +1291,7 @@ def _prepare_workloads():
_add_workload_less_equal(array_pool)
_add_workload_diff()
_add_workload_resize()
_add_workload_full_like()


_prepare_workloads()
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pylint: skip-file
from __future__ import absolute_import
from distutils.version import StrictVersion
import sys
import unittest
import itertools
Expand Down Expand Up @@ -4463,6 +4464,52 @@ def hybrid_forward(self, F, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)


@with_seed()
@use_np
def test_np_full_like():
class TestFullLike(HybridBlock):
def __init__(self, fill_value, dtype):
super(TestFullLike, self).__init__()
self._fill_value = fill_value
self._dtype = dtype

def hybrid_forward(self, F, x, *args, **kwargs):
return F.np.full_like(x, self._fill_value, self._dtype)

if StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can decorate the test function with this condition as the following.

@unittest.skipIf(StrictVersion(platform.python_version()) < StrictVersion('3.0.0'), "reason to skip")
def test_np_full_like():
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

return

dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
shapes = [
(),
(1,),
(4, 3),
(4, 5),
(0, 1),
(6, 5, 6),
(4, 2, 1, 2),
(5, 0, 3, 3),
(3, 3, 0, 0),
]
fill_values = [0, 1, 2, 3, 4, 5, 6]
flags = [True, False]
_np_version = _np.version.version
for fill_value, dtype, shape, hybridize in itertools.product(
fill_values, dtypes, shapes, flags):
param_dtype= _np.random.choice(dtypes)
a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype)
test = TestFullLike(fill_value, param_dtype)
expected_ret = _np.full_like(a, fill_value=fill_value, dtype=param_dtype)
if hybridize:
test.hybridize()
ret = test(a)
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

# check imperative again
ret = np.full_like(a, fill_value, param_dtype)
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()