diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index ea3459da84de..2b3b53b2af52 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril'] + 'unique', 'lcm', 'tril', 'identity'] @set_module('mxnet.ndarray.numpy') @@ -50,7 +50,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): The shape of the empty array. dtype : str or numpy.dtype, optional An optional value type. Default is `numpy.float32`. Note that this - behavior is different from NumPy's `ones` function where `float64` + behavior is different from NumPy's `zeros` function where `float64` is the default value, because `float32` is considered as the default data type in deep learning. order : {'C'}, optional, default: 'C' @@ -96,7 +96,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): Returns ------- out : ndarray - Array of zeros with the given shape, dtype, and ctx. + Array of ones with the given shape, dtype, and ctx. """ if order != 'C': raise NotImplementedError @@ -213,6 +213,47 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None): return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx) +@set_module('mxnet.ndarray.numpy') +def identity(n, dtype=None, ctx=None): + """ + Return the identity array. + + The identity array is a square array with ones on + the main diagonal. + + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``numpy.float32``. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : ndarray + `n` x `n` array with its main diagonal set to one, + and all other elements 0. + + Examples + -------- + >>> np.identity(3) + >>> np.identity(3) + array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + """ + if not isinstance(n, int): + raise TypeError("Input 'n' should be an integer") + if n < 0: + raise ValueError("Input 'n' cannot be negative") + if ctx is None: + ctx = current_context() + dtype = _np.float32 if dtype is None else dtype + return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype) + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 48753d5724f8..a0c25ab5c525 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -54,7 +54,7 @@ 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril'] + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -1754,7 +1754,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): The shape of the empty array. dtype : str or numpy.dtype, optional An optional value type (default is `numpy.float32`). Note that this - behavior is different from NumPy's `ones` function where `float64` + behavior is different from NumPy's `zeros` function where `float64` is the default value, because `float32` is considered as the default data type in deep learning. order : {'C'}, optional, default: 'C' @@ -1773,7 +1773,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): @set_module('mxnet.numpy') def ones(shape, dtype=_np.float32, order='C', ctx=None): - """Return a new array of given shape and type, filled with zeros. + """Return a new array of given shape and type, filled with ones. This function currently only supports storing multi-dimensional data in row-major (C-style). @@ -1795,7 +1795,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): Returns ------- out : ndarray - Array of zeros with the given shape, dtype, and ctx. + Array of ones with the given shape, dtype, and ctx. """ return _mx_nd_np.ones(shape, dtype, order, ctx) @@ -1856,6 +1856,40 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin return _mx_nd_np.full(shape, fill_value, order=order, ctx=ctx, dtype=dtype, out=out) +@set_module('mxnet.numpy') +def identity(n, dtype=None, ctx=None): + """ + Return the identity array. + + The identity array is a square array with ones on + the main diagonal. + + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``numpy.float32``. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : ndarray + `n` x `n` array with its main diagonal set to one, + and all other elements 0. + + Examples + -------- + >>> np.identity(3) + >>> np.identity(3) + array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + """ + return _mx_nd_np.identity(n, dtype, ctx) + + @set_module('mxnet.numpy') def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index c4457df42d6f..e3820601e2ec 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -37,7 +37,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril'] + 'unique', 'lcm', 'tril', 'identity'] def _num_outputs(sym): @@ -906,7 +906,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): @set_module('mxnet.symbol.numpy') def ones(shape, dtype=_np.float32, order='C', ctx=None): - """Return a new array of given shape and type, filled with zeros. + """Return a new array of given shape and type, filled with ones. This function currently only supports storing multi-dimensional data in row-major (C-style). @@ -928,7 +928,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): Returns ------- out : ndarray - Array of zeros with the given shape, dtype, and ctx. + Array of ones with the given shape, dtype, and ctx. """ if order != 'C': raise NotImplementedError @@ -993,6 +993,39 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out) +@set_module('mxnet.symbol.numpy') +def identity(n, dtype=None, ctx=None): + """ + Return the identity array. + + The identity array is a square array with ones on + the main diagonal. + + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``numpy.float32``. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + out : _Symbol + `n` x `n` array with its main diagonal set to one, + and all other elements 0. + """ + if not isinstance(n, int): + raise TypeError("Input 'n' should be an integer") + if n < 0: + raise ValueError("Input 'n' cannot be negative") + if ctx is None: + ctx = current_context() + dtype = _np.float32 if dtype is None else dtype + return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype) + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index 4f031bdaa050..2477573c2413 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -71,6 +71,16 @@ NNVM_REGISTER_OP(_npi_ones) .set_attr("FCompute", FillCompute) .add_arguments(InitOpParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_identity) +.describe("Return a new identity array of given shape, type, and context.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", InitShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", IdentityCompute) +.add_arguments(InitOpParam::__FIELDS__()); + NNVM_REGISTER_OP(_np_zeros_like) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index 49f1051735d8..e68dd9ad36a1 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -35,6 +35,9 @@ NNVM_REGISTER_OP(_npi_zeros) NNVM_REGISTER_OP(_npi_ones) .set_attr("FCompute", FillCompute); +NNVM_REGISTER_OP(_npi_identity) +.set_attr("FCompute", IdentityCompute); + NNVM_REGISTER_OP(_np_zeros_like) .set_attr("FCompute", FillCompute); diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h index 5c41820b57f8..3e1c345d59c3 100644 --- a/src/operator/numpy/np_init_op.h +++ b/src/operator/numpy/np_init_op.h @@ -20,8 +20,9 @@ /*! * Copyright (c) 2019 by Contributors * \file np_init_op.h - * \brief CPU Implementation of numpy init op + * \brief Function definition of numpy init op */ + #ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ #define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ @@ -65,6 +66,22 @@ struct indices_fwd { } }; +template +struct identity { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const int n) { + using namespace mxnet_op; + + const index_t row_id = i / n; + const index_t col_id = i % n; + if (row_id == col_id) { + KERNEL_ASSIGN(out_data[i], req, static_cast(1)); + } else { + KERNEL_ASSIGN(out_data[i], req, static_cast(0)); + } + } +}; + template void IndicesCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -101,6 +118,28 @@ void IndicesCompute(const nnvm::NodeAttrs& attrs, } } +template +void IdentityCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 0U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + Stream *s = ctx.get_stream(); + const TBlob& out_data = outputs[0]; + int n = out_data.shape_[0]; + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, out_data.Size(), out_data.dptr(), n); + }); + }); +} + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 923ee53fc400..3a5e72b53d58 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -181,6 +181,51 @@ def check_ones_array_creation(shape, dtype): assert type(y[1]) == np.ndarray +@with_seed() +@use_np +def test_identity(): + class TestIdentity(HybridBlock): + def __init__(self, shape, dtype=None): + super(TestIdentity, self).__init__() + self._n = n + self._dtype = dtype + + def hybrid_forward(self, F, x): + return x * F.np.identity(self._n, self._dtype) + + class TestIdentityOutputType(HybridBlock): + def hybrid_forward(self, F, x): + return x, F.np.identity(0) + + def check_identity_array_creation(shape, dtype): + np_out = _np.identity(n=n, dtype=dtype) + mx_out = np.identity(n=n, dtype=dtype) + assert same(mx_out.asnumpy(), np_out) + if dtype is None: + assert mx_out.dtype == _np.float32 + assert np_out.dtype == _np.float64 + + ns = [0, 1, 2, 3, 5, 15, 30, 200] + dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] + for n in ns: + for dtype in dtypes: + check_identity_array_creation(n, dtype) + x = mx.nd.array(_np.random.uniform(size=(n, n)), dtype=dtype).as_np_ndarray() + if dtype is None: + x = x.astype('float32') + for hybridize in [True, False]: + test_identity = TestIdentity(n, dtype) + test_identity_output_type = TestIdentityOutputType() + if hybridize: + test_identity.hybridize() + test_identity_output_type.hybridize() + y = test_identity(x) + assert type(y) == np.ndarray + assert same(x.asnumpy() * _np.identity(n, dtype), y.asnumpy()) + y = test_identity_output_type(x) + assert type(y[1]) == np.ndarray + + @with_seed() def test_np_ndarray_binary_element_wise_ops(): np_op_map = {