Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy op identity & fix some typo
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Sep 25, 2019
1 parent 4002c3f commit 883a8b1
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 11 deletions.
47 changes: 44 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril']
'unique', 'lcm', 'tril', 'identity']


@set_module('mxnet.ndarray.numpy')
Expand All @@ -50,7 +50,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None):
The shape of the empty array.
dtype : str or numpy.dtype, optional
An optional value type. Default is `numpy.float32`. Note that this
behavior is different from NumPy's `ones` function where `float64`
behavior is different from NumPy's `zeros` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
Expand Down Expand Up @@ -96,7 +96,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
Returns
-------
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
Array of ones with the given shape, dtype, and ctx.
"""
if order != 'C':
raise NotImplementedError
Expand Down Expand Up @@ -213,6 +213,47 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None):
return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)


@set_module('mxnet.ndarray.numpy')
def identity(n, dtype=None, ctx=None):
"""
Return the identity array.
The identity array is a square array with ones on
the main diagonal.
Parameters
----------
n : int
Number of rows (and columns) in `n` x `n` output.
dtype : data-type, optional
Data-type of the output. Defaults to ``numpy.float32``.
ctx : Context, optional
An optional device context (default is the current default context).
Returns
-------
out : ndarray
`n` x `n` array with its main diagonal set to one,
and all other elements 0.
Examples
--------
>>> np.identity(3)
>>> np.identity(3)
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
"""
if not isinstance(n, int):
raise TypeError("Input 'n' should be an integer")
if n < 0:
raise ValueError("Input 'n' cannot be negative")
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype)


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
""" Helper function for element-wise operation.
Expand Down
42 changes: 38 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril']
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -1754,7 +1754,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None):
The shape of the empty array.
dtype : str or numpy.dtype, optional
An optional value type (default is `numpy.float32`). Note that this
behavior is different from NumPy's `ones` function where `float64`
behavior is different from NumPy's `zeros` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
Expand All @@ -1773,7 +1773,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None):

@set_module('mxnet.numpy')
def ones(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with zeros.
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand All @@ -1795,7 +1795,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
Returns
-------
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
Array of ones with the given shape, dtype, and ctx.
"""
return _mx_nd_np.ones(shape, dtype, order, ctx)

Expand Down Expand Up @@ -1856,6 +1856,40 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
return _mx_nd_np.full(shape, fill_value, order=order, ctx=ctx, dtype=dtype, out=out)


@set_module('mxnet.numpy')
def identity(n, dtype=None, ctx=None):
"""
Return the identity array.
The identity array is a square array with ones on
the main diagonal.
Parameters
----------
n : int
Number of rows (and columns) in `n` x `n` output.
dtype : data-type, optional
Data-type of the output. Defaults to ``numpy.float32``.
ctx : Context, optional
An optional device context (default is the current default context).
Returns
-------
out : ndarray
`n` x `n` array with its main diagonal set to one,
and all other elements 0.
Examples
--------
>>> np.identity(3)
>>> np.identity(3)
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
"""
return _mx_nd_np.identity(n, dtype, ctx)


@set_module('mxnet.numpy')
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None):
"""
Expand Down
39 changes: 36 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril']
'unique', 'lcm', 'tril', 'identity']


def _num_outputs(sym):
Expand Down Expand Up @@ -906,7 +906,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None):

