Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
custom op full_like
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 committed Nov 13, 2019
1 parent 177fd3e commit 9223167
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 3 deletions.
55 changes: 54 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num']
'nan_to_num', 'full_like']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5211,6 +5211,59 @@ def resize(a, new_shape):
return _npi.resize_fallback(a, new_shape=new_shape)


@set_module('mxnet.ndarray.numpy')
def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments
"""
Return a full array with the same shape and type as a given array.
Parameters
----------
a : ndarray
The shape and data-type of `a` define these same attributes of
the returned array.
fill_value : scalar
Fill value.
dtype : data-type, optional
Overrides the data type of the result.
Temporarily do not support boolean type.
Returns
-------
out : ndarray
Array of `fill_value` with the same shape and type as `a`.
See Also
--------
empty_like : Return an empty array with shape and type of input.
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full : Return a new array of given shape filled with value.
Examples
--------
>>> x = np.arange(6, dtype=int)
>>> np.full_like(x, 1)
array([1, 1, 1, 1, 1, 1], dtype=int64)
>>> np.full_like(x, 0.1)
array([0, 0, 0, 0, 0, 0], dtype=int64)
>>> np.full_like(x, 0.1, dtype=np.float64)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64)
>>> np.full_like(x, np.nan, dtype=np.double)
array([nan, nan, nan, nan, nan, nan], dtype=float64)
>>> y = np.arange(6, dtype=np.float32)
>>> np.full_like(y, 0.1)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool'}
try:
dtype = dtype if isinstance(dtype, str) else dtype_list[dtype]
except:
raise NotImplementedError("Do not support this dtype at this moment")
return _npi.full_like_fallback(a, fill_value=fill_value, dtype=dtype)


@set_module('mxnet.ndarray.numpy')
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
"""
Expand Down
47 changes: 46 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num']
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'full_like']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -7204,6 +7204,51 @@ def resize(a, new_shape):
return _mx_nd_np.resize(a, new_shape)


@set_module('mxnet.numpy')
def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments
"""
Return a full array with the same shape and type as a given array.
Parameters
----------
a : ndarray
The shape and data-type of `a` define these same attributes of
the returned array.
fill_value : scalar
Fill value.
dtype : data-type, optional
Overrides the data type of the result.
Returns
-------
out : ndarray
Array of `fill_value` with the same shape and type as `a`.
See Also
--------
empty_like : Return an empty array with shape and type of input.
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full : Return a new array of given shape filled with value.
Examples
--------
>>> x = np.arange(6, dtype=int)
>>> np.full_like(x, 1)
array([1, 1, 1, 1, 1, 1], dtype=int64)
>>> np.full_like(x, 0.1)
array([0, 0, 0, 0, 0, 0], dtype=int64)
>>> np.full_like(x, 0.1, dtype=np.float64)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64)
>>> np.full_like(x, np.nan, dtype=np.double)
array([nan, nan, nan, nan, nan, nan], dtype=float64)
>>> y = np.arange(6, dtype=np.float32)
>>> np.full_like(y, 0.1)
array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
"""
return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype)


@set_module('mxnet.numpy')
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'may_share_memory',
'diff',
'resize',
'full_like',
]


Expand Down
55 changes: 55 additions & 0 deletions python/mxnet/numpy_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,61 @@ def _register_helper(prop_cls):
return _register_helper


@use_np # enforce np shape and array semantics for all the methods in this class
class FullLike(operator.CustomOp):
"""Fallback to NumPy full_like operator."""
def __init__(self, fill_value, dtype):
super(FullLike, self).__init__()
self._fill_value = fill_value
self._dtype = dtype

def forward(self, is_train, req, in_data, out_data, aux):
out = np.full_like(in_data[0].asnumpy(), self._fill_value, dtype=self._dtype)
self.assign(out_data[0], req[0], _mx_np.array(out, dtype=out.dtype, ctx=out_data[0].ctx))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
raise NotImplementedError('Operator full_like does not support gradient computation')


@register('full_like_fallback')
class FullLikeProp(operator.CustomOpProp):
"""Fallback full_like operator properties."""
def __init__(self, fill_value, dtype):
super(FullLikeProp, self).__init__(need_top_grad=True)
self._fill_value = self.get_fill_value(fill_value)
self._dtype = None if dtype == 'None' else dtype

