From e1435a3de140e60dbd5f2a1965ea7d9c94399124 Mon Sep 17 00:00:00 2001 From: vexilligera Date: Fri, 24 Jan 2020 18:27:25 +0000 Subject: [PATCH] [NumPy] Add NumPy support for norm (#17014) * norm * full test * add default behaviour * add col norm forward * add matrix col row norms forward * row col norm backward * improve tests * beautify cpp * update broadcast_op * C1002 * billie holiday even told it even better * probing for windows unittest numpy error * update test * update test * fix style * retrigger unix ci * update according to reviews * fix backward set_num_input * fix CI --- 3rdparty/mshadow/mshadow/base.h | 10 + python/mxnet/ndarray/numpy/linalg.py | 161 +++- python/mxnet/symbol/numpy/linalg.py | 164 +++- .../broadcast_reduce_customized-inl.cuh | 416 +++++++++ .../linalg/broadcast_reduce_customized-inl.h | 181 ++++ .../linalg/broadcast_reduce_op_customized.h | 168 ++++ src/operator/numpy/linalg/np_gesvd-inl.h | 13 +- src/operator/numpy/linalg/np_norm-inl.h | 836 ++++++++++++++++++ src/operator/numpy/linalg/np_norm.cc | 204 +++++ src/operator/numpy/linalg/np_norm_backward.cc | 43 + src/operator/numpy/linalg/np_norm_backward.cu | 33 + src/operator/numpy/linalg/np_norm_forward.cc | 49 + src/operator/numpy/linalg/np_norm_forward.cu | 33 + .../unittest/test_numpy_interoperability.py | 16 +- tests/python/unittest/test_numpy_op.py | 120 ++- 15 files changed, 2376 insertions(+), 71 deletions(-) create mode 100644 src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh create mode 100644 src/operator/numpy/linalg/broadcast_reduce_customized-inl.h create mode 100644 src/operator/numpy/linalg/broadcast_reduce_op_customized.h create mode 100644 src/operator/numpy/linalg/np_norm-inl.h create mode 100644 src/operator/numpy/linalg/np_norm.cc create mode 100644 src/operator/numpy/linalg/np_norm_backward.cc create mode 100644 src/operator/numpy/linalg/np_norm_backward.cu create mode 100644 src/operator/numpy/linalg/np_norm_forward.cc create mode 100644 src/operator/numpy/linalg/np_norm_forward.cu diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index ec5197317221..e7b86832e408 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -713,6 +713,11 @@ template<> MSHADOW_XINLINE bool MinValue(void) { return false; } +/*! \brief minimum value of unsigned int */ +template<> +MSHADOW_XINLINE unsigned int MinValue(void) { + return 0; +} /*! * \brief negative infinity of certain types @@ -785,6 +790,11 @@ template<> MSHADOW_XINLINE bool MaxValue(void) { return true; } +/*! \brief maximum value of uint32_t */ +template<> +MSHADOW_XINLINE uint32_t MaxValue(void) { + return -1; +} /*! * \brief positive infinity of certain types diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 51be85851a9b..3fa37c87a594 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -96,49 +96,166 @@ def pinv(a, rcond=1e-15, hermitian=False): return _npi.pinv(a, rcond, hermitian) +# pylint: disable=too-many-return-statements def norm(x, ord=None, axis=None, keepdims=False): r"""Matrix or vector norm. - - This function can only support Frobenius norm for now. - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - + This function is able to return one of eight different matrix norms, + or one of an infinite number of vector norms (described below), depending + on the value of the ``ord`` parameter. Parameters ---------- x : ndarray - Input array. - ord : {'fro'}, optional - Order of the norm. + Input array. If `axis` is None, `x` must be 1-D or 2-D. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means numpy's + `inf` object. axis : {int, 2-tuple of ints, None}, optional If `axis` is an integer, it specifies the axis of `x` along which to compute the vector norms. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None, the norm of the whole ndarray is - returned. - + are computed. If `axis` is None then either a vector norm (when `x` + is 1-D) or a matrix norm (when `x` is 2-D) is returned. keepdims : bool, optional If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original `x`. - Returns ------- - n : float or ndarray + n : ndarray Norm of the matrix or vector(s). - + Notes + ----- + For values of ``ord <= 0``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + The following norms can be calculated: + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + 'nuc' -- -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 -- as below + -2 -- as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + The Frobenius norm is given by [1]_: + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + The nuclear norm is the sum of the singular values. + When you want to operate norm for matrices,if you ord is (-1, 1, inf, -inf), + you must give you axis, it is not support default axis. References ---------- .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + Examples + -------- + >>> from mxnet import np + >>> a = np.arange(9) - 4 + >>> a + array([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> b = a.reshape((3, 3)) + >>> b + array([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + >>> np.linalg.norm(a) + array(7.745967) + >>> np.linalg.norm(b) + array(7.745967) + >>> np.linalg.norm(b, 'fro') + array(7.745967) + >>> np.linalg.norm(a, 'inf') + array(4.) + >>> np.linalg.norm(b, 'inf', axis=(0, 1)) + array(9.) + >>> np.linalg.norm(a, '-inf') + array(0.) + >>> np.linalg.norm(b, '-inf', axis=(0, 1)) + array(2.) + >>> np.linalg.norm(a, 1) + array(20.) + >>> np.linalg.norm(b, 1, axis=(0, 1)) + array(7.) + >>> np.linalg.norm(a, -1) + array(0.) + >>> np.linalg.norm(b, -1, axis=(0, 1)) + array(6.) + >>> np.linalg.norm(a, 2) + array(7.745967) + >>> np.linalg.norm(a, -2) + array(0.) + >>> np.linalg.norm(a, 3) + array(5.8480353) + >>> np.linalg.norm(a, -3) + array(0.) + Using the `axis` argument to compute vector norms: + >>> c = np.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> np.linalg.norm(c, axis=0) + array([1.4142135, 2.236068 , 5. ]) + >>> np.linalg.norm(c, axis=1) + array([3.7416573, 4.2426405]) + >>> np.linalg.norm(c, ord=1, axis=1) + array([6., 6.]) + Using the `axis` argument to compute matrix norms: + >>> m = np.arange(8).reshape(2,2,2) + >>> np.linalg.norm(m, axis=(1,2)) + array([ 3.7416573, 11.224973 ]) + >>> np.linalg.norm(m[0, :, :]), np.linalg.norm(m[1, :, :]) + (array(3.7416573), array(11.224973)) """ - if ord is not None and ord != 'fro': - raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord))) - if isinstance(axis, tuple) and len(axis) > 2: - raise ValueError('Improper number of dimensions to norm') - if ord == 'fro' and x.ndim > 2 and axis is None: - raise ValueError('Improper number of dimensions to norm') - return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims)) + if axis is None and ord is None: + return _npi.norm(x, ord=2, axis=None, keepdims=keepdims, flag=-2) + if axis is None or isinstance(axis, (int, tuple)): # pylint: disable=too-many-nested-blocks + if axis is not None: + if isinstance(axis, int): + axis = (axis, ) + if len(axis) == 2: + if ord in ['inf', '-inf']: + row_axis, col_axis = axis + if not keepdims: + if row_axis > col_axis: + row_axis -= 1 + if ord == 'inf': + return _mx_nd_np.sum(_mx_nd_np.abs(x), axis=col_axis, keepdims=keepdims).max(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long + else: + return _mx_nd_np.sum(_mx_nd_np.abs(x), axis=col_axis, keepdims=keepdims).min(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long + if ord in [1, -1]: + row_axis, col_axis = axis + if not keepdims: + if row_axis < col_axis: + col_axis -= 1 + if ord == 1: + return _mx_nd_np.sum(_mx_nd_np.abs(x), axis=row_axis, keepdims=keepdims).max(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long + elif ord == -1: + return _mx_nd_np.sum(_mx_nd_np.abs(x), axis=row_axis, keepdims=keepdims).min(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long + if ord in [2, -2]: + return _npi.norm(x, ord=ord, axis=axis, keepdims=keepdims, flag=0) + if ord is None: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=1) + if ord == 'inf': + return _mx_nd_np.max(_mx_nd_np.abs(x), axis=axis, keepdims=keepdims) + elif ord == '-inf': + return _mx_nd_np.min(_mx_nd_np.abs(x), axis=axis, keepdims=keepdims) + elif ord is None: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=1) + elif ord == 2: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=-1) + elif ord == 'nuc': + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=2) + elif ord in ['fro', 'f']: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=1) + else: + return _npi.norm(x, ord=ord, axis=axis, keepdims=keepdims, flag=-1) + else: + raise TypeError("'axis' must be None, an integer or a tuple of integers.") +# pylint: enable=too-many-return-statements def svd(a): diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 979742001aa8..d4bf8dcb4c25 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -96,48 +96,168 @@ def pinv(a, rcond=1e-15, hermitian=False): return _npi.pinv(a, rcond, hermitian) +# pylint: disable=too-many-return-statements def norm(x, ord=None, axis=None, keepdims=False): r"""Matrix or vector norm. - - This function can only support Frobenius norm for now. - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - + This function is able to return one of eight different matrix norms, + or one of an infinite number of vector norms (described below), depending + on the value of the ``ord`` parameter. Parameters ---------- - x : ndarray - Input array. - ord : {'fro'}, optional - Order of the norm. + x : _Symbol + Input array. If `axis` is None, `x` must be 1-D or 2-D. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means numpy's + `inf` object. axis : {int, 2-tuple of ints, None}, optional If `axis` is an integer, it specifies the axis of `x` along which to compute the vector norms. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None, the norm of the whole ndarray is - returned. - + are computed. If `axis` is None then either a vector norm (when `x` + is 1-D) or a matrix norm (when `x` is 2-D) is returned. keepdims : bool, optional If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original `x`. - Returns ------- - n : float or ndarray + n : _Symbol Norm of the matrix or vector(s). - + Notes + ----- + For values of ``ord <= 0``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + The following norms can be calculated: + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + 'nuc' -- -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 -- as below + -2 -- as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + The Frobenius norm is given by [1]_: + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + The nuclear norm is the sum of the singular values. + When you want to operate norm for matrices,if you ord is (-1, 1, inf, -inf), + you must give you axis, it is not support default axis. References ---------- .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + Examples + -------- + >>> from mxnet import np + >>> a = np.arange(9) - 4 + >>> a + array([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> b = a.reshape((3, 3)) + >>> b + array([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + >>> np.linalg.norm(a) + array(7.745967) + >>> np.linalg.norm(b) + array(7.745967) + >>> np.linalg.norm(b, 'fro') + array(7.745967) + >>> np.linalg.norm(a, 'inf') + array(4.) + >>> np.linalg.norm(b, 'inf', axis=(0, 1)) + array(9.) + >>> np.linalg.norm(a, '-inf') + array(0.) + >>> np.linalg.norm(b, '-inf', axis=(0, 1)) + array(2.) + >>> np.linalg.norm(a, 1) + array(20.) + >>> np.linalg.norm(b, 1, axis=(0, 1)) + array(7.) + >>> np.linalg.norm(a, -1) + array(0.) + >>> np.linalg.norm(b, -1, axis=(0, 1)) + array(6.) + >>> np.linalg.norm(a, 2) + array(7.745967) + >>> np.linalg.norm(a, -2) + array(0.) + >>> np.linalg.norm(a, 3) + array(5.8480353) + >>> np.linalg.norm(a, -3) + array(0.) + Using the `axis` argument to compute vector norms: + >>> c = np.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> np.linalg.norm(c, axis=0) + array([1.4142135, 2.236068 , 5. ]) + >>> np.linalg.norm(c, axis=1) + array([3.7416573, 4.2426405]) + >>> np.linalg.norm(c, ord=1, axis=1) + array([6., 6.]) + Using the `axis` argument to compute matrix norms: + >>> m = np.arange(8).reshape(2,2,2) + >>> np.linalg.norm(m, axis=(1,2)) + array([ 3.7416573, 11.224973 ]) + >>> np.linalg.norm(m[0, :, :]), np.linalg.norm(m[1, :, :]) + (array(3.7416573), array(11.224973)) """ - if ord is not None and ord != 'fro': - raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord))) - if isinstance(axis, tuple) and len(axis) > 2: - raise ValueError('Improper number of dimensions to norm') - # TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise exception - return _symbol.sqrt(_mx_sym_np.sum(x * x, axis=axis, keepdims=keepdims)) + if axis is None and ord is None: + return _npi.norm(x, ord=2, axis=None, keepdims=keepdims, flag=-2) + if axis is None or isinstance(axis, (int, tuple)): # pylint: disable=too-many-nested-blocks + if axis is not None: + if isinstance(axis, int): + axis = (axis, ) + if len(axis) == 2: + if ord in ['inf', '-inf']: + row_axis, col_axis = axis + if not keepdims: + if row_axis > col_axis: + row_axis -= 1 + if ord == 'inf': + return _mx_sym_np.sum(_symbol.abs(x), axis=col_axis, keepdims=keepdims).max(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long + else: + return _mx_sym_np.sum(_symbol.abs(x), axis=col_axis, keepdims=keepdims).min(axis=row_axis, keepdims=keepdims) # pylint: disable=line-too-long + if ord in [1, -1]: + row_axis, col_axis = axis + if not keepdims: + if row_axis < col_axis: + col_axis -= 1 + if ord == 1: + return _mx_sym_np.sum(_symbol.abs(x), axis=row_axis, keepdims=keepdims).max(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long + elif ord == -1: + return _mx_sym_np.sum(_symbol.abs(x), axis=row_axis, keepdims=keepdims).min(axis=col_axis, keepdims=keepdims) # pylint: disable=line-too-long + if ord in [2, -2]: + return _npi.norm(x, ord=ord, axis=axis, keepdims=keepdims, flag=0) + if ord is None: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=1) + if ord == 'inf': + return _mx_sym_np.max(_symbol.abs(x), axis=axis, keepdims=keepdims) + #return _npi.norm(x, ord=float('inf'), axis=axis, keepdims=keepdims, flag=3) + elif ord == '-inf': + return _mx_sym_np.min(_symbol.abs(x), axis=axis, keepdims=keepdims) + #return _npi.norm(x, ord=-float('inf'), axis=axis, keepdims=keepdims, flag=4) + elif ord is None: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=1) + elif ord == 2: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=-1) + elif ord == 'nuc': + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=2) + elif ord in ['fro', 'f']: + return _npi.norm(x, ord=2, axis=axis, keepdims=keepdims, flag=1) + else: + return _npi.norm(x, ord=ord, axis=axis, keepdims=keepdims, flag=-1) + else: + raise TypeError("'axis' must be None, an integer or a tuple of integers.") +# pylint: enable=too-many-return-statements def svd(a): diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh new file mode 100644 index 000000000000..d5258819a561 --- /dev/null +++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015-2020 by Contributors + * \file broadcast_reduce_customized-inl.cuh + * \brief Customized CUDA implementations for binary broadcast and reduce + * \author MXNet contributors +*/ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ +#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ + +#include "../../tensor/broadcast_reduce-inl.cuh" + +using namespace mshadow::cuda; + +template +__launch_bounds__(nthread_reduce) +__global__ void reduce_kernel_wr(const int N, const int M, const bool addto, + const DType* __restrict big, OType *small, + const Shape big_shape0, const Shape small_shape, + const Shape big_shape, const Shape big_stride, + const int Mnext, const bool do_transpose, + Reducer* reducer) { + extern __shared__ char shTileChar[]; + AType* shTile = (AType*)(shTileChar); + const int tid = threadIdx.x + threadIdx.y*blockDim.x; + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + const int tidx = (do_transpose) ? tid / by : threadIdx.x; + const int tidy = (do_transpose) ? tid % by : threadIdx.y; + // bool need_clean = !reducer; + // reducer = reducer ? reducer : new Reducer(); + for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { + // This TB handles M range [Mstart, ...., Mend - 1] + const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); + const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); + for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { + int idx = idx0 + tidx; + Shape coord = unravel(idx, small_shape); + int idx_big0 = ravel(coord, big_shape0); + + AType val, residual; + reducer->SetInitValue(val, residual); + if (idx < N) { + for (int k = tidy + Mstart; k < Mend; k += by*unroll) { + int idx_big[unroll]; + #pragma unroll + for (int u=0;u < unroll;u++) { + idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); + } + DType tmp[unroll]; + #pragma unroll + for (int u=0;u < unroll;u++) { + if (k + u*by < Mend) { + tmp[u] = OP::Map(big[idx_big[u]]); + } + } + #pragma unroll + for (int u=0;u < unroll;u++) { + if (k + u*by < Mend) reducer->Reduce(val, AType(tmp[u]), residual); + } + } + } + + // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 + if (by > 1) { + // Fix bx to avoid bank conflicts. Assumes warpSize number of banks + const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; + const int it0 = tidx + tidy*fbx; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; + __syncthreads(); + for (int t=1;t < by;t <<= 1) { + AType tmp, tmp_residual; + reducer->SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } + __syncthreads(); + reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); + __syncthreads(); + } + if (idx < N && tidy == 0) { + reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); + } + } else { + if (idx < N) { + reducer->Finalize(val, residual); + assign(&small[idx + m0*N], addto, OType(val)); + } + } + } + } + // if (need_clean) { + // delete reducer; + // } +} + +template +__launch_bounds__(nthread_reduce) +__global__ void reduce_kernel_wr(const int N, const int M, const bool addto, + const DType* __restrict big, const DType* __restrict lhs, + const DType* __restrict rhs, DType *small, + const Shape big_shape0, const Shape lhs_shape0, + const Shape rhs_shape0, const Shape small_shape, + const Shape big_shape, const Shape lhs_shape, + const Shape rhs_shape, const Shape big_stride, + const Shape lhs_stride, const Shape rhs_stride, + const int Mnext, const bool do_transpose, + Reducer* reducer) { + extern __shared__ char shTileChar[]; + DType* shTile = (DType*)(shTileChar); + const int tid = threadIdx.x + threadIdx.y*blockDim.x; + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + const int tidx = (do_transpose) ? tid / by : threadIdx.x; + const int tidy = (do_transpose) ? tid % by : threadIdx.y; + // bool need_clean = !reducer; + // reducer = reducer ? reducer : new Reducer(); + for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { + // This TB handles M range [Mstart, ...., Mend - 1] + const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext); + const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); + for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { + int idx = idx0 + tidx; + Shape coord = unravel(idx, small_shape); + int idx_big0 = ravel(coord, big_shape0); + int idx_lhs0 = ravel(coord, lhs_shape0); + int idx_rhs0 = ravel(coord, rhs_shape0); + + DType val, residual; + reducer->SetInitValue(val, residual); + if (idx < N) { + for (int k = tidy + Mstart; k < Mend; k += by*unroll) { + int idx_big[unroll]; + int idx_lhs[unroll]; + int idx_rhs[unroll]; + #pragma unroll + for (int u=0;u < unroll;u++) { + idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); + idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride); + idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride); + } + DType tmp[unroll]; + #pragma unroll + for (int u=0;u < unroll;u++) { + if (k + u*by < Mend) { + tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]])); + } + } + #pragma unroll + for (int u=0;u < unroll;u++) { + if (k + u*by < Mend) reducer->Reduce(val, tmp[u], residual); + } + } + } + + // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 + if (by > 1) { + // Fix bx to avoid bank conflicts. Assumes warpSize number of banks + const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; + const int it0 = tidx + tidy*fbx; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; + __syncthreads(); + for (int t=1;t < by;t <<= 1) { + DType tmp, tmp_residual; + reducer->SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } + __syncthreads(); + reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); + __syncthreads(); + } + if (idx < N && tidy == 0) { + reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + assign(&small[idx + m0*N], addto, shTile[tidx * 2]); + } + } else { + if (idx < N) { + reducer->Finalize(val, residual); + assign(&small[idx + m0*N], addto, val); + } + } + } + } + // if (need_clean) { + // delete reducer; + // } +} + +// Simple reduction of lines when M is small +template +__launch_bounds__(kMaxThreadsPerBlock) +__global__ void reduce_lines_kernel_wr(const int N, const int M, const bool addto, + const int small_in_stride, const DType* __restrict small_in, DType *small_out, + Reducer* reducer) { + for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + + DType val, residual; + reducer->SetInitValue(val, residual); + for (int k = 0; k < M; k++) { + reducer->Reduce(val, small_in[idx + k*small_in_stride], residual); + } + + if (idx < N) { + reducer->Finalize(val, residual); + assign(&small_out[idx], addto, val); + } + + } +} + +template +__launch_bounds__(kMaxThreadsPerBlock) +__global__ void reduce_kernel_M1_wr(const int N, const bool addto, + const DType* __restrict big, OType *small, const Shape bshape, + const Shape sshape, Reducer* reducer) { + for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + Shape coord = unravel(idx, sshape); + int j = ravel(coord, bshape); + AType val, residual; + reducer->SetInitValue(val, residual); + reducer->Reduce(val, AType(OP::Map(big[j])), residual); + reducer->Finalize(val, residual); + assign(&small[idx], addto, OType(val)); + } +} + +template +__launch_bounds__(kMaxThreadsPerBlock) +__global__ void reduce_kernel_M1_wr(const int N, const bool addto, + const DType* __restrict big, + const DType* __restrict lhs, + const DType* __restrict rhs, + DType *small, + const Shape big_shape, + const Shape lhs_shape, + const Shape rhs_shape, + const Shape small_shape, + Reducer* reducer) { + for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + Shape coord = unravel(idx, small_shape); + int idx_big = ravel(coord, big_shape); + int idx_lhs = ravel(coord, lhs_shape); + int idx_rhs = ravel(coord, rhs_shape); + DType val, residual; + reducer->SetInitValue(val, residual); + reducer->Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); + reducer->Finalize(val, residual); + assign(&small[idx], addto, val); + } +} + +#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \ + if (do_unroll) { \ + const int unrollVar = unrollAmount; \ + {__VA_ARGS__} \ + } else { \ + const int unrollVar = 1; \ + {__VA_ARGS__} \ + } + +template +void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqType req, + const TBlob& big, const Tensor& workspace, + const ReduceImplConfig& config, + Reducer* reducer = nullptr) { + bool need_clean = !reducer; + reducer = reducer ? reducer : new Reducer(); + if (config.M == 1) { + reduce_kernel_M1_wr + <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( + config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), + small.shape_.get(), reducer); + MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr); + } else { + OType* small_dptr = small.dptr(); + bool addto = (req == kAddTo); + if (config.Mnext > 1) { + // small_dptr[] is N*Mnext*sizeof(DType) bytes + small_dptr = reinterpret_cast(workspace.dptr_); + addto = false; + // Check that the workspace is contigiuous + CHECK_EQ(workspace.CheckContiguous(), true); + // Check that we have enough storage + CHECK_GE(workspace.size(0), config.workspace_size); + } + + const int by = (config.kernel_1.do_transpose) ? + config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; + const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); + KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { + reduce_kernel_wr + <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( + config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), + small.shape_.get(), config.rshape, config.rstride, config.Mnext, + config.kernel_1.do_transpose, reducer); + }); + MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); + + if (config.Mnext > 1) { + reduce_lines_kernel_wr + <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> + (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr(), reducer); + MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr); + } + } + if (need_clean) { + delete reducer; + } +} + +template +void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, + const OpReqType req, const TBlob& big, const Tensor& workspace, + const ReduceImplConfig& config, Reducer* reducer = nullptr) { + bool need_clean = !reducer; + reducer = reducer ? reducer : new Reducer(); + if (config.M == 1) { + reduce_kernel_M1_wr + <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( + config.N, req == kAddTo, big.dptr(), lhs.dptr(), rhs.dptr(), + small.dptr(), big.shape_.get(), lhs.shape_.get(), + rhs.shape_.get(), small.shape_.get()); + MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr); + } else { + DType* small_dptr = small.dptr(); + bool addto = (req == kAddTo); + if (config.Mnext > 1) { + // small_dptr[] is N*Mnext*sizeof(DType) bytes + small_dptr = reinterpret_cast(workspace.dptr_); + addto = false; + // Check that the workspace is contigiuous + CHECK_EQ(workspace.CheckContiguous(), true); + // Check that we have enough storage + CHECK_GE(workspace.size(0), config.workspace_size); + } + + const int by = (config.kernel_1.do_transpose) ? + config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; + const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); + KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { + reduce_kernel_wr + <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( + config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), + small_dptr, big.shape_.get(), lhs.shape_.get(), + rhs.shape_.get(), small.shape_.get(), config.rshape, config.lhs_shape, + config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext, + config.kernel_1.do_transpose, reducer); + MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); + }); + + if (config.Mnext > 1) { + reduce_lines_kernel_wr + <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> + (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr(), reducer); + MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr); + } + } + if (need_clean) { + delete reducer; + } +} + +#undef KERNEL_UNROLL_SWITCH + +template +void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, + const Tensor& workspace, const TBlob& big, Reducer* reducer = nullptr) { + if (req == kNullOp) return; + cudaStream_t stream = Stream::GetStream(s); + bool need_clean = !reducer; + reducer = reducer ? reducer : new Reducer(); + ReduceImplConfig config = + ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); + if (safe_acc) { + MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { + typedef typename std::conditional::type AccType; + MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + typedef typename std::conditional::type OutType; + config = ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); + ReduceImplWithReducer( + stream, small, req, big, workspace, config, reducer); + }); + }); + } else { + ReduceImplWithReducer(stream, small, req, big, workspace, config, reducer); + } + if (need_clean) { + delete reducer; + } +} + +#endif // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_ diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h new file mode 100644 index 000000000000..2b5970d4f4ae --- /dev/null +++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015-2017 by Contributors + * \file broadcast_reduce_customized-inl.h + * \brief CPU-specific Function definition of broadcast and reduce operators + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_CUSTOMIZED_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_CUSTOMIZED_INL_H_ + +#include "../../tensor/broadcast_reduce-inl.h" + +namespace mxnet { +namespace op { +namespace broadcast { +using namespace mshadow; + +template +MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, const bool addto, + const DType* __restrict big, OType *small, + const Shape& bshape, const Shape& sshape, + const Shape& rshape, const Shape& rstride, + Reducer* reducer) { + Shape coord = unravel(idx, sshape); + index_t j = ravel(coord, bshape); + AType val, residual; + reducer->SetInitValue(val, residual); + for (size_t k = 0; k < M; ++k) { + coord = unravel(k, rshape); + reducer->Reduce(val, AType(OP::Map(big[j + dot(coord, rstride)])), residual); + } + reducer->Finalize(val, residual); + assign(&small[idx], addto, OType(val)); +} + +#ifdef __CUDACC__ +#include "broadcast_reduce_customized-inl.cuh" +#include "../../tensor/broadcast_reduce-inl.cuh" + +#else + +template +void seq_reduce_compute_wr(const size_t N, const size_t M, const bool addto, + const DType *big, OType *small, const Shape bshape, + const Shape sshape, const Shape rshape, + const Shape rstride, + Reducer* reducer) { + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (index_t idx = 0; idx < static_cast(N); ++idx) { + seq_reduce_assign_wr(idx, M, addto, big, small, + bshape, sshape, rshape, rstride, reducer); + } +} + +template +void ReduceWithReducer(Stream* s, const TBlob& small, const OpReqType req, + const Tensor& workspace, const TBlob& big, + Reducer* reducer) { + if (req == kNullOp) return; + Shape rshape, rstride; + diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); + size_t N = small.shape_.Size(), M = rshape.Size(); + if (!safe_acc) { + seq_reduce_compute_wr( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride, reducer); + } else { + MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { + typedef typename std::conditional::type AccType; + MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, { + typedef typename std::conditional::type OutType; + seq_reduce_compute_wr( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride, reducer); + }); + }); + } +} + +template +MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, const bool addto, + const DType* __restrict big, const DType* __restrict lhs, + const DType* __restrict rhs, DType *small, + const Shape& big_shape, + const Shape& lhs_shape0, + const Shape& rhs_shape0, + const Shape& small_shape, const Shape& rshape, + const Shape& lhs_shape, + const Shape& rhs_shape, + const Shape& rstride, const Shape& lhs_stride, + const Shape& rhs_stride, + Reducer* reducer) { + Shape coord = unravel(idx, small_shape); + const index_t idx_big0 = ravel(coord, big_shape); + const index_t idx_lhs0 = ravel(coord, lhs_shape0); + const index_t idx_rhs0 = ravel(coord, rhs_shape0); + DType val, residual; + reducer->SetInitValue(val, residual); + for (size_t k = 0; k < M; ++k) { + Shape coord_big = unravel(k, rshape); + index_t idx_big = idx_big0 + dot(coord_big, rstride); + + Shape coord_lhs = unravel(k, lhs_shape); + index_t idx_lhs = idx_lhs0 + dot(coord_lhs, lhs_stride); + + Shape coord_rhs = unravel(k, rhs_shape); + index_t idx_rhs = idx_rhs0 + dot(coord_rhs, rhs_stride); + + reducer->Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); + } + reducer->Finalize(val, residual); + assign(&small[idx], addto, val); +} + +template +void seq_reduce_compute_wr(const size_t N, const size_t M, const bool addto, + const DType *big, const DType *lhs, const DType *rhs, DType *small, + const Shape big_shape, const Shape small_shape, + const Shape rshape, const Shape rstride, + const Shape lhs_shape, const Shape lhs_stride, + const Shape rhs_shape, const Shape rhs_stride, + const Shape& lhs_shape0, const Shape& rhs_shape0, + Reducer* reducer) { + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (index_t idx = 0; idx < static_cast(N); ++idx) { + seq_reduce_assign_wr(idx, M, addto, big, lhs, rhs, small, + big_shape, lhs_shape0, rhs_shape0, small_shape, rshape, lhs_shape, rhs_shape, rstride, + lhs_stride, rhs_stride, reducer); + } +} + +template +void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, + const Tensor& workspace, const TBlob& big, const TBlob& lhs, + const TBlob& rhs, Reducer* reducer) { + if (req == kNullOp) return; + Shape rshape, rstride; + diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); + size_t N = small.shape_.Size(); + size_t M = rshape.Size(); + + Shape lhs_shape, lhs_stride; + diff(small.shape_.get(), lhs.shape_.get(), &lhs_shape, &lhs_stride); + + Shape rhs_shape, rhs_stride; + diff(small.shape_.get(), rhs.shape_.get(), &rhs_shape, &rhs_stride); + + seq_reduce_compute_wr( + N, M, req == kAddTo, + big.dptr(), lhs.dptr(), rhs.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), + rshape, rstride, + lhs_shape, lhs_stride, + rhs_shape, rhs_stride, + lhs.shape_.get(), rhs.shape_.get(), + reducer); +} + +#endif +} // namespace broadcast +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_CUSTOMIZED_INL_H_ diff --git a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h new file mode 100644 index 000000000000..25f66d04f663 --- /dev/null +++ b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file broadcast_reduce_op_customized.h + * \brief Function definition of broadcast and reduce operators + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_OP_CUSTOMIZED_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_OP_CUSTOMIZED_H_ + +#include "../../tensor/broadcast_reduce_op.h" +#include "./broadcast_reduce_customized-inl.h" +#include + +namespace mxnet { +namespace op { + +template +void ReduceAxesComputeImplWithReducer(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& small, + Reducer* reducer = nullptr) { + using namespace mshadow; + using namespace mshadow::expr; + + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + broadcast::ReduceWithReducer( + s, out_data, req[0], workspace, in_data, reducer); + // no normalization + }); + }); + }); +} + +template +struct reduce_axes_backward_broadcast_wm { + template + MSHADOW_XINLINE static void Map(index_t i, + DType *data, + OType *out, + DType *igrad, + OType *ograd, + mshadow::Shape in_shape, + mshadow::Shape out_shape, + const uint32_t ndim, + Mapper* OP = nullptr) { + size_t in_stride = 1; + size_t out_stride = 1; + index_t idx = i; + index_t out_idx = i; + bool need_clean = !OP; + for (int iter = ndim - 1; iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + out_idx -= dim_idx * in_stride; + if (out_shape[iter] != 1) { + out_idx += dim_idx * out_stride; + } + idx /= in_shape[iter]; + in_stride *= in_shape[iter]; + out_stride *= out_shape[iter]; + } + OP = OP ? OP : new Mapper(); + KERNEL_ASSIGN(igrad[i], req, DType(ograd[out_idx]) * OP->Map(data[i], DType(out[out_idx]))); + if (need_clean) { + delete OP; + } + } +}; + +template +void ReduceAxesBackwardUseInOutImplWithMapper(const OpContext& ctx, + const mxnet::TShape &small, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + Mapper* OP = nullptr) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); + Stream *s = ctx.get_stream(); + + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mshadow::Shape in_shape; + mshadow::Shape out_shape; + for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) { + if (i < dst_shape.ndim()) { + in_shape[i] = src_shape[i]; + out_shape[i] = dst_shape[i]; + } else { + in_shape[i] = 1; + out_shape[i] = 1; + } + } + if (dst_shape.ndim() == 2) { + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get<2>(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, + in_shape, out_shape, src_shape.ndim(), OP); + }); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, + in_shape, out_shape, src_shape.ndim(), OP); + }); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + } + }); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_OP_CUSTOMIZED_H_ diff --git a/src/operator/numpy/linalg/np_gesvd-inl.h b/src/operator/numpy/linalg/np_gesvd-inl.h index 7ce3078dae66..56815679dcfa 100644 --- a/src/operator/numpy/linalg/np_gesvd-inl.h +++ b/src/operator/numpy/linalg/np_gesvd-inl.h @@ -69,14 +69,21 @@ struct gesvd { const Tensor& L, const Tensor& V, const OpContext& ctx, - const nnvm::NodeAttrs& attrs) { + const nnvm::NodeAttrs& attrs, + TBlob* workspace = nullptr) { Stream *s = ctx.get_stream(); if (A.dptr_ != V.dptr_) Copy(V, A, s); // From here on, we work on V only // Reserve workspace (size determined by query) int lwork(linalg_gesvd_workspace_query(UT[0], L[0], V[0], s)); - Tensor work = ctx.requested[0] - .get_space_typed(Shape1(lwork), s); + Tensor work; + if (!workspace) { + work = ctx.requested[0] + .get_space_typed(Shape1(lwork), s); + } else { + work = workspace + ->get_with_shape(Shape1(lwork), s); + } // Loop over items in batch for (index_t i = 0; i < UT.size(0); ++i) { linalg_gesvd(UT[i], L[i], V[i], work, s); diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h new file mode 100644 index 000000000000..9de4a76f950d --- /dev/null +++ b/src/operator/numpy/linalg/np_norm-inl.h @@ -0,0 +1,836 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_norm-inl.h + * \brief norm + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_NORM_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_NP_NORM_INL_H_ + +#include +#include +#include +#include +#include "../../tensor/la_op.h" +#include "../../tensor/la_op-inl.h" +#include "../../tensor/init_op.h" +#include "./broadcast_reduce_op_customized.h" +#include "./np_gesvd-inl.h" +#include "../np_matrix_op-inl.h" + +namespace mxnet { +namespace op { + +namespace mshadow_op { +/*! \brief Lp-norm power reducer */ + +struct nrmlp { + double lp; + MSHADOW_XINLINE nrmlp(): lp(2) {} + MSHADOW_XINLINE nrmlp(double l): lp(l) {} + + /* \brief power for Lp norm */ + MSHADOW_XINLINE static double lp_power(volatile double src, volatile double p) { + if (p != 0.0) { + if (src == 0.0) { + return src; + } else { + return power::Map(src, p); + } + } else { // 0-norm, sparsity + return static_cast(src != 0); + } + } + + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE void Reduce(volatile AType& sum_of_powers, volatile DType src) { // NOLINT(*) + if (src != 0) { + sum_of_powers += AType(lp_power(static_cast(src), lp)); + } + } + + /*! \brief do stable reduction into dst */ + template + MSHADOW_XINLINE void Reduce(volatile AType& sum_of_powers, volatile DType src, volatile DType& scale) { // NOLINT(*) + if (src != 0) { + DType src_abs = abs::Map(src); + if (scale < src_abs) { + sum_of_powers = sum_of_powers * AType(lp_power(static_cast(scale / src_abs), lp)); + sum_of_powers = sum_of_powers + 1; + scale = src_abs; + } else { + sum_of_powers = sum_of_powers + AType(lp_power(static_cast(src_abs / scale), lp)); + } + } + } + + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + dst_val += src_val; + } + + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*) + if (dst_scale != 0 && dst_scale >= src_scale) { + dst_ssq = dst_ssq + src_ssq * DType(lp_power(static_cast(src_scale / dst_scale), 2)); + } else if (src_scale != 0 && dst_scale < src_scale) { + dst_ssq = src_ssq + dst_ssq * DType(lp_power(static_cast(dst_scale / src_scale), 2)); + dst_scale = src_scale; + } + } + + /*! \brief finalize reduction result */ + template + MSHADOW_XINLINE void Finalize(volatile DType& sum_of_powers) { // NOLINT(*) + if (lp != 0.0) { + sum_of_powers = DType(lp_power(static_cast(sum_of_powers), 1.0 / lp)); + } + } + + /*! \brief finalize reduction result */ + template + MSHADOW_XINLINE void Finalize(volatile DType& sum_of_powers, volatile DType& scale) { // NOLINT(*) + if (lp != 0.0) { + sum_of_powers = scale * DType(lp_power(static_cast(sum_of_powers), 1.0 / lp)); + } + } + + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &sum_of_powers) { // NOLINT(*) + sum_of_powers = 0; + } + + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &sum_of_powers, DType &scale) { // NOLINT(*) + SetInitValue(sum_of_powers); + scale = 0; + } +}; + +/*! \brief Elementwise gradient of Lp-norm, does not handle p = 1 */ +struct nrmlp_grad : public mxnet_op::tunable { + double lp; + MSHADOW_XINLINE nrmlp_grad(): lp(2) {} + MSHADOW_XINLINE nrmlp_grad(double l): lp(l) {} + + /* \brief elementwise gradient of lp norm */ + template + MSHADOW_XINLINE DType Map(DType a, DType b) { + DType ret; + if (lp != 0.0) { // dx_i = (|x_i| / y) ^ (p - 1) * sgn(x_i) + DType abs_a = DType(abs::Map(a)); + DType sgn_a = DType(sign::Map(a)); + ret = power::Map(DType(abs_a / b), DType(lp - 1)) * sgn_a; + } else { // L0 norm is elementwise constant and non-differentiable + ret = 0; + } + return ret; + } +}; + +/*! \brief Gradient for abs-min/max */ +struct abs_grad : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + DType sgn = DType(sign::Map(a)); + DType grad = DType(abs::Map(a)) == DType(abs::Map(b)) ? + DType(1.0) : DType(0.0); + return sgn * grad; + } +}; + +/*! \brief Sign */ +struct abs_sign : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return DType(sign::Map(a)); + } +}; + +} // namespace mshadow_op + +inline bool NumpyLpNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs); + +inline bool NumpyMatrixNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs); + +inline void assign_svd_empty(mxnet::ShapeVector *out_attrs); + +bool NumpyNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs); + +bool NumpyNormType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs); + +TShape swapMatDims(const TShape &shape, const TShape &axis); + +TShape inverseTranspose(const TShape &axes); + +struct NumpyNormParam : public dmlc::Parameter { + double ord; + dmlc::optional axis; + bool keepdims; + int flag; + DMLC_DECLARE_PARAMETER(NumpyNormParam) { + DMLC_DECLARE_FIELD(ord).set_default(2) + .describe("Order of the norm. inf means numpy’s inf object."); + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) + .describe(R"code(If axis is an integer, it specifies the axis of x along + which to compute the vector norms. If axis is a 2-tuple, + it specifies the axes that hold 2-D matrices, and the matrix norms of + these matrices are computed. If axis is None then either a vector norm (when x is 1-D) + or a matrix norm (when x is 2-D) is returned.If axis is an integer, + it specifies the axis of x along which to compute the vector norms. + If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, + and the matrix norms of these matrices are computed. If axis is None then either a + vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned.)code"); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axis is left " + "in the result as dimension with size one."); + DMLC_DECLARE_FIELD(flag).set_default(-1) + .describe("Mapping relations between ord and flag." + "ord: None, 'fro', 'nuc', 'inf' '-inf'." + "flag: 0 , 1, 2, 3, 4. "); + } +}; + +template +void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NumpyNormParam& param = nnvm::get(attrs.parsed); + double ord = param.ord; + + if (req[0] == kNullOp) return; + + mxnet::TShape small; + mxnet::TShape out_shape = outputs[0].shape_; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); + const_cast&>(outputs)[0] = outputs[0].reshape(small); + } + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) { + common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. " + "See https://mxnet.apache.org/api/faq/env_var " + "for more details."); + } + if (param.axis.value().ndim() != 2) { // elementwise Lp-norm + if (ord == -std::numeric_limits::infinity()) { // -inf norm + LOG(FATAL) << "-inf norm handled in front-end."; + } else if (param.ord == std::numeric_limits::infinity()) { // inf norm + LOG(FATAL) << "inf norm handled in front-end."; + } else { + mshadow_op::nrmlp host_reducer(param.ord); + mshadow_op::nrmlp *reducer_instance = nullptr; +#ifdef __CUDACC__ + Stream *s = ctx.get_stream(); + cudaStream_t copy_stream = mshadow::Stream::GetStream(s); + cudaMalloc(reinterpret_cast(&reducer_instance), sizeof(mshadow_op::nrmlp)); + cudaMemcpyAsync(reducer_instance, &host_reducer, sizeof(mshadow_op::nrmlp), + cudaMemcpyHostToDevice, copy_stream); + cudaStreamSynchronize(copy_stream); +#else + reducer_instance = &host_reducer; +#endif + if (safe_acc) { + ReduceAxesComputeImplWithReducer( + ctx, inputs, req, outputs, small, reducer_instance); + } else { + ReduceAxesComputeImplWithReducer( + ctx, inputs, req, outputs, small, reducer_instance); + } +#ifdef __CUDACC__ + cudaFree(reducer_instance); +#endif + } + } + if (!param.keepdims) { + const_cast&>(outputs)[0] = outputs[0].reshape(out_shape); + } +} + +template +void NumpyLpNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + + const NumpyNormParam& param = nnvm::get(attrs.parsed); + double ord = param.ord; + mxnet::TShape small; + + Stream *s = ctx.get_stream(); + if (param.keepdims) { + small = inputs[0].shape_; + } else { + small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false); + } + + if (param.axis.value().ndim() != 2) { // Elementwise Lp norm + if (ord == -std::numeric_limits::infinity()) { // -inf norm + LOG(FATAL) << "-inf norm handled in front-end."; + } else if (ord == std::numeric_limits::infinity()) { // inf norm + LOG(FATAL) << "inf norm handled in front-end."; + } else if (ord == 1) { // nrmlp_grad does not handle p = 1, legacy code from tensor + mxnet::TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); + mshadow::Shape in_shape; + mshadow::Shape out_shape; + for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) { + if (i < dst_shape.ndim()) { + in_shape[i] = src_shape[i]; + out_shape[i] = dst_shape[i]; + } else { + in_shape[i] = 1; + out_shape[i] = 1; + } + } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, OType, { + if (dst_shape.ndim() == 2) { + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, igrad.shape_.Size(), igrad.dptr_, ograd.dptr_, data.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, igrad.shape_.Size(), igrad.dptr_, ograd.dptr_, data.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + } + }); + }); + } else { // Elementwise Lp + mshadow_op::nrmlp_grad host_mapper(ord); + mshadow_op::nrmlp_grad *mapper_instance = nullptr; +#ifdef __CUDACC__ + cudaStream_t copy_stream = mshadow::Stream::GetStream(s); + cudaMalloc(reinterpret_cast(&mapper_instance), sizeof(mshadow_op::nrmlp_grad)); + cudaMemcpyAsync(mapper_instance, &host_mapper, sizeof(mshadow_op::nrmlp_grad), + cudaMemcpyHostToDevice, copy_stream); + cudaStreamSynchronize(copy_stream); +#else + mapper_instance = &host_mapper; +#endif + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (req[0] == kAddTo) { + TBlob workspace = TBlob(ctx.requested[0].get_space_typed( + Shape1(outputs[0].shape_.Size()), s)); + std::vector temp({workspace.reshape(outputs[0].shape_)}); + ReduceAxesBackwardUseInOutImplWithMapper( + ctx, small, inputs, req, temp, mapper_instance); + Tensor out = outputs[0].FlatTo1D(s); + out += workspace.FlatTo1D(s); + } else { + ReduceAxesBackwardUseInOutImplWithMapper( + ctx, small, inputs, req, outputs, mapper_instance); + } + }); +#ifdef __CUDACC__ + cudaFree(mapper_instance); +#endif + } + } else { // matrix norm should switch to matrix norm op + LOG(FATAL) << "Case handled in matrix norm compute."; + } +} + +template +void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + + if (req[0] == kNullOp) return; + + Stream *s = ctx.get_stream(); + const NumpyNormParam& param = nnvm::get(attrs.parsed); + double ord = param.ord; + + TShape reduced_shape; + if (param.keepdims) { + reduced_shape = outputs[0].shape_; + } else { + reduced_shape = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); + } + + if (param.flag == 1) { // Frobenius norm + ReduceAxesComputeImplWithReducer( + ctx, inputs, req, outputs, reduced_shape); + return; + } + + TShape mat_axis = param.axis.value(); + + if (param.ord != 2 && param.ord != -2) { // row norm or col norm + TShape sum_shape = inputs[0].shape_; + sum_shape[mat_axis[!(param.ord == 1 || param.ord == -1)]] = 1; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + TBlob temp = outputs[1].reshape(sum_shape); + std::vector sum_output({temp}); + ReduceAxesComputeImpl( + ctx, inputs, req, sum_output, sum_shape); + if (param.ord > 0) { + ReduceAxesComputeImpl( + ctx, sum_output, req, outputs, reduced_shape); + } else { + ReduceAxesComputeImpl( + ctx, sum_output, req, outputs, reduced_shape); + } + }); + return; + } + + if (inputs[0].type_flag_ == mshadow::kFloat16) { + LOG(FATAL) << "Matrix +/- 2-norm does not support float 16 due to SVD implementation."; + } + + // spectral norms + TShape old_shape = inputs[0].shape_; + TShape svd_in_shape = inputs[0].shape_; + TShape axes(old_shape.ndim(), 1); + for (int i = 0; i < old_shape.ndim(); ++i) { + axes[i] = i; + } + + svd_in_shape = swapMatDims(svd_in_shape, mat_axis); + axes = swapMatDims(axes, mat_axis); + TShape reduce_axes = inverseTranspose(axes); + + int row_dim = svd_in_shape[svd_in_shape.ndim() - 2]; + int col_dim = svd_in_shape[svd_in_shape.ndim() - 1]; + int svd_dim = row_dim <= col_dim ? row_dim : col_dim; + int batch_dim = svd_in_shape.ProdShape(0, svd_in_shape.ndim() - 2); + + TShape L_shape = svd_in_shape; + TShape L_trans = inputs[0].shape_; + if (row_dim > col_dim) { + L_shape[L_shape.ndim() - 2] = 1; + L_trans[mat_axis[0]] = 1; + } else { + L_shape[L_shape.ndim() - 1] = 1; + L_trans[mat_axis[1]] = 1; + } + + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor UT = + outputs[1].get_with_shape(Shape3(batch_dim, row_dim, row_dim), s); + Tensor L = + outputs[2].get_with_shape(Shape2(batch_dim, svd_dim), s); + Tensor V = + outputs[3].get_with_shape(Shape3(batch_dim, row_dim, col_dim), s); + + size_t svd_space = linalg_gesvd_workspace_query(UT[0], L[0], V[0], s); + size_t space = svd_in_shape.Size() + svd_space; + space += space & 1; + size_t offset = svd_in_shape.Size() + (1 & svd_in_shape.Size()); + + TBlob temp = TBlob(ctx.requested[0].get_space_typed( + Shape1(space), s)); + TBlob workspace(reinterpret_cast(temp.dptr_), svd_in_shape, + temp.dev_mask(), temp.dev_id()); + TBlob svd_workspace(reinterpret_cast(temp.dptr_) + offset, TShape(1, svd_space), + temp.dev_mask(), temp.dev_id()); + TransposeImpl(ctx.run_ctx, inputs[0], workspace, axes); + Tensor svd_input = + workspace.get_with_shape(Shape3(batch_dim, row_dim, col_dim), s); + gesvd::op(svd_input, UT, L, V, ctx, attrs, &svd_workspace); + + TBlob workspace0(reinterpret_cast(temp.dptr_), L_trans, + temp.dev_mask(), temp.dev_id()); + TransposeImpl(ctx.run_ctx, TBlob(L).reshape(L_shape), workspace0, reduce_axes); + std::vector eigen({ workspace0 }); + if (param.flag == 2) { // nuclear norm + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + if (ord == 2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (ord == -2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } + } else { + if (ord == 2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } else if (ord == -2) { + ReduceAxesComputeImpl( + ctx, eigen, req, outputs, reduced_shape); + } + } + }); +} + +template +void NumpyMatrixNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + + Stream *s = ctx.get_stream(); + if (req[0] == kNullOp) return; + + const NumpyNormParam& param = nnvm::get(attrs.parsed); + + TShape reduced_shape; + TShape old_shape_ = inputs[0].shape_; + if (param.keepdims) { + reduced_shape = inputs[0].shape_; + } else { + reduced_shape = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false); + } + + std::vector map_inputs; + std::vector map_outputs; + + if (param.flag == 1) { // frob norm + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + map_inputs = std::vector({inputs[0], inputs[4], inputs[5]}); + if (req[0] == kAddTo) { + TBlob workspace = TBlob(ctx.requested[0].get_space_typed( + Shape1(outputs[0].shape_.Size()), s)); + std::vector temp({workspace.reshape(outputs[0].shape_)}); + ReduceAxesBackwardUseInOutImpl( + ctx, reduced_shape, map_inputs, req, temp); + Tensor out = outputs[0].FlatTo1D(s); + out += workspace.FlatTo1D(s); + } else { + ReduceAxesBackwardUseInOutImpl( + ctx, reduced_shape, map_inputs, req, outputs); + } + }); + return; + } + + TShape mat_axis = param.axis.value(); + + if (param.ord != 2 && param.ord != -2) { // row norm or col norm + TShape sum_shape = outputs[0].shape_; + TShape out_shape = outputs[0].shape_; + int sum_dim = mat_axis[!(param.ord == 1 || param.ord == -1)]; + sum_shape[sum_dim] = 1; + TShape small(3, 1), squeezed(3, outputs[0].shape_[sum_dim]); + squeezed[0] = small[0] = sum_shape.ProdShape(0, sum_dim); + squeezed[2] = small[2] = sum_shape.ProdShape(sum_dim + 1, sum_shape.ndim()); + map_inputs = std::vector({ inputs[0], inputs[6], inputs[5] }); + + size_t sum_size = sum_shape.Size(); + size_t ws_offset = sum_size + (sum_size & 1); + size_t ws_size = ws_offset + (req[0] == kAddTo ? outputs[0].shape_.Size() : 0); + ws_size += ws_size & 1; + + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + TBlob workspace = TBlob(ctx.requested[0].get_space_typed( + Shape1(ws_size), s)); + TBlob temp0 = TBlob(reinterpret_cast(workspace.dptr_), + sum_shape, workspace.dev_mask(), workspace.dev_id()); + std::vector map_outputs({ temp0 }); + ReduceAxesBackwardUseInOutImpl( + ctx, reduced_shape, map_inputs, req, map_outputs); + temp0 = temp0.reshape(small); + map_outputs = std::vector({temp0, inputs[4], inputs[6]}); + if (req[0] == kAddTo) { + TBlob out_temp = TBlob(reinterpret_cast(workspace.dptr_) + ws_offset, + outputs[0].shape_, workspace.dev_mask(), workspace.dev_id()); + std::vector tmp_outputs({ out_temp }); + ReduceAxesBackwardUseInOutImpl( + ctx, sum_shape, map_outputs, req, tmp_outputs); + out_temp = out_temp.reshape(squeezed); + Tensor tmp_out = + out_temp.get_with_shape(Shape3(squeezed[0], squeezed[1], squeezed[2]), s); + Tensor mask = + temp0.get_with_shape(Shape3(small[0], small[1], small[2]), s); + tmp_out = tmp_out * broadcast_to(mask, squeezed); + TBlob final_output = outputs[0].reshape(squeezed); + Tensor out = + final_output.get_with_shape( + Shape3(squeezed[0], squeezed[1], squeezed[2]), s); + out += tmp_out; + } else { + ReduceAxesBackwardUseInOutImpl( + ctx, sum_shape, map_outputs, req, outputs); + TBlob final_output = outputs[0].reshape(squeezed); + Tensor out = + final_output.get_with_shape( + Shape3(squeezed[0], squeezed[1], squeezed[2]), s); + Tensor mask = + temp0.get_with_shape(Shape3(small[0], small[1], small[2]), s); + out = out * broadcast_to(mask, squeezed); + } + }); + return; + } + + if (!param.keepdims) { + const_cast&>(inputs)[0] = inputs[0].reshape(reduced_shape); + const_cast&>(inputs)[5] = inputs[5].reshape(reduced_shape); + } + + map_inputs.push_back(inputs[0]); + TBlob L_reduced = inputs[5]; + TBlob L_irreduced = inputs[7]; + + TShape old_shape = inputs[4].shape_; + TShape svd_in_shape = old_shape; + TShape axes(old_shape.ndim(), 1); + for (int i = 0; i < old_shape.ndim(); ++i) { + axes[i] = i; + } + svd_in_shape = swapMatDims(svd_in_shape, mat_axis); + axes = swapMatDims(axes, mat_axis); + TShape reduce_axes = inverseTranspose(axes); + + int row_dim = svd_in_shape[svd_in_shape.ndim() - 2]; + int col_dim = svd_in_shape[svd_in_shape.ndim() - 1]; + int batch_dim = svd_in_shape.ProdShape(0, svd_in_shape.ndim() - 2); + + TShape L_shape = svd_in_shape; + TShape L_trans = old_shape; + if (row_dim > col_dim) { + L_shape[L_shape.ndim() - 2] = 1; + L_trans[mat_axis[0]] = 1; + } else { + L_shape[L_shape.ndim() - 1] = 1; + L_trans[mat_axis[1]] = 1; + } + L_irreduced = L_irreduced.reshape(L_shape); + int kmn = outputs[0].shape_.Size(); + int kmm = inputs[1].shape_.Size(); + int km = inputs[2].shape_.Size(); + size_t workspace_size = svd_in_shape.ProdShape(0, svd_in_shape.ndim()) * 2 + + km + kmn + 5; + workspace_size += req[0] == kAddTo? kmn : kmm; + size_t workspace_offset1 = svd_in_shape.ProdShape(0, svd_in_shape.ndim()); + workspace_offset1 += workspace_offset1 & 1; + size_t workspace_offset2 = workspace_offset1 * 2; + size_t workspace_offset3 = workspace_offset2; + if (req[0] == kAddTo) { + workspace_offset3 += kmn + (kmn & 1); + } else { + workspace_offset3 += kmm + (kmm & 1); + } + size_t workspace_offset4 = workspace_offset3 + km + (km & 1); + + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + TBlob workspace = TBlob(ctx.requested[0].get_space_typed( + Shape1(workspace_size), s)); + TBlob workspace0(reinterpret_cast(workspace.dptr_), L_trans, + workspace.dev_mask(), workspace.dev_id()); + TBlob workspace1(reinterpret_cast(workspace.dptr_) + workspace_offset1, L_trans, + workspace.dev_mask(), workspace.dev_id()); + TBlob tempM(reinterpret_cast(workspace.dptr_) + workspace_offset2, inputs[1].shape_, + workspace.dev_mask(), workspace.dev_id()); + TBlob tempMd(reinterpret_cast(workspace.dptr_) + workspace_offset3, inputs[2].shape_, + workspace.dev_mask(), workspace.dev_id()); + TBlob temp(reinterpret_cast(workspace.dptr_) + workspace_offset4, inputs[3].shape_, + workspace.dev_mask(), workspace.dev_id()); + TransposeImpl(ctx.run_ctx, L_irreduced.reshape(L_shape), workspace0, reduce_axes); + map_inputs.push_back(workspace0); + map_inputs.push_back(L_reduced); + if (param.flag == 2) { // nuclear norm + mxnet::op::Fill(s, workspace1, req[0], DType(1.0)); + } else { + std::vector reduce_output({ workspace1 }); + ReduceAxesBackwardUseInOutImpl( + ctx, reduced_shape, map_inputs, req, reduce_output); + } + workspace1 = workspace1.reshape(L_shape); + gesvd_backward::op(inputs[1].FlatToKD(s), + workspace1.reshape(inputs[2].shape_).FlatToKD(s), + inputs[3].FlatToKD(s), + inputs[6].FlatToKD(s), + inputs[7].FlatToKD(s), + inputs[8].FlatToKD(s), + temp.get_with_shape(Shape3(batch_dim, row_dim, col_dim)), + tempM.FlatToKD(s), + tempMd.FlatToKD(s), + s, attrs); + Tensor temp_flat = temp.FlatToKD(s); + TBlob in_grad_trans(reinterpret_cast(workspace0.dptr_), + swapMatDims(inputs[0].shape_, mat_axis), + workspace.dev_mask(), workspace.dev_id()); + TransposeImpl(ctx.run_ctx, inputs[0], in_grad_trans, axes); + Tensor trans_in_grad = in_grad_trans.FlatToKD(s); + temp_flat = temp_flat * broadcast_to(trans_in_grad, temp.shape_); + if (req[0] == kAddTo) { + TBlob ograd(reinterpret_cast(tempM.dptr_), outputs[0].shape_, + workspace.dev_mask(), workspace.dev_id()); + TransposeImpl(ctx.run_ctx, temp.reshape(svd_in_shape), ograd, reduce_axes); + Tensor out = outputs[0].FlatTo1D(s); + out += ograd.FlatTo1D(s); + } else { + TransposeImpl(ctx.run_ctx, temp.reshape(svd_in_shape), outputs[0], reduce_axes); + } + }); + if (!param.keepdims) { + const_cast&>(inputs)[0] = inputs[0].reshape(old_shape_); + const_cast&>(inputs)[5] = inputs[5].reshape(old_shape_); + } +} + +template +void NumpyNormComputeForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + Stream *s = ctx.get_stream(); + if (inputs[0].shape_.Size() == 0U) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + mxnet::op::Fill(s, outputs[0], req[0], DType(0.0)); + }); + return; + } + const NumpyNormParam& param = nnvm::get(attrs.parsed); + + if (param.flag == -2) { // flattened L2 norm + std::vector flat_inputs({ + inputs[0].reshape(TShape(1, inputs[0].shape_.Size())) + }); + std::vector flat_outputs({ + outputs[0].reshape(TShape(1, 1)) + }); + ReduceAxesComputeImplWithReducer( + ctx, flat_inputs, req, flat_outputs, TShape(1, 1)); + return; + } + + if (param.axis.value().ndim() == 2) { + NumpyMatrixNormCompute(attrs, ctx, inputs, req, outputs); + } else { + NumpyLpNormCompute(attrs, ctx, inputs, req, outputs); + } +} + +template +void NumpyNormComputeBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + Stream *s = ctx.get_stream(); + if (inputs[0].shape_.Size() == 0U) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + mxnet::op::Fill(s, outputs[0], req[0], DType(0.0)); + }); + return; + } + if (!common::is_float(outputs[0].type_flag_)) { + LOG(FATAL) << "Computing gradient for integer inputs is not well-undefined behavior."; + } + const NumpyNormParam& param = nnvm::get(attrs.parsed); + + if (param.flag == -2) { // flattened L2 norm + std::vector flat_inputs({ + inputs[0].reshape(TShape(1, 1)), + inputs[4].reshape(TShape(1, outputs[0].shape_.Size())), + inputs[5].reshape(TShape(1, 1)) + }); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (req[0] == kAddTo) { + TBlob workspace = TBlob(ctx.requested[0].get_space_typed( + Shape1(outputs[0].shape_.Size()), s)); + std::vector temp({ workspace }); + ReduceAxesBackwardUseInOutImpl( + ctx, TShape(1, 1), flat_inputs, req, temp); + Tensor out = outputs[0].FlatTo1D(s); + out += workspace.FlatTo1D(s); + } else { + std::vector flat_outputs({ + outputs[0].reshape(TShape(1, outputs[0].shape_.Size())) + }); + ReduceAxesBackwardUseInOutImpl( + ctx, TShape(1, 1), flat_inputs, req, flat_outputs); + } + }); + return; + } + + // need to infer shape again in backward + std::vector in_attrs({ + inputs.size() == 9 ? inputs[4].shape_ : inputs[1].shape_ + }); + std::vector out_attrs({ + inputs.size() == 9 ? inputs[5].shape_ : inputs[2].shape_, + TShape(), TShape(), TShape() + }); + NumpyNormShape(attrs, &in_attrs, &out_attrs); + + if (param.axis.value().ndim() == 2) { + NumpyMatrixNormGradCompute(attrs, ctx, inputs, req, outputs); + } else { + std::vector grad_inputs({inputs[0], inputs[4], inputs[5]}); + NumpyLpNormGradCompute(attrs, ctx, grad_inputs, req, outputs); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_NORM_INL_H_ diff --git a/src/operator/numpy/linalg/np_norm.cc b/src/operator/numpy/linalg/np_norm.cc new file mode 100644 index 000000000000..c284f4a80ea1 --- /dev/null +++ b/src/operator/numpy/linalg/np_norm.cc @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_norm-.cc + * \brief CPU registration of np.linalg.norm + */ + +#include "./np_norm-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyNormParam); + +inline bool NumpyLpNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + if (!shape_is_known((*in_attrs)[0])) return false; + const NumpyNormParam& param = nnvm::get(attrs.parsed); + const int ndim = (*in_attrs)[0].ndim(); + if ((!param.axis.has_value() && param.flag != 0 && ndim > 2) || + (param.axis.has_value() && param.axis.value().ndim() > 2)) + LOG(FATAL) << "Improper number of dimensions to norm."; + if (!param.axis.has_value()) { + if ((ndim == 0 && param.flag != 0) || // for scalar + (ndim == 1 && (param.flag == 2)) || + (ndim >= 2 && (param.ord == 0 || param.ord > 2 || param.ord < -2))) { + LOG(FATAL) << "Invalid norm order for inputs."; + } + } else { + if ((param.axis.value().ndim() == 0 && param.flag != 0) || // for scalar + (param.axis.value().ndim() == 1 && (param.flag == 2)) || + (param.axis.value().ndim() == 2 && (param.ord == 0 || param.ord > 2 || param.ord < -2))) { + LOG(FATAL) << "Invalid norm order for inputs."; + } + } + if (!param.keepdims && (*in_attrs)[0].ndim() == 1) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)); + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + ReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims, false)); + } + return true; +} + +inline bool NumpyMatrixNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const NumpyNormParam& param = nnvm::get(attrs.parsed); + const int ndim = (*in_attrs)[0].ndim(); + auto shape = swapMatDims((*in_attrs)[0], param.axis.value()); + if (param.axis.value().ndim() == 2) { + int batch_dim = 1; + int row_dim = (*in_attrs)[0][param.axis.value()[0]]; + int col_dim = (*in_attrs)[0][param.axis.value()[1]]; + TShape out_shape(ndim - (param.keepdims ? 0 : 2), 1); + for (int i = 0; i < ndim - 2; ++i) { + batch_dim *= shape[i]; + } + if (param.keepdims) { + out_shape = (*in_attrs)[0]; + out_shape[param.axis.value()[0]] = 1; + out_shape[param.axis.value()[1]] = 1; + } else { + for (int i = 0; i < ndim - 2; ++i) { + out_shape[i] = shape[i]; + } + } + int svd_dim = row_dim < col_dim ? row_dim : col_dim; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + if (param.ord == 2 || param.ord == -2) { + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape({ batch_dim, row_dim, row_dim })); // UT + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape({ batch_dim, svd_dim })); // L + SHAPE_ASSIGN_CHECK(*out_attrs, 3, TShape({ batch_dim, row_dim, col_dim })); // V + } else { + TShape sum_shape = (*in_attrs)[0]; + TShape mat_axis = param.axis.value(); + int sum_dim = mat_axis[!(param.ord == 1 || param.ord == -1)]; + TShape small(3, 1); + sum_shape[sum_dim] = 1; + small[0] = sum_shape.ProdShape(0, sum_dim); + small[2] = sum_shape.ProdShape(sum_dim + 1, sum_shape.ndim()); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, small); // sum + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape({ 0, 0 })); // L + SHAPE_ASSIGN_CHECK(*out_attrs, 3, TShape({ 0, 0, 0 })); // V + } + } else { + LOG(FATAL) << "Invalid norm or ord arguments."; + } + return true; +} + +inline void assign_svd_empty(mxnet::ShapeVector *out_attrs) { + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape({ 0, 0, 0 })); // UT + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape({ 0, 0 })); // L + SHAPE_ASSIGN_CHECK(*out_attrs, 3, TShape({ 0, 0, 0 })); // V +} + +bool NumpyNormType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 4U); + int in_type = in_attrs->at(0); + int out_type; + if (!common::is_float(in_type)) { + out_type = in_type; + LOG(WARNING) << "WARNING: Integer input to norm. This will result in integer " + "output which is different from standard NumPy behavior and " + "breaks gradient compute in backward. Please cast the input " + "to floating point types first."; + } else { + out_type = in_type; + } + for (int i = 0; i < 4; ++i) { + TYPE_ASSIGN_CHECK(*out_attrs, i, out_type); + } + return out_attrs->at(0) != -1; +} + +bool NumpyNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 4U); // reduced, UT, S, V + const NumpyNormParam& param = nnvm::get(attrs.parsed); + if (!param.axis.has_value()) { + if (param.flag == -2) { + int ndim = param.keepdims ? (*in_attrs)[0].ndim() : 0; + int sz = param.keepdims ? 1 : -1; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(ndim, sz)); + assign_svd_empty(out_attrs); + return true; + } + if ((*in_attrs)[0].ndim() >= 2) { + TShape axis(2, 0); + axis[0] = (*in_attrs)[0].ndim() - 2; + axis[1] = (*in_attrs)[0].ndim() - 1; + const_cast(param).axis = axis; + return NumpyMatrixNormShape(attrs, in_attrs, out_attrs); + } else { + TShape axis(1, (*in_attrs)[0].ndim() - 1); + const_cast(param).axis = axis; + assign_svd_empty(out_attrs); + return NumpyLpNormShape(attrs, in_attrs, out_attrs); + } + } else { + TShape axis(param.axis.value().ndim(), 0); + for (int i = 0; i < param.axis.value().ndim(); ++i) { + axis[i] = param.axis.value()[i] < 0 ? + (*in_attrs)[0].ndim() + param.axis.value()[i] : + param.axis.value()[i]; + } + const_cast(param).axis = axis; + if (param.axis.value().ndim() == 2) { + return NumpyMatrixNormShape(attrs, in_attrs, out_attrs); + } else { + assign_svd_empty(out_attrs); + return NumpyLpNormShape(attrs, in_attrs, out_attrs); + } + } +} + +TShape swapMatDims(const TShape &shape, const TShape &axis) { + TShape ret(shape.ndim(), 1); + int i, j = 0; + for (i = 0; i < shape.ndim(); ++i) { + if (i != axis[0] && i != axis[1]) { + ret[j++] = shape[i]; + } + } + ret[j++] = shape[axis[0]]; + ret[j] = shape[axis[1]]; + return ret; +} + +TShape inverseTranspose(const TShape &axes) { + TShape ret(axes.ndim(), 1); + for (int i = 0; i < axes.ndim(); ++i) { + ret[axes[i]] = i; + } + return ret; +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_norm_backward.cc b/src/operator/numpy/linalg/np_norm_backward.cc new file mode 100644 index 000000000000..f6db6864b5db --- /dev/null +++ b/src/operator/numpy/linalg/np_norm_backward.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_norm_backward.cc + * \brief CPU registration of np.linalg.norm + */ + +#include "./np_norm-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_backward_npi_norm) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_num_inputs(2 * 4 + 1) +.set_attr("FCompute", NumpyNormComputeBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_norm_backward.cu b/src/operator/numpy/linalg/np_norm_backward.cu new file mode 100644 index 000000000000..09e85ab36f19 --- /dev/null +++ b/src/operator/numpy/linalg/np_norm_backward.cu @@ -0,0 +1,33 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +/*! +* Copyright (c) 2019 by Contributors +* \file np_norm_backward.cu +* \brief GPU implementation of Operators for advanced linear algebra. +*/ +#include "./np_norm-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_backward_npi_norm) +.set_attr("FCompute", NumpyNormComputeBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_norm_forward.cc b/src/operator/numpy/linalg/np_norm_forward.cc new file mode 100644 index 000000000000..2c8f27638206 --- /dev/null +++ b/src/operator/numpy/linalg/np_norm_forward.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_norm_forward.cc + * \brief CPU registration of np.linalg.norm + */ + +#include "./np_norm-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_norm) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(4) +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { return 1; }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyNormShape) +.set_attr("FInferType", NumpyNormType) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_npi_norm"}) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", NumpyNormComputeForward) +.add_argument("data", "NDArray-or-Symbol", "The input"); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_norm_forward.cu b/src/operator/numpy/linalg/np_norm_forward.cu new file mode 100644 index 000000000000..6feecb09a09e --- /dev/null +++ b/src/operator/numpy/linalg/np_norm_forward.cu @@ -0,0 +1,33 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +/*! +* Copyright (c) 2019 by Contributors +* \file np_norm_forward.cu +* \brief GPU implementation of Operators for advanced linear algebra. +*/ +#include "./np_norm-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_norm) +.set_attr("FCompute", NumpyNormComputeForward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 704f35949bd0..24a4da984414 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -367,12 +367,9 @@ def _add_workload_transpose(): def _add_workload_linalg_norm(): OpArgMngr.add_workload('linalg.norm', np.random.uniform(size=(4, 1))) - for dt in ["double", "float32", "int64"]: + for dt in ["float64", "float32"]: OpArgMngr.add_workload('linalg.norm', np.array([], dtype=dt)) OpArgMngr.add_workload('linalg.norm', np.array([np.array([]), np.array([])], dtype=dt)) - # numerical error exceed the tolerance - if dt == "int64": - continue for v in ([1, 2, 3, 4], [-1, -2, -3, -4], [-1, 2, -3, 4]): OpArgMngr.add_workload('linalg.norm', np.array(v, dtype=dt)) A = np.array([[1, 2, 3], [4, 5, 6]], dtype=dt) @@ -401,7 +398,7 @@ def _add_workload_linalg_norm(): OpArgMngr.add_workload('linalg.norm', np.take(B[:], np.array(k), axis=k_index).T) A = np.arange(1, 25, dtype=dt).reshape(2, 3, 4) OpArgMngr.add_workload('linalg.norm', A, ord=None, axis=None) - OpArgMngr.add_workload('linalg.norm', A, ord=None,axis=None, keepdims=True) + OpArgMngr.add_workload('linalg.norm', A, ord=None, axis=None, keepdims=True) for k in range(A.ndim): OpArgMngr.add_workload('linalg.norm', A, axis=k) OpArgMngr.add_workload('linalg.norm', A, axis=k, keepdims=True) @@ -410,12 +407,15 @@ def _add_workload_linalg_norm(): OpArgMngr.add_workload('linalg.norm', A, axis=k, keepdims=True) OpArgMngr.add_workload('linalg.norm', np.array([[]], dtype=dt)) A = np.array([[1, 3], [5, 7]], dtype=dt) - OpArgMngr.add_workload('linalg.norm', A) - OpArgMngr.add_workload('linalg.norm', A, 'fro') + OpArgMngr.add_workload('linalg.norm', A, 2) + OpArgMngr.add_workload('linalg.norm', A, -2) + OpArgMngr.add_workload('linalg.norm', A, 'nuc') A = (1 / 10) * np.array([[1, 2, 3], [6, 0, 5], [3, 2, 1]], dtype=dt) OpArgMngr.add_workload('linalg.norm', A) OpArgMngr.add_workload('linalg.norm', A, 'fro') - for dt in [np.float16, np.float32, np.float64]: + OpArgMngr.add_workload('linalg.norm', A, 1) + OpArgMngr.add_workload('linalg.norm', A, -1) + for dt in [np.float32, np.float64]: OpArgMngr.add_workload('linalg.norm', np.array([[1, 0, 1], [0, 1, 1]], dtype=dt)) OpArgMngr.add_workload('linalg.norm', np.array([[1, 0, 1], [0, 1, 1]], dtype=dt), 'fro') diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c4e3b9a9c2df..f8d6817ea444 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3894,7 +3894,6 @@ def hybrid_forward(self, F, x): @with_seed() @use_np def test_np_linalg_norm(): - @use_np class TestLinalgNorm(HybridBlock): def __init__(self, ord=None, axis=None, keepdims=False): super(TestLinalgNorm, self).__init__() @@ -3905,21 +3904,110 @@ def __init__(self, ord=None, axis=None, keepdims=False): def hybrid_forward(self, F, x): return F.np.linalg.norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims) - a = np.arange(5 * 6 * 7 * 8).reshape((5, 6, 7, 8)) - ords = [None, 'fro'] - axes = [None, (0, 2), (1, 0), (1, 2)] - for ord in ords: - for axis in axes: - if ord == 'fro' and axis is None and a.ndim > 2: - continue - for keepdims in [False, True]: - for hybridize in [False, True]: - net = TestLinalgNorm(ord, axis, keepdims) - if hybridize: - net.hybridize() - mx_ret = net(a) - np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) - assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4) + configs = [ + ((2, 3, 4), 1, (2, 1)), + ((2, 3, 4), 2, (1, 2)), + ((2, 3, 4), None, None), + ((3,), None, None), + ((2, 3), 2, 1), + ((2, 3, 4), 1, 1), + ((2, 3, 4), -1, 2), + ((2, 3, 4), 2, 1), + ((2, 3, 4), 4, 1), + ((2, 3, 0, 4), -2, 1), + ((2, 3, 4, 5), 2, (2, 3)), + ((2, 3), -1, None), + ((2, 3, 4), 'inf', 1), + ((2, 3, 4), '-inf', (1, 0)), + ((2, 3), None, (0, 1)), + ((3, 2, 3), None, (1, 2)), + ((2, 3), None, None), + ((2, 3, 4), 'fro', (0, 2)), + ((2, 0, 4), 'fro', (0, 2)), + ((2, 3, 4), None, (0, 2)), + ((2, 3, 4), -3.2, 2), + ((2, 3, 4), -1, (0, 1)), + ((2, 3, 4), 'inf', (0, 2)), + ((2, 3, 4), '-inf', (0, 2)), + ((4, 4, 4, 4), -2, (0, 2)), + ((2, 3, 4), 'nuc', (0, 2)), + ((2, 2), 'nuc', None), + ] + + def spectral_norm_grad(data): + with mx.autograd.record(): + UT, S, V = np.linalg.svd(data) + norm = np.max(np.abs(S), axis=-1) + norm.backward() + return data.grad.asnumpy() + + # numpy is flaky under float16, also gesvd does not support fp16 + dtypes = [np.float32, np.float64] + for hybridize, itype, (shape, ord, axis), keepdims in \ + itertools.product([True, False], dtypes, configs, [True, False]): + net = TestLinalgNorm(ord, axis, keepdims) + rtol = 1e-2 + atol = 1e-2 + if hybridize: + net.hybridize() + a = mx.nd.random.uniform(-10.0, 10.0, shape=shape, dtype=itype).as_np_ndarray() + a.attach_grad() + with mx.autograd.record(): + mx_ret = net(a) + if ord == 'inf': + np_ret = _np.linalg.norm(a.asnumpy(), ord=_np.inf, axis=axis, keepdims=keepdims) + elif ord == '-inf': + np_ret = _np.linalg.norm(a.asnumpy(), ord=-_np.inf, axis=axis, keepdims=keepdims) + else: + np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + + assert np_ret.shape == mx_ret.shape + assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) + + mx_ret.backward() + + grad_axis = axis + if axis is None and len(shape) >= 2 and ord is not None: + grad_axis = (len(shape) - 2, len(shape) - 1) + elif axis is None and ord is None: + grad_axis = tuple([i for i in range(len(shape))]) + elif axis is None: + grad_axis = len(shape) - 1 + + if not keepdims and isinstance(grad_axis, tuple): + if len(grad_axis) == 2 and grad_axis[0] > grad_axis[1] and grad_axis[0] > len(np_ret.shape): + grad_axis = (grad_axis[1], grad_axis[0]) + for i in grad_axis: + np_ret = _np.expand_dims(np_ret, axis=i) + elif not keepdims: + np_ret = _np.expand_dims(np_ret, axis=grad_axis) + + if ord == 4: + backward_expected = _np.sign(a.asnumpy()) * _np.power(_np.abs(a.asnumpy()) / np_ret, ord - 1) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + if ord == 2 and not isinstance(grad_axis, tuple): + backward_expected = _np.divide(a.asnumpy(), np_ret) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + elif ord == 2 and isinstance(grad_axis, tuple): + backward_expected = spectral_norm_grad(a) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + if ord == 'fro': + backward_expected = _np.divide(a.asnumpy(), np_ret) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + assert a.grad.shape == a.shape + + # Test imperative once again + if ord == 'inf': + np_ret = _np.linalg.norm(a.asnumpy(), ord=_np.inf, axis=axis, keepdims=keepdims) + elif ord == '-inf': + np_ret = _np.linalg.norm(a.asnumpy(), ord=-_np.inf, axis=axis, keepdims=keepdims) + else: + np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + mx_ret = np.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) + assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) @with_seed()