@set_module('mxnet.symbol.numpy')
def ones(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with zeros.
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand All @@ -928,7 +928,7 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
Returns
-------
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
Array of ones with the given shape, dtype, and ctx.
"""
if order != 'C':
raise NotImplementedError
Expand Down Expand Up @@ -993,6 +993,39 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out)


@set_module('mxnet.symbol.numpy')
def identity(n, dtype=None, ctx=None):
"""
Return the identity array.
The identity array is a square array with ones on
the main diagonal.
Parameters
----------
n : int
Number of rows (and columns) in `n` x `n` output.
dtype : data-type, optional
Data-type of the output. Defaults to ``numpy.float32``.
ctx : Context, optional
An optional device context (default is the current default context).
Returns
-------
out : _Symbol
`n` x `n` array with its main diagonal set to one,
and all other elements 0.
"""
if not isinstance(n, int):
raise TypeError("Input 'n' should be an integer")
if n < 0:
raise ValueError("Input 'n' cannot be negative")
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype)


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
""" Helper function for element-wise operation.
Expand Down
10 changes: 10 additions & 0 deletions src/operator/numpy/np_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ NNVM_REGISTER_OP(_npi_ones)
.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 1>)
.add_arguments(InitOpParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_identity)
.describe("Return a new identity array of given shape, type, and context.")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InitOpParam>)
.set_attr<mxnet::FInferShape>("FInferShape", InitShape<InitOpParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.add_arguments(InitOpParam::__FIELDS__());

NNVM_REGISTER_OP(_np_zeros_like)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/numpy/np_init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ NNVM_REGISTER_OP(_npi_zeros)
NNVM_REGISTER_OP(_npi_ones)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);

NNVM_REGISTER_OP(_npi_identity)
.set_attr<FCompute>("FCompute<gpu>", IdentityCompute<gpu>);

NNVM_REGISTER_OP(_np_zeros_like)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);

Expand Down
41 changes: 40 additions & 1 deletion src/operator/numpy/np_init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
/*!
* Copyright (c) 2019 by Contributors
* \file np_init_op.h
* \brief CPU Implementation of numpy init op
* \brief Function definition of numpy init op
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_

Expand Down Expand Up @@ -65,6 +66,22 @@ struct indices_fwd {
}
};

template<int req>
struct identity {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const int n) {
using namespace mxnet_op;

const index_t row_id = i / n;
const index_t col_id = i % n;
if (row_id == col_id) {
KERNEL_ASSIGN(out_data[i], req, static_cast<DType>(1));
} else {
KERNEL_ASSIGN(out_data[i], req, static_cast<DType>(0));
}
}
};

template<typename xpu>
void IndicesCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -101,6 +118,28 @@ void IndicesCompute(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu>
void IdentityCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& out_data = outputs[0];
int n = out_data.shape_[0];
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<identity<req_type>, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), n);
});
});
}

} // namespace op
} // namespace mxnet

Expand Down
45 changes: 45 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,51 @@ def check_ones_array_creation(shape, dtype):
assert type(y[1]) == np.ndarray


@with_seed()
@use_np
def test_identity():
class TestIdentity(HybridBlock):
def __init__(self, shape, dtype=None):
super(TestIdentity, self).__init__()
self._n = n
self._dtype = dtype

def hybrid_forward(self, F, x):
return x * F.np.identity(self._n, self._dtype)

class TestIdentityOutputType(HybridBlock):
def hybrid_forward(self, F, x):
return x, F.np.identity(0)

def check_identity_array_creation(shape, dtype):
np_out = _np.identity(n=n, dtype=dtype)
mx_out = np.identity(n=n, dtype=dtype)
assert same(mx_out.asnumpy(), np_out)
if dtype is None:
assert mx_out.dtype == _np.float32
assert np_out.dtype == _np.float64

ns = [0, 1, 2, 3, 5, 15, 30, 200]
dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
for n in ns:
for dtype in dtypes:
check_identity_array_creation(n, dtype)
x = mx.nd.array(_np.random.uniform(size=(n, n)), dtype=dtype).as_np_ndarray()
if dtype is None:
x = x.astype('float32')
for hybridize in [True, False]:
test_identity = TestIdentity(n, dtype)
test_identity_output_type = TestIdentityOutputType()
if hybridize:
test_identity.hybridize()
test_identity_output_type.hybridize()
y = test_identity(x)
assert type(y) == np.ndarray
assert same(x.asnumpy() * _np.identity(n, dtype), y.asnumpy())
y = test_identity_output_type(x)
assert type(y[1]) == np.ndarray


@with_seed()
def test_np_ndarray_binary_element_wise_ops():
np_op_map = {
Expand Down

0 comments on commit 883a8b1

Please sign in to comment.