From 9223167eff3822cd3ea3ffbb30de2d255ba447e5 Mon Sep 17 00:00:00 2001 From: Alicia1529 Date: Mon, 11 Nov 2019 18:48:05 +0800 Subject: [PATCH] custom op full_like --- python/mxnet/ndarray/numpy/_op.py | 55 ++++++++++++++++++- python/mxnet/numpy/multiarray.py | 47 +++++++++++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/numpy_op_fallback.py | 55 +++++++++++++++++++ python/mxnet/symbol/numpy/_symbol.py | 39 ++++++++++++- .../unittest/test_numpy_interoperability.py | 8 +++ tests/python/unittest/test_numpy_op.py | 47 ++++++++++++++++ 7 files changed, 249 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index f9164855dfe9..47779b8aba8b 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -39,7 +39,7 @@ 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', - 'nan_to_num'] + 'nan_to_num', 'full_like'] @set_module('mxnet.ndarray.numpy') @@ -5211,6 +5211,59 @@ def resize(a, new_shape): return _npi.resize_fallback(a, new_shape=new_shape) +@set_module('mxnet.ndarray.numpy') +def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments + """ + Return a full array with the same shape and type as a given array. + + Parameters + ---------- + a : ndarray + The shape and data-type of `a` define these same attributes of + the returned array. + fill_value : scalar + Fill value. + dtype : data-type, optional + Overrides the data type of the result. + Temporarily do not support boolean type. + + Returns + ------- + out : ndarray + Array of `fill_value` with the same shape and type as `a`. + + See Also + -------- + empty_like : Return an empty array with shape and type of input. + ones_like : Return an array of ones with shape and type of input. + zeros_like : Return an array of zeros with shape and type of input. + full : Return a new array of given shape filled with value. + + Examples + -------- + >>> x = np.arange(6, dtype=int) + >>> np.full_like(x, 1) + array([1, 1, 1, 1, 1, 1], dtype=int64) + >>> np.full_like(x, 0.1) + array([0, 0, 0, 0, 0, 0], dtype=int64) + >>> np.full_like(x, 0.1, dtype=np.float64) + array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64) + >>> np.full_like(x, np.nan, dtype=np.double) + array([nan, nan, nan, nan, nan, nan], dtype=float64) + >>> y = np.arange(6, dtype=np.float32) + >>> np.full_like(y, 0.1) + array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + """ + dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32', + _np.int64:'int64', _np.float16:'float16', _np.float32:'float32', + _np.float64:'float64', _np.bool_:'bool'} + try: + dtype = dtype if isinstance(dtype, str) else dtype_list[dtype] + except: + raise NotImplementedError("Do not support this dtype at this moment") + return _npi.full_like_fallback(a, fill_value=fill_value, dtype=dtype) + + @set_module('mxnet.ndarray.numpy') def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): """ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b6816d75a98e..209f9368dcda 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -57,7 +57,7 @@ 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'full_like'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -7204,6 +7204,51 @@ def resize(a, new_shape): return _mx_nd_np.resize(a, new_shape) +@set_module('mxnet.numpy') +def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments + """ + Return a full array with the same shape and type as a given array. + + Parameters + ---------- + a : ndarray + The shape and data-type of `a` define these same attributes of + the returned array. + fill_value : scalar + Fill value. + dtype : data-type, optional + Overrides the data type of the result. + + Returns + ------- + out : ndarray + Array of `fill_value` with the same shape and type as `a`. + + See Also + -------- + empty_like : Return an empty array with shape and type of input. + ones_like : Return an array of ones with shape and type of input. + zeros_like : Return an array of zeros with shape and type of input. + full : Return a new array of given shape filled with value. + + Examples + -------- + >>> x = np.arange(6, dtype=int) + >>> np.full_like(x, 1) + array([1, 1, 1, 1, 1, 1], dtype=int64) + >>> np.full_like(x, 0.1) + array([0, 0, 0, 0, 0, 0], dtype=int64) + >>> np.full_like(x, 0.1, dtype=np.float64) + array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64) + >>> np.full_like(x, np.nan, dtype=np.double) + array([nan, nan, nan, nan, nan, nan], dtype=float64) + >>> y = np.arange(6, dtype=np.float32) + >>> np.full_like(y, 0.1) + array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + """ + return _mx_nd_np.full_like(a, fill_value=fill_value, dtype=dtype) + + @set_module('mxnet.numpy') def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): """ diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 025982cfc7a5..a46b918acd80 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -136,6 +136,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'may_share_memory', 'diff', 'resize', + 'full_like', ] diff --git a/python/mxnet/numpy_op_fallback.py b/python/mxnet/numpy_op_fallback.py index b98a211c7169..aaf29c1edc70 100644 --- a/python/mxnet/numpy_op_fallback.py +++ b/python/mxnet/numpy_op_fallback.py @@ -49,6 +49,61 @@ def _register_helper(prop_cls): return _register_helper +@use_np # enforce np shape and array semantics for all the methods in this class +class FullLike(operator.CustomOp): + """Fallback to NumPy full_like operator.""" + def __init__(self, fill_value, dtype): + super(FullLike, self).__init__() + self._fill_value = fill_value + self._dtype = dtype + + def forward(self, is_train, req, in_data, out_data, aux): + out = np.full_like(in_data[0].asnumpy(), self._fill_value, dtype=self._dtype) + self.assign(out_data[0], req[0], _mx_np.array(out, dtype=out.dtype, ctx=out_data[0].ctx)) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + raise NotImplementedError('Operator full_like does not support gradient computation') + + +@register('full_like_fallback') +class FullLikeProp(operator.CustomOpProp): + """Fallback full_like operator properties.""" + def __init__(self, fill_value, dtype): + super(FullLikeProp, self).__init__(need_top_grad=True) + self._fill_value = self.get_fill_value(fill_value) + self._dtype = None if dtype == 'None' else dtype + + @staticmethod + def get_fill_value(fill_value): + """Convert fill_value to corresponding data type""" + if fill_value == 'nan': + return np.nan + elif fill_value == 'inf': + return np.inf + elif fill_value == '-inf': + return -np.inf + try: + return ast.literal_eval(fill_value) + except: + raise NotImplementedError("Do not support fill_value %s at this moment"%(fill_value)) + + def list_arguments(self): + return ['a'] + + def infer_type(self, in_type): + if self._dtype is None: + return (in_type[0],), (in_type[0],), () + else: + out_dtype = eval('np.'+self._dtype) # pylint: disable=W0123 + return (in_type[0],), (out_dtype,), () + + def infer_shape(self, in_shape): + return (in_shape[0],), (in_shape[0],), () + + def create_operator(self, ctx, in_shapes, in_dtypes): + return FullLike(self._fill_value, self._dtype) + + @use_np # enforce np shape and array semantics for all the methods in this class class Resize(operator.CustomOp): """Fallback to NumPy resize operator.""" diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0d7303865b92..b24f9e136d79 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -41,7 +41,7 @@ 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', - 'resize', 'nan_to_num'] + 'resize', 'nan_to_num', 'full_like'] def _num_outputs(sym): @@ -4824,6 +4824,43 @@ def resize(a, new_shape): return _npi.resize_fallback(a, new_shape=new_shape) +@set_module('mxnet.symbol.numpy') +def full_like(a, fill_value=0, dtype=None): # pylint: disable=too-many-arguments + """ + Return a full array with the same shape and type as a given array. + + Parameters + ---------- + a : _Symbol + The shape and data-type of `a` define these same attributes of + the returned array. + fill_value : scalar + Fill value. + dtype : data-type, optional + Overrides the data type of the result. + + Returns + ------- + out : _Symbol + Array of `fill_value` with the same shape and type as `a`. + + See Also + -------- + empty_like : Return an empty array with shape and type of input. + ones_like : Return an array of ones with shape and type of input. + zeros_like : Return an array of zeros with shape and type of input. + full : Return a new array of given shape filled with value. + """ + dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32', + _np.int64:'int64', _np.float16:'float16', _np.float32:'float32', + _np.float64:'float64', _np.bool_:'bool'} + try: + dtype = dtype if isinstance(dtype, str) else dtype_list[dtype] + except: + raise NotImplementedError("Do not support this dtype at this moment") + return _npi.full_like_fallback(a, fill_value=fill_value, dtype=dtype) + + @set_module('mxnet.symbol.numpy') def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): """ diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6d9c63f9f857..41c185d995e8 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1173,6 +1173,13 @@ def _add_workload_resize(): OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 100)) +def _add_workload_full_like(): + OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(1,3,4), dtype='float64'), 1) + OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3,1)), 2, np.int64) + OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(9,3)), np.nan) + OpArgMngr.add_workload('full_like', np.random.uniform(low=0, high=100, size=(0,3)), 0, np.float32) + + @use_np def _prepare_workloads(): array_pool = { @@ -1284,6 +1291,7 @@ def _prepare_workloads(): _add_workload_less_equal(array_pool) _add_workload_diff() _add_workload_resize() + _add_workload_full_like() _prepare_workloads() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index fb022e06158c..a14564fc35f3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -17,6 +17,7 @@ # pylint: skip-file from __future__ import absolute_import +from distutils.version import StrictVersion import sys import unittest import itertools @@ -4463,6 +4464,52 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_full_like(): + class TestFullLike(HybridBlock): + def __init__(self, fill_value, dtype): + super(TestFullLike, self).__init__() + self._fill_value = fill_value + self._dtype = dtype + + def hybrid_forward(self, F, x, *args, **kwargs): + return F.np.full_like(x, self._fill_value, self._dtype) + + if StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): + return + + dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8'] + shapes = [ + (), + (1,), + (4, 3), + (4, 5), + (0, 1), + (6, 5, 6), + (4, 2, 1, 2), + (5, 0, 3, 3), + (3, 3, 0, 0), + ] + fill_values = [0, 1, 2, 3, 4, 5, 6] + flags = [True, False] + _np_version = _np.version.version + for fill_value, dtype, shape, hybridize in itertools.product( + fill_values, dtypes, shapes, flags): + param_dtype= _np.random.choice(dtypes) + a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype) + test = TestFullLike(fill_value, param_dtype) + expected_ret = _np.full_like(a, fill_value=fill_value, dtype=param_dtype) + if hybridize: + test.hybridize() + ret = test(a) + assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + + # check imperative again + ret = np.full_like(a, fill_value, param_dtype) + assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule()