From 054ceeb5fa6d3153216c6bf3c158ad48acb87192 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 5 Nov 2019 09:59:36 +0000 Subject: [PATCH 1/3] Add NumPy support for inv --- python/mxnet/ndarray/numpy/linalg.py | 44 ++++++++++- python/mxnet/numpy/linalg.py | 44 ++++++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/linalg.py | 44 ++++++++++- src/operator/tensor/la_op-inl.h | 6 ++ src/operator/tensor/la_op.cc | 1 + .../unittest/test_numpy_interoperability.py | 6 ++ tests/python/unittest/test_numpy_op.py | 77 +++++++++++++++++++ 8 files changed, 220 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index b7766ad4700e..d443d68cda0d 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -21,7 +21,7 @@ from . import _op as _mx_nd_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky'] +__all__ = ['norm', 'svd', 'cholesky', 'inv'] def norm(x, ord=None, axis=None, keepdims=False): @@ -200,3 +200,45 @@ def cholesky(a): [ 4., 10.]]) """ return _npi.cholesky(a) + + +def inv(a): + r""" + Compute the (multiplicative) inverse of a matrix. + + Given a square matrix `a`, return the matrix `ainv` satisfying + ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``. + + Parameters + ---------- + a : (..., M, M) ndarray + Matrix to be inverted. + + Returns + ------- + ainv : (..., M, M) ndarray + (Multiplicative) inverse of the matrix `a`. + + Raises + ------ + MXNetError + If `a` is not square or inversion fails. + + Examples + -------- + >>> from mxnet import np + >>> a = np.array([[1., 2.], [3., 4.]]) + array([[-2. , 1. ], + [ 1.5, -0.5]]) + + Inverses of several matrices can be computed at once: + + >>> a = np.array([[[1., 2.], [3., 4.]], [[1, 3], [3, 5]]]) + >>> np.linalg.inv(a) + array([[[-2. , 1. ], + [ 1.5 , -0.5 ]], + + [[-1.2500001 , 0.75000006], + [ 0.75000006, -0.25000003]]]) + """ + return _npi.inv(a) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 402171e24aed..b0552ee9f319 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['norm', 'svd', 'cholesky'] +__all__ = ['norm', 'svd', 'cholesky', 'inv'] def norm(x, ord=None, axis=None, keepdims=False): @@ -218,3 +218,45 @@ def cholesky(a): [ 4., 10.]]) """ return _mx_nd_np.linalg.cholesky(a) + + +def inv(a): + r""" + Compute the (multiplicative) inverse of a matrix. + + Given a square matrix `a`, return the matrix `ainv` satisfying + ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``. + + Parameters + ---------- + a : (..., M, M) ndarray + Matrix to be inverted. + + Returns + ------- + ainv : (..., M, M) ndarray + (Multiplicative) inverse of the matrix `a`. + + Raises + ------ + MXNetError + If `a` is not square or inversion fails. + + Examples + -------- + >>> from mxnet import np + >>> a = np.array([[1., 2.], [3., 4.]]) + array([[-2. , 1. ], + [ 1.5, -0.5]]) + + Inverses of several matrices can be computed at once: + + >>> a = np.array([[[1., 2.], [3., 4.]], [[1, 3], [3, 5]]]) + >>> np.linalg.inv(a) + array([[[-2. , 1. ], + [ 1.5 , -0.5 ]], + + [[-1.2500001 , 0.75000006], + [ 0.75000006, -0.25000003]]]) + """ + return _mx_nd_np.linalg.inv(a) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index ae92b9885217..025982cfc7a5 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -126,6 +126,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'zeros_like', 'linalg.norm', 'linalg.cholesky', + 'linalg.inv', 'trace', 'tril', 'meshgrid', diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index d06e2d9d4aac..3c8118550559 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -22,7 +22,7 @@ from . import _op as _mx_sym_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky'] +__all__ = ['norm', 'svd', 'cholesky', 'inv'] def norm(x, ord=None, axis=None, keepdims=False): @@ -187,3 +187,45 @@ def cholesky(a): [ 4., 10.]]) """ return _npi.cholesky(a) + + +def inv(a): + r""" + Compute the (multiplicative) inverse of a matrix. + + Given a square matrix `a`, return the matrix `ainv` satisfying + ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``. + + Parameters + ---------- + a : (..., M, M) ndarray + Matrix to be inverted. + + Returns + ------- + ainv : (..., M, M) ndarray + (Multiplicative) inverse of the matrix `a`. + + Raises + ------ + MXNetError + If `a` is not square or inversion fails. + + Examples + -------- + >>> from mxnet import np + >>> a = np.array([[1., 2.], [3., 4.]]) + array([[-2. , 1. ], + [ 1.5, -0.5]]) + + Inverses of several matrices can be computed at once: + + >>> a = np.array([[[1., 2.], [3., 4.]], [[1, 3], [3, 5]]]) + >>> np.linalg.inv(a) + array([[[-2. , 1. ], + [ 1.5 , -0.5 ]], + + [[-1.2500001 , 0.75000006], + [ 0.75000006, -0.25000003]]]) + """ + return _npi.inv(a) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index c3ed4fef5cd3..35edbbd42565 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -463,6 +463,9 @@ struct inverse { const OpContext& ctx, const nnvm::NodeAttrs& attrs) { // Since inverse(A) = trans(inverse(trans(A))), so we don't need to transpose // A even if we are using the col-major version of getrf and getri routines. + if (B.shape_.Size() == 0U) { + return; + } linalg_batch_inverse(A, B, ctx); } }; @@ -882,6 +885,9 @@ struct inverse_backward { const Tensor& dB, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { // Backward of A = inverse(B) + if (dB.shape_.Size() == 0U) { + return; + } Stream *s = ctx.get_stream(); Tensor temp = ctx.requested[0] .get_space_typed(A.shape_, s); diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 3d0e43251e03..8f9da23711a0 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -893,6 +893,7 @@ NNVM_REGISTER_OP(_backward_linalg_syevd) NNVM_REGISTER_OP(_linalg_inverse) .add_alias("linalg_inverse") +.add_alias("_npi_inv") .describe(R"code(Compute the inverse of a matrix. Input is a tensor *A* of dimension *n >= 2*. diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index d2620cf8551e..6d9c63f9f857 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -273,6 +273,11 @@ def _add_workload_linalg_cholesky(): OpArgMngr.add_workload('linalg.cholesky', np.array(a, dtype=dtype)) +def _add_workload_linalg_inv(): + OpArgMngr.add_workload('linalg.inv', np.array(_np.ones((0, 0)), dtype=np.float32)) + OpArgMngr.add_workload('linalg.inv', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) + + def _add_workload_trace(): OpArgMngr.add_workload('trace', np.random.uniform(size=(4, 1))) OpArgMngr.add_workload('trace', np.random.uniform(size=(3, 2))) @@ -1216,6 +1221,7 @@ def _prepare_workloads(): _add_workload_zeros_like(array_pool) _add_workload_linalg_norm() _add_workload_linalg_cholesky() + _add_workload_linalg_inv() _add_workload_trace() _add_workload_tril() _add_workload_outer() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e6d12da23582..4e4c32eae74e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2983,6 +2983,83 @@ def check_cholesky(L, data_np): check_cholesky(L, data_np) +@with_seed() +@use_np +def test_np_linalg_inv(): + class TestInverse(HybridBlock): + def __init__(self): + super(TestInverse, self).__init__() + + def hybrid_forward(self, F, data): + return F.np.linalg.inv(data) + + def get_grad(A): + if 0 in A.shape: + return A + + dA = _np.ones_like(A) + A_inv = _np.linalg.inv(A) + dA_inv = -_np.matmul(_np.matmul(A_inv, dA), A_inv) + return _np.swapaxes(dA_inv, -1, -2) + + def check_inv(A_inv, data_np): + assert A_inv.shape == data_np.shape + # catch error if numpy throws rank < 2 + try: + A_expected = _np.linalg.inv(data_np) + except Exception as e: + print(data_np) + print(data_np.shape) + print(e) + else: + assert A_inv.shape == A_expected.shape + assert_almost_equal(A_inv.asnumpy(), A_expected, rtol=rtol, atol=atol) + + shapes = [ + (1, 1), + (2, 2), + (3, 3), + (0, 0), + (5, 5), + (0, 1, 1), + (6, 1, 1), + (2, 3, 3, 3), + (4, 2, 1, 1), + (0, 5, 3, 3), + (5, 0, 0, 0), + (3, 3, 0, 0), + ] + dtypes = ['float32', 'float64'] + for hybridize, dtype, shape in itertools.product([True, False], dtypes, shapes): + atol = rtol = 1e-2 + + test_inv = TestInverse() + if hybridize: + test_inv.hybridize() + # use LU decomposition to generate invertible matrices + if 0 in shape: + data_np = _np.ones(shape) + else: + n = shape[-1] + L = _np.tril(_np.random.uniform(-10., 10., shape)) + U = _np.triu(_np.random.uniform(-10., 10., shape)) + data_np = _np.matmul(L, U) + data = np.array(data_np, dtype=dtype) + data.attach_grad() + with mx.autograd.record(): + A_inv = test_inv(data) + + # check cholesky validity + check_inv(A_inv, data_np) + # check backward. backward does not support empty input + mx.autograd.backward(A_inv) + backward_expected = get_grad(data.asnumpy()) + assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + # check imperative once again + A_inv = np.linalg.inv(data) + check_inv(A_inv, data_np) + + @with_seed() @use_np def test_np_vstack(): From 6a14f09ff335c339c51e798deea3d4031972db60 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 6 Nov 2019 10:35:55 +0000 Subject: [PATCH 2/3] fix CUDA float64 memory alignment bug --- src/operator/linalg_impl.h | 3 ++- tests/python/unittest/test_numpy_op.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 97d40099d4fd..d83eb0d08815 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1598,7 +1598,8 @@ void linalg_batch_inverse(const Tensor& A, \ get_space_typed(Shape1(workspace_size), s); \ const Tensor pivot(reinterpret_cast(workspace.dptr_), \ Shape2(A.size(0), A.size(1))); \ - const Tensor LU(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ + int offset = pivot.MSize() & 1 ? pivot.MSize() + 1 : pivot.MSize(); \ + const Tensor LU(reinterpret_cast(pivot.dptr_ + offset), \ A.shape_); \ Copy(LU, B, s); \ linalg_batch_getrf(LU, pivot, true, s); \ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 4e4c32eae74e..47965b0086b8 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3016,11 +3016,11 @@ def check_inv(A_inv, data_np): assert_almost_equal(A_inv.asnumpy(), A_expected, rtol=rtol, atol=atol) shapes = [ - (1, 1), - (2, 2), - (3, 3), (0, 0), - (5, 5), + (4, 4), + (2, 2), + (1, 1), + (2, 1, 1), (0, 1, 1), (6, 1, 1), (2, 3, 3, 3), @@ -3028,6 +3028,7 @@ def check_inv(A_inv, data_np): (0, 5, 3, 3), (5, 0, 0, 0), (3, 3, 0, 0), + (3, 5, 5), ] dtypes = ['float32', 'float64'] for hybridize, dtype, shape in itertools.product([True, False], dtypes, shapes): From 981eef4cab98a56c972aa4cf8401fa9afa23a500 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 7 Nov 2019 07:18:50 +0000 Subject: [PATCH 3/3] make test_mixed_precision more tolerant --- tests/python/unittest/test_numpy_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2cc015bfd38f..1b710e30390e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1669,7 +1669,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ltype) mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rtype) rtol = 1e-2 if ltype is np.float16 or rtype is np.float16 else 1e-3 - atol = 1e-4 if ltype is np.float16 or rtype is np.float16 else 1e-5 + atol = 1e-3 if ltype is np.float16 or rtype is np.float16 else 1e-5 for hybridize in [True, False]: if hybridize: mx_func.hybridize()