From 64c5f2ba143dbb6ff4cdfb24c4df6f7e6015748b Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 9 Sep 2019 17:53:13 -0700 Subject: [PATCH] numpy eye op --- python/mxnet/ndarray/numpy/_op.py | 35 +++++++- python/mxnet/numpy/multiarray.py | 29 ++++++- python/mxnet/symbol/numpy/_symbol.py | 35 +++++++- src/operator/numpy/np_init_op.cc | 32 +++---- src/operator/numpy/np_init_op.cu | 5 +- src/operator/numpy/np_init_op.h | 112 +++++++++++++++++++++++++ src/operator/tensor/init_op.h | 40 +++++---- tests/python/unittest/test_numpy_op.py | 68 +++++++++++++++ 8 files changed, 315 insertions(+), 41 deletions(-) create mode 100644 src/operator/numpy/np_init_op.h diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 964d533b2387..65c40363dc33 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -22,7 +22,7 @@ from __future__ import absolute_import import numpy as _np from ...base import numeric_types -from ...util import set_module +from ...util import _sanity_check_params, set_module from ...context import current_context from . import _internal as _npi from ..ndarray import NDArray @@ -33,7 +33,7 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var'] + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'eye'] @set_module('mxnet.ndarray.numpy') @@ -2363,3 +2363,34 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: 0.2025 """ return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + + +@set_module('mxnet.ndarray.numpy') +def eye(N, M=None, k=0, dtype=_np.float32, **kwargs): + """ + Return a 2-D array with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + N : int + Number of rows in the output. + M : int, optional + Number of columns in the output. If None, defaults to N. + k : int, optional + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, + and a negative value to a lower diagonal. + dtype : data-type, optional + Data-type of the returned array. + + Returns + ------- + I : ndarray of shape (N,M) + An array where all elements are equal to zero, + except for the k-th diagonal, whose values are equal to one. + """ + _sanity_check_params('eye', ['order'], kwargs) + ctx = kwargs.pop('ctx', current_context()) + if ctx is None: + ctx = current_context() + return _npi.eye(N, M, k, ctx, dtype) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 7d6e81a2d4a5..c5be78b30fe8 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -52,7 +52,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var'] + 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'eye'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -3808,3 +3808,30 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): 0.2025 """ return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + + +@set_module('mxnet.numpy') +def eye(N, M=None, k=0, dtype=_np.float32, **kwargs): + """ + Return a 2-D array with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + N : int + Number of rows in the output. + M : int, optional + Number of columns in the output. If None, defaults to N. + k : int, optional + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, + and a negative value to a lower diagonal. + dtype : data-type, optional + Data-type of the returned array. + + Returns + ------- + I : ndarray of shape (N,M) + An array where all elements are equal to zero, + except for the k-th diagonal, whose values are equal to one. + """ + return _mx_nd_np.eye(N, M, k, dtype, **kwargs) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index ec4f6a4dd741..b183c0c1b864 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -23,7 +23,7 @@ import numpy as _np from . import _op as _mx_np_op from ...base import _LIB, SymbolHandle, numeric_types, mx_uint -from ...util import check_call, set_module +from ...util import check_call, _sanity_check_params, set_module from ...context import current_context from ..symbol import Symbol from .._internal import _set_np_symbol_class @@ -35,7 +35,7 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var'] + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'eye'] def _num_outputs(sym): @@ -2678,4 +2678,35 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) +@set_module('mxnet.symbol.numpy') +def eye(N, M=None, k=0, dtype=_np.float32, **kwargs): + """ + Return a 2-D array with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + N : int + Number of rows in the output. + M : int, optional + Number of columns in the output. If None, defaults to N. + k : int, optional + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, + and a negative value to a lower diagonal. + dtype : data-type, optional + Data-type of the returned array. + + Returns + ------- + I : ndarray of shape (N,M) + An array where all elements are equal to zero, + except for the k-th diagonal, whose values are equal to one. + """ + _sanity_check_params('eye', ['order'], kwargs) + ctx = kwargs.pop('ctx', current_context()) + if ctx is None: + ctx = current_context() + return _npi.eye(N, M, k, ctx, dtype) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index b85a92f97683..def525a0f384 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -22,12 +22,13 @@ * \file np_init_op.cc * \brief CPU Implementation of numpy init op */ -#include "../tensor/init_op.h" -#include "../tensor/elemwise_unary_op.h" +#include "./np_init_op.h" namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(NumpyEyeParam); + NNVM_REGISTER_OP(_npi_zeros) .set_num_inputs(0) .set_num_outputs(1) @@ -84,23 +85,6 @@ NNVM_REGISTER_OP(_np_ones_like) .add_argument("a", "NDArray-or-Symbol", "The shape and data-type of a define these same attributes of the returned array."); -bool NumpyRangeShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_shapes, - mxnet::ShapeVector* out_shapes) { - const RangeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_shapes->size(), 0U); - CHECK_EQ(out_shapes->size(), 1U); - CHECK_NE(param.step, 0) << "_npi_arange does not support step=0"; - CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat; - CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value"; - double out_size = std::ceil((param.stop.value() - param.start) / param.step); - if (out_size < 0) { - out_size = 0; - } - SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast(out_size)})); - return true; -} - NNVM_REGISTER_OP(_npi_arange) .set_num_inputs(0) .set_num_outputs(1) @@ -110,5 +94,15 @@ NNVM_REGISTER_OP(_npi_arange) .set_attr("FCompute", RangeCompute) .add_arguments(RangeParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_eye) +.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyEyeShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", NumpyEyeFill) +.add_arguments(NumpyEyeParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index fe631f388a19..7f0d587a55de 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -23,7 +23,7 @@ * \brief GPU Implementation of numpy init op */ -#include "../tensor/init_op.h" +#include "./np_init_op.h" namespace mxnet { namespace op { @@ -43,5 +43,8 @@ NNVM_REGISTER_OP(_np_ones_like) NNVM_REGISTER_OP(_npi_arange) .set_attr("FCompute", RangeCompute); +NNVM_REGISTER_OP(_npi_eye) +.set_attr("FCompute", NumpyEyeFill); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h new file mode 100644 index 000000000000..6c0cc669291c --- /dev/null +++ b/src/operator/numpy/np_init_op.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_init_op.h + * \brief CPU Implementation of numpy init op + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ + +#include +#include +#include "../tensor/init_op.h" +#include "../tensor/elemwise_unary_op.h" + + +namespace mxnet { +namespace op { + +struct NumpyEyeParam : public dmlc::Parameter { + nnvm::dim_t N; + dmlc::optional M; + nnvm::dim_t k; + std::string ctx; + int dtype; + DMLC_DECLARE_PARAMETER(NumpyEyeParam) { + DMLC_DECLARE_FIELD(N) + .describe("Number of rows in the output."); + DMLC_DECLARE_FIELD(M) + .set_default(dmlc::optional()) + .describe("Number of columns in the output. If None, defaults to N."); + DMLC_DECLARE_FIELD(k) + .set_default(0) + .describe("Index of the diagonal. 0 (the default) refers to the main diagonal," + "a positive value refers to an upper diagonal." + "and a negative value to a lower diagonal."); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .set_default(mshadow::kFloat32) + MXNET_ADD_ALL_TYPES + .describe("Data-type of the returned array."); + } +}; + +inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const RangeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shapes->size(), 0U); + CHECK_EQ(out_shapes->size(), 1U); + CHECK_NE(param.step, 0) << "_npi_arange does not support step=0"; + CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat; + CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value"; + double out_size = std::ceil((param.stop.value() - param.start) / param.step); + if (out_size < 0) { + out_size = 0; + } + SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast(out_size)})); + return true; +} + +inline bool NumpyEyeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const NumpyEyeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 0U); + CHECK_EQ(out_attrs->size(), 1U); + nnvm::dim_t M = param.M.has_value() ? param.M.value() : param.N; + CHECK(param.N >= 0) << "negative dimensions are not allowed. N is " << param.N; + CHECK(M >= 0) << "negative dimensions are not allowed. M is " << M; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, M)); + + return out_attrs->at(0).ndim() != 0U; +} +template +void NumpyEyeFill(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 0U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0) return; // zero-size tensor + const NumpyEyeParam& param = nnvm::get(attrs.parsed); + const nnvm::dim_t num_cols = param.M.has_value() ? param.M.value() : param.N; + EyeFillImpl(outputs[0], ctx, req, num_cols, param.N, param.k); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index f3c405d7103c..9ed4f515b712 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -483,6 +483,29 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, } } +template +inline void EyeFillImpl(const TBlob& out_data, + const OpContext& ctx, + const std::vector& req, + const nnvm::dim_t num_cols, + const nnvm::dim_t N, + const nnvm::dim_t k) { + using namespace mxnet_op; + const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0); + const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0); + const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) : + std::min(rnnz, num_cols); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Fill(s, out_data, req[0], static_cast(0)); + if (nnz > 0) { + Kernel, xpu>::Launch(s, nnz, out_data.dptr(), + std::max(static_cast(0), k), k, num_cols); + } + }); + }); +} template void EyeFill(const nnvm::NodeAttrs& attrs, @@ -493,25 +516,10 @@ void EyeFill(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 0U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - mshadow::Stream *s = ctx.get_stream(); const EyeParam& param = nnvm::get(attrs.parsed); const TBlob& out_data = outputs[0]; const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N; - - const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0); - const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0); - const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) : - std::min(rnnz, num_cols); - using namespace mxnet_op; - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Fill(s, out_data, req[0], static_cast(0)); - if (nnz > 0) { - Kernel, xpu>::Launch(s, nnz, out_data.dptr(), - std::max(static_cast(0), param.k), param.k, num_cols); - } - }); - }); + EyeFillImpl(out_data, ctx, req, num_cols, param.N, param.k); } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1c9ebbb32866..82de98cf0258 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1744,6 +1744,74 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) +@with_seed() +@use_np +def test_np_eye(): + configs = [ + 4, + 1000, + (4, 3), + (5, None), + (4, None, 1), + (2, 2, 1), + (4, 6, 1), + (7, 3, -3), + (3, 2, -2), + (4, 0), + (0, 0), + (0, 3), + (0, 0, -2) + ] + exception_configs = [ + -1, + -1000, + (-2, None), + (1, -1) + ] + dtypes = ['int32', 'float16', 'float32', 'float64', None] + for config in configs: + for dtype in dtypes: + if isinstance(config, tuple): + mx_ret = np.eye(*config, dtype=dtype) + np_ret = _np.eye(*config, dtype=dtype) + else: + mx_ret = np.eye(config, dtype=dtype) + np_ret = _np.eye(config, dtype=dtype) + assert same(mx_ret.asnumpy(), np_ret) + # check for exception input + for config in exception_configs: + if isinstance(config, tuple): + assertRaises(MXNetError, np.eye, *config) + else: + assertRaises(MXNetError, np.eye, config) + + class TestEye(HybridBlock): + def __init__(self, N, M=None, k=0, dtype=None): + super(TestEye, self).__init__() + self._N = N + self._M = M + self._k = k + self._dtype = dtype + + def hybrid_forward(self, F, x): + return x + F.np.eye(self._N, self._M, self._k, dtype=self._dtype) + + for dtype in dtypes: + x = np.zeros(shape=(), dtype=dtype) + for config in configs: + for hybridize in [False, True]: + if isinstance(config, tuple): + net = TestEye(*config, dtype=dtype) + np_out = _np.eye(*config, dtype=dtype) + else: + net = TestEye(config, dtype=dtype) + np_out = _np.eye(config, dtype=dtype) + if hybridize: + net.hybridize() + mx_out = net(x) + assert same(mx_out.asnumpy(), np_out) + + if __name__ == '__main__': import nose nose.runmodule()