@staticmethod
def get_fill_value(fill_value):
"""Convert fill_value to corresponding data type"""
if fill_value == 'nan':
return np.nan
elif fill_value == 'inf':
return np.inf
elif fill_value == '-inf':
return -np.inf
try:
return ast.literal_eval(fill_value)
except:
raise NotImplementedError("Do not support fill_value %s at this moment"%(fill_value))

def list_arguments(self):
return ['a']

def infer_type(self, in_type):
if self._dtype is None:
return (in_type[0],), (in_type[0],), ()
else:
out_dtype = eval('np.'+self._dtype) # pylint: disable=W0123
return (in_type[0],), (out_dtype,), ()

def infer_shape(self, in_shape):
return (in_shape[0],), (in_shape[0],), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
return FullLike(self._fill_value, self._dtype)


@use_np # enforce np shape and array semantics for all the methods in this class
class Resize(operator.CustomOp):
"""Fallback to NumPy resize operator."""
Expand Down
39 changes: 38 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num']
'resize', 'nan_to_num', 'full_like']


def _num_outputs(sym):
Expand Down Expand Up @@ -4824,6 +4824,43 @@ def resize(a, new_shape):
return _npi.resize_fallback(a, new_shape=new_shape)


@set_module('mxnet.symbol.numpy')
def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments
"""
Return a full array with the same shape and type as a given array.
Parameters
----------
a : _Symbol
The shape and data-type of `a` define these same attributes of
the returned array.
fill_value : scalar
Fill value.
dtype : data-type, optional
Overrides the data type of the result.
Returns
-------
out : _Symbol
Array of `fill_value` with the same shape and type as `a`.
See Also
--------
empty_like : Return an empty array with shape and type of input.
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full : Return a new array of given shape filled with value.
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool'}
try:
dtype = dtype if isinstance(dtype, str) else dtype_list[dtype]
except:
raise NotImplementedError("Do not support this dtype at this moment")
return _npi.full_like_fallback(a, fill_value=fill_value, dtype=dtype)


@set_module('mxnet.symbol.numpy')
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,13 @@ def _add_workload_resize():
OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 100))


def _add_workload_full_like():
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(1,3,4), dtype='float64'), 1)
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3,1)), 2, np.int64)
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3)), np.nan)
OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(0,3)), 0, np.float32)


@use_np
def _prepare_workloads():
array_pool = {
Expand Down Expand Up @@ -1284,6 +1291,7 @@ def _prepare_workloads():
_add_workload_less_equal(array_pool)
_add_workload_diff()
_add_workload_resize()
_add_workload_full_like()


_prepare_workloads()
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pylint: skip-file
from __future__ import absolute_import
from distutils.version import StrictVersion
import sys
import unittest
import itertools
Expand Down Expand Up @@ -4463,6 +4464,52 @@ def hybrid_forward(self, F, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)


@with_seed()
@use_np
def test_np_full_like():
class TestFullLike(HybridBlock):
def __init__(self, fill_value, dtype):
super(TestFullLike, self).__init__()
self._fill_value = fill_value
self._dtype = dtype

def hybrid_forward(self, F, x, *args, **kwargs):
return F.np.full_like(x, self._fill_value, self._dtype)

if StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
return

dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
shapes = [
(),
(1,),
(4, 3),
(4, 5),
(0, 1),
(6, 5, 6),
(4, 2, 1, 2),
(5, 0, 3, 3),
(3, 3, 0, 0),
]
fill_values = [0, 1, 2, 3, 4, 5, 6]
flags = [True, False]
_np_version = _np.version.version
for fill_value, dtype, shape, hybridize in itertools.product(
fill_values, dtypes, shapes, flags):
param_dtype= _np.random.choice(dtypes)
a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype)
test = TestFullLike(fill_value, param_dtype)
expected_ret = _np.full_like(a, fill_value=fill_value, dtype=param_dtype)
if hybridize:
test.hybridize()
ret = test(a)
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

# check imperative again
ret = np.full_like(a, fill_value, param_dtype)
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 9223167

Please sign in to comment.