Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy eye op (#16132)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 authored and haojin2 committed Oct 14, 2019
1 parent 974327e commit 80982ec
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 37 deletions.
35 changes: 33 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import
import numpy as _np
from ...base import numeric_types
from ...util import set_module
from ...util import _sanity_check_params, set_module
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray
Expand All @@ -31,7 +31,7 @@
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -788,6 +788,37 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
raise ValueError("np.histogram fails with", locals())


@set_module('mxnet.ndarray.numpy')
def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
"""
Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the output.
M : int, optional
Number of columns in the output. If None, defaults to N.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
I : ndarray of shape (N,M)
An array where all elements are equal to zero,
except for the k-th diagonal, whose values are equal to one.
"""
_sanity_check_params('eye', ['order'], kwargs)
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
return _npi.eye(N, M, k, ctx, dtype)


@set_module('mxnet.ndarray.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
29 changes: 28 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
Expand Down Expand Up @@ -3645,6 +3645,33 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
return _mx_nd_np.histogram(a, bins=bins, range=range, normed=normed, weights=weights, density=density)


@set_module('mxnet.numpy')
def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
"""
Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the output.
M : int, optional
Number of columns in the output. If None, defaults to N.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
I : ndarray of shape (N,M)
An array where all elements are equal to zero,
except for the k-th diagonal, whose values are equal to one.
"""
return _mx_nd_np.eye(N, M, k, dtype, **kwargs)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
33 changes: 32 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -1305,6 +1305,37 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
raise ValueError("histogram fails with", locals())


@set_module('mxnet.symbol.numpy')
def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
"""
Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the output.
M : int, optional
Number of columns in the output. If None, defaults to N.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
I : ndarray of shape (N,M)
An array where all elements are equal to zero,
except for the k-th diagonal, whose values are equal to one.
"""
_sanity_check_params('eye', ['order'], kwargs)
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
return _npi.eye(N, M, k, ctx, dtype)


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
30 changes: 13 additions & 17 deletions src/operator/numpy/np_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
* \file np_init_op.cc
* \brief CPU Implementation of numpy init op
*/

#include "../tensor/init_op.h"
#include "../tensor/elemwise_unary_op.h"
#include "./np_init_op.h"

namespace mxnet {
namespace op {


DMLC_REGISTER_PARAMETER(NumpyEyeParam);
DMLC_REGISTER_PARAMETER(IndicesOpParam);

inline bool NumpyIndicesShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -117,23 +120,6 @@ NNVM_REGISTER_OP(_np_ones_like)
.add_argument("a", "NDArray-or-Symbol",
"The shape and data-type of a define these same attributes of the returned array.");

bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_shapes->size(), 0U);
CHECK_EQ(out_shapes->size(), 1U);
CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
double out_size = std::ceil((param.stop.value() - param.start) / param.step);
if (out_size < 0) {
out_size = 0;
}
SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
return true;
}

NNVM_REGISTER_OP(_npi_arange)
.set_num_inputs(0)
.set_num_outputs(1)
Expand All @@ -143,6 +129,16 @@ NNVM_REGISTER_OP(_npi_arange)
.set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu, RangeParam>)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_eye)
.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyEyeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyEyeShape)
.set_attr<nnvm::FInferType>("FInferType", InitType<NumpyEyeParam>)
.set_attr<FCompute>("FCompute<cpu>", NumpyEyeFill<cpu>)
.add_arguments(NumpyEyeParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_indices)
.describe("Return an array representing the indices of a grid.")
.set_num_inputs(0)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/numpy/np_init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ NNVM_REGISTER_OP(_np_ones_like)
NNVM_REGISTER_OP(_npi_arange)
.set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu, RangeParam>);

NNVM_REGISTER_OP(_npi_eye)
.set_attr<FCompute>("FCompute<gpu>", NumpyEyeFill<gpu>);

NNVM_REGISTER_OP(_npi_indices)
.set_attr<FCompute>("FCompute<gpu>", IndicesCompute<gpu>);

Expand Down
72 changes: 72 additions & 0 deletions src/operator/numpy/np_init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@
namespace mxnet {
namespace op {

struct NumpyEyeParam : public dmlc::Parameter<NumpyEyeParam> {
nnvm::dim_t N;
dmlc::optional<nnvm::dim_t> M;
nnvm::dim_t k;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(NumpyEyeParam) {
DMLC_DECLARE_FIELD(N)
.describe("Number of rows in the output.");
DMLC_DECLARE_FIELD(M)
.set_default(dmlc::optional<nnvm::dim_t>())
.describe("Number of columns in the output. If None, defaults to N.");
DMLC_DECLARE_FIELD(k)
.set_default(0)
.describe("Index of the diagonal. 0 (the default) refers to the main diagonal,"
"a positive value refers to an upper diagonal."
"and a negative value to a lower diagonal.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Data-type of the returned array.");
}
};

struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
mxnet::TShape dimensions;
int dtype;
Expand All @@ -52,6 +80,50 @@ struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
}
};

inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_shapes->size(), 0U);
CHECK_EQ(out_shapes->size(), 1U);
CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
double out_size = std::ceil((param.stop.value() - param.start) / param.step);
if (out_size < 0) {
out_size = 0;
}
SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
return true;
}

inline bool NumpyEyeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
nnvm::dim_t M = param.M.has_value() ? param.M.value() : param.N;
CHECK(param.N >= 0) << "negative dimensions are not allowed. N is " << param.N;
CHECK(M >= 0) << "negative dimensions are not allowed. M is " << M;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, M));

return out_attrs->at(0).ndim() != 0U;
}
template<typename xpu>
void NumpyEyeFill(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0) return; // zero-size tensor
const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
const nnvm::dim_t num_cols = param.M.has_value() ? param.M.value() : param.N;
EyeFillImpl<xpu>(outputs[0], ctx, req, num_cols, param.N, param.k);
}

template<int req>
struct indices_fwd {
template<typename DType>
Expand Down
40 changes: 24 additions & 16 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,29 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu>
inline void EyeFillImpl(const TBlob& out_data,
const OpContext& ctx,
const std::vector<OpReqType>& req,
const nnvm::dim_t num_cols,
const nnvm::dim_t N,
const nnvm::dim_t k) {
using namespace mxnet_op;
const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
std::min(rnnz, num_cols);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Fill(s, out_data, req[0], static_cast<DType>(0));
if (nnz > 0) {
Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
}
});
});
}

template<typename xpu>
void EyeFill(const nnvm::NodeAttrs& attrs,
Expand All @@ -493,25 +516,10 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
const TBlob& out_data = outputs[0];
const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;

const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
std::min(rnnz, num_cols);
using namespace mxnet_op;
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Fill(s, out_data, req[0], static_cast<DType>(0));
if (nnz > 0) {
Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
}
});
});
EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
}


Expand Down
Loading

0 comments on commit 80982ec

Please sign in to comment.