Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy eye op
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Sep 10, 2019
1 parent 9675a2d commit 64c5f2b
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 41 deletions.
35 changes: 33 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import
import numpy as _np
from ...base import numeric_types
from ...util import set_module
from ...util import _sanity_check_params, set_module
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray
Expand All @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'eye']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2363,3 +2363,34 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
0.2025
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.ndarray.numpy')
def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
"""
Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the output.
M : int, optional
Number of columns in the output. If None, defaults to N.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
I : ndarray of shape (N,M)
An array where all elements are equal to zero,
except for the k-th diagonal, whose values are equal to one.
"""
_sanity_check_params('eye', ['order'], kwargs)
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
return _npi.eye(N, M, k, ctx, dtype)
29 changes: 28 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'eye']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -3808,3 +3808,30 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
0.2025
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.numpy')
def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
"""
Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the output.
M : int, optional
Number of columns in the output. If None, defaults to N.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
I : ndarray of shape (N,M)
An array where all elements are equal to zero,
except for the k-th diagonal, whose values are equal to one.
"""
return _mx_nd_np.eye(N, M, k, dtype, **kwargs)
35 changes: 33 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as _np
from . import _op as _mx_np_op
from ...base import _LIB, SymbolHandle, numeric_types, mx_uint
from ...util import check_call, set_module
from ...util import check_call, _sanity_check_params, set_module
from ...context import current_context
from ..symbol import Symbol
from .._internal import _set_np_symbol_class
Expand All @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'eye']


def _num_outputs(sym):
Expand Down Expand Up @@ -2678,4 +2678,35 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.symbol.numpy')
def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
"""
Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the output.
M : int, optional
Number of columns in the output. If None, defaults to N.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
I : ndarray of shape (N,M)
An array where all elements are equal to zero,
except for the k-th diagonal, whose values are equal to one.
"""
_sanity_check_params('eye', ['order'], kwargs)
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
return _npi.eye(N, M, k, ctx, dtype)


_set_np_symbol_class(_Symbol)
32 changes: 13 additions & 19 deletions src/operator/numpy/np_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
* \file np_init_op.cc
* \brief CPU Implementation of numpy init op
*/
#include "../tensor/init_op.h"
#include "../tensor/elemwise_unary_op.h"
#include "./np_init_op.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyEyeParam);

NNVM_REGISTER_OP(_npi_zeros)
.set_num_inputs(0)
.set_num_outputs(1)
Expand Down Expand Up @@ -84,23 +85,6 @@ NNVM_REGISTER_OP(_np_ones_like)
.add_argument("a", "NDArray-or-Symbol",
"The shape and data-type of a define these same attributes of the returned array.");

bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_shapes->size(), 0U);
CHECK_EQ(out_shapes->size(), 1U);
CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
double out_size = std::ceil((param.stop.value() - param.start) / param.step);
if (out_size < 0) {
out_size = 0;
}
SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
return true;
}

NNVM_REGISTER_OP(_npi_arange)
.set_num_inputs(0)
.set_num_outputs(1)
Expand All @@ -110,5 +94,15 @@ NNVM_REGISTER_OP(_npi_arange)
.set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu, RangeParam>)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_eye)
.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyEyeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyEyeShape)
.set_attr<nnvm::FInferType>("FInferType", InitType<NumpyEyeParam>)
.set_attr<FCompute>("FCompute<cpu>", NumpyEyeFill<cpu>)
.add_arguments(NumpyEyeParam::__FIELDS__());

} // namespace op
} // namespace mxnet
5 changes: 4 additions & 1 deletion src/operator/numpy/np_init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \brief GPU Implementation of numpy init op
*/

#include "../tensor/init_op.h"
#include "./np_init_op.h"

namespace mxnet {
namespace op {
Expand All @@ -43,5 +43,8 @@ NNVM_REGISTER_OP(_np_ones_like)
NNVM_REGISTER_OP(_npi_arange)
.set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu, RangeParam>);

NNVM_REGISTER_OP(_npi_eye)
.set_attr<FCompute>("FCompute<gpu>", NumpyEyeFill<gpu>);

} // namespace op
} // namespace mxnet
112 changes: 112 additions & 0 deletions src/operator/numpy/np_init_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_init_op.h
* \brief CPU Implementation of numpy init op
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_

#include <vector>
#include <string>
#include "../tensor/init_op.h"
#include "../tensor/elemwise_unary_op.h"


namespace mxnet {
namespace op {

struct NumpyEyeParam : public dmlc::Parameter<NumpyEyeParam> {
nnvm::dim_t N;
dmlc::optional<nnvm::dim_t> M;
nnvm::dim_t k;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(NumpyEyeParam) {
DMLC_DECLARE_FIELD(N)
.describe("Number of rows in the output.");
DMLC_DECLARE_FIELD(M)
.set_default(dmlc::optional<nnvm::dim_t>())
.describe("Number of columns in the output. If None, defaults to N.");
DMLC_DECLARE_FIELD(k)
.set_default(0)
.describe("Index of the diagonal. 0 (the default) refers to the main diagonal,"
"a positive value refers to an upper diagonal."
"and a negative value to a lower diagonal.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Data-type of the returned array.");
}
};

inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_shapes->size(), 0U);
CHECK_EQ(out_shapes->size(), 1U);
CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
double out_size = std::ceil((param.stop.value() - param.start) / param.step);
if (out_size < 0) {
out_size = 0;
}
SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
return true;
}

inline bool NumpyEyeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
nnvm::dim_t M = param.M.has_value() ? param.M.value() : param.N;
CHECK(param.N >= 0) << "negative dimensions are not allowed. N is " << param.N;
CHECK(M >= 0) << "negative dimensions are not allowed. M is " << M;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, M));

return out_attrs->at(0).ndim() != 0U;
}
template<typename xpu>
void NumpyEyeFill(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0) return; // zero-size tensor
const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
const nnvm::dim_t num_cols = param.M.has_value() ? param.M.value() : param.N;
EyeFillImpl<xpu>(outputs[0], ctx, req, num_cols, param.N, param.k);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
40 changes: 24 additions & 16 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,29 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu>
inline void EyeFillImpl(const TBlob& out_data,
const OpContext& ctx,
const std::vector<OpReqType>& req,
const nnvm::dim_t num_cols,
const nnvm::dim_t N,
const nnvm::dim_t k) {
using namespace mxnet_op;
const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
std::min(rnnz, num_cols);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Fill(s, out_data, req[0], static_cast<DType>(0));
if (nnz > 0) {
Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
}
});
});
}

template<typename xpu>
void EyeFill(const nnvm::NodeAttrs& attrs,
Expand All @@ -493,25 +516,10 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
const TBlob& out_data = outputs[0];
const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;

const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
std::min(rnnz, num_cols);
using namespace mxnet_op;
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Fill(s, out_data, req[0], static_cast<DType>(0));
if (nnz > 0) {
Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
}
});
});
EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
}


Expand Down
Loading

0 comments on commit 64c5f2b

Please sign in to comment.