Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[NumPy] NumPy support for linalg.inv #16730

Merged
merged 4 commits into from
Nov 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _op as _mx_nd_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky']
__all__ = ['norm', 'svd', 'cholesky', 'inv']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -200,3 +200,45 @@ def cholesky(a):
[ 4., 10.]])
"""
return _npi.cholesky(a)


def inv(a):
r"""
Compute the (multiplicative) inverse of a matrix.

Given a square matrix `a`, return the matrix `ainv` satisfying
``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``.

Parameters
----------
a : (..., M, M) ndarray
Matrix to be inverted.

Returns
-------
ainv : (..., M, M) ndarray
(Multiplicative) inverse of the matrix `a`.

Raises
------
MXNetError
If `a` is not square or inversion fails.

Examples
--------
>>> from mxnet import np
>>> a = np.array([[1., 2.], [3., 4.]])
array([[-2. , 1. ],
[ 1.5, -0.5]])

Inverses of several matrices can be computed at once:

>>> a = np.array([[[1., 2.], [3., 4.]], [[1, 3], [3, 5]]])
>>> np.linalg.inv(a)
array([[[-2. , 1. ],
[ 1.5 , -0.5 ]],

[[-1.2500001 , 0.75000006],
[ 0.75000006, -0.25000003]]])
"""
return _npi.inv(a)
44 changes: 43 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np

__all__ = ['norm', 'svd', 'cholesky']
__all__ = ['norm', 'svd', 'cholesky', 'inv']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -218,3 +218,45 @@ def cholesky(a):
[ 4., 10.]])
"""
return _mx_nd_np.linalg.cholesky(a)


def inv(a):
r"""
Compute the (multiplicative) inverse of a matrix.

Given a square matrix `a`, return the matrix `ainv` satisfying
``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``.

Parameters
----------
a : (..., M, M) ndarray
Matrix to be inverted.

Returns
-------
ainv : (..., M, M) ndarray
(Multiplicative) inverse of the matrix `a`.

Raises
------
MXNetError
If `a` is not square or inversion fails.

Examples
--------
>>> from mxnet import np
>>> a = np.array([[1., 2.], [3., 4.]])
array([[-2. , 1. ],
[ 1.5, -0.5]])

Inverses of several matrices can be computed at once:

>>> a = np.array([[[1., 2.], [3., 4.]], [[1, 3], [3, 5]]])
>>> np.linalg.inv(a)
array([[[-2. , 1. ],
[ 1.5 , -0.5 ]],

[[-1.2500001 , 0.75000006],
[ 0.75000006, -0.25000003]]])
"""
return _mx_nd_np.linalg.inv(a)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'zeros_like',
'linalg.norm',
'linalg.cholesky',
'linalg.inv',
'trace',
'tril',
'meshgrid',
Expand Down
44 changes: 43 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import _op as _mx_sym_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky']
__all__ = ['norm', 'svd', 'cholesky', 'inv']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -187,3 +187,45 @@ def cholesky(a):
[ 4., 10.]])
"""
return _npi.cholesky(a)


def inv(a):
r"""
Compute the (multiplicative) inverse of a matrix.

Given a square matrix `a`, return the matrix `ainv` satisfying
``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``.

Parameters
----------
a : (..., M, M) ndarray
Matrix to be inverted.

Returns
-------
ainv : (..., M, M) ndarray
(Multiplicative) inverse of the matrix `a`.

Raises
------
MXNetError
If `a` is not square or inversion fails.

Examples
--------
>>> from mxnet import np
>>> a = np.array([[1., 2.], [3., 4.]])
array([[-2. , 1. ],
[ 1.5, -0.5]])

Inverses of several matrices can be computed at once:

>>> a = np.array([[[1., 2.], [3., 4.]], [[1, 3], [3, 5]]])
>>> np.linalg.inv(a)
array([[[-2. , 1. ],
[ 1.5 , -0.5 ]],

[[-1.2500001 , 0.75000006],
[ 0.75000006, -0.25000003]]])
"""
return _npi.inv(a)
3 changes: 2 additions & 1 deletion src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,8 @@ void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s); \
const Tensor<xpu, 2, int> pivot(reinterpret_cast<int *>(workspace.dptr_), \
Shape2(A.size(0), A.size(1))); \
const Tensor<xpu, 3, DType> LU(reinterpret_cast<DType *>(pivot.dptr_ + pivot.MSize()), \
int offset = pivot.MSize() & 1 ? pivot.MSize() + 1 : pivot.MSize(); \
const Tensor<xpu, 3, DType> LU(reinterpret_cast<DType *>(pivot.dptr_ + offset), \
A.shape_); \
Copy(LU, B, s); \
linalg_batch_getrf(LU, pivot, true, s); \
Expand Down
6 changes: 6 additions & 0 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ struct inverse {
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
// Since inverse(A) = trans(inverse(trans(A))), so we don't need to transpose
// A even if we are using the col-major version of getrf and getri routines.
if (B.shape_.Size() == 0U) {
return;
}
linalg_batch_inverse(A, B, ctx);
}
};
Expand Down Expand Up @@ -882,6 +885,9 @@ struct inverse_backward {
const Tensor<xpu, 3, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
// Backward of A = inverse(B)
if (dB.shape_.Size() == 0U) {
return;
}
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 3, DType> temp = ctx.requested[0]
.get_space_typed<xpu, 3, DType>(A.shape_, s);
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/la_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ NNVM_REGISTER_OP(_backward_linalg_syevd)

NNVM_REGISTER_OP(_linalg_inverse)
.add_alias("linalg_inverse")
.add_alias("_npi_inv")
.describe(R"code(Compute the inverse of a matrix.
Input is a tensor *A* of dimension *n >= 2*.

Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def _add_workload_linalg_cholesky():
OpArgMngr.add_workload('linalg.cholesky', np.array(a, dtype=dtype))


def _add_workload_linalg_inv():
OpArgMngr.add_workload('linalg.inv', np.array(_np.ones((0, 0)), dtype=np.float32))
OpArgMngr.add_workload('linalg.inv', np.array(_np.ones((0, 1, 1)), dtype=np.float64))


def _add_workload_trace():
OpArgMngr.add_workload('trace', np.random.uniform(size=(4, 1)))
OpArgMngr.add_workload('trace', np.random.uniform(size=(3, 2)))
Expand Down Expand Up @@ -1216,6 +1221,7 @@ def _prepare_workloads():
_add_workload_zeros_like(array_pool)
_add_workload_linalg_norm()
_add_workload_linalg_cholesky()
_add_workload_linalg_inv()
_add_workload_trace()
_add_workload_tril()
_add_workload_outer()
Expand Down
80 changes: 79 additions & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,7 +1669,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ltype)
mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rtype)
rtol = 1e-2 if ltype is np.float16 or rtype is np.float16 else 1e-3
atol = 1e-4 if ltype is np.float16 or rtype is np.float16 else 1e-5
atol = 1e-3 if ltype is np.float16 or rtype is np.float16 else 1e-5
for hybridize in [True, False]:
if hybridize:
mx_func.hybridize()
Expand Down Expand Up @@ -3046,6 +3046,84 @@ def check_cholesky(L, data_np):
check_cholesky(L, data_np)


@with_seed()
@use_np
def test_np_linalg_inv():
class TestInverse(HybridBlock):
def __init__(self):
super(TestInverse, self).__init__()

def hybrid_forward(self, F, data):
return F.np.linalg.inv(data)

def get_grad(A):
if 0 in A.shape:
return A

dA = _np.ones_like(A)
A_inv = _np.linalg.inv(A)
dA_inv = -_np.matmul(_np.matmul(A_inv, dA), A_inv)
return _np.swapaxes(dA_inv, -1, -2)

def check_inv(A_inv, data_np):
assert A_inv.shape == data_np.shape
# catch error if numpy throws rank < 2
try:
A_expected = _np.linalg.inv(data_np)
except Exception as e:
print(data_np)
print(data_np.shape)
print(e)
else:
assert A_inv.shape == A_expected.shape
assert_almost_equal(A_inv.asnumpy(), A_expected, rtol=rtol, atol=atol)

shapes = [
(0, 0),
(4, 4),
(2, 2),
(1, 1),
(2, 1, 1),
(0, 1, 1),
(6, 1, 1),
(2, 3, 3, 3),
(4, 2, 1, 1),
(0, 5, 3, 3),
(5, 0, 0, 0),
(3, 3, 0, 0),
(3, 5, 5),
]
dtypes = ['float32', 'float64']
for hybridize, dtype, shape in itertools.product([True, False], dtypes, shapes):
atol = rtol = 1e-2

test_inv = TestInverse()
if hybridize:
test_inv.hybridize()
# use LU decomposition to generate invertible matrices
if 0 in shape:
data_np = _np.ones(shape)
else:
n = shape[-1]
L = _np.tril(_np.random.uniform(-10., 10., shape))
U = _np.triu(_np.random.uniform(-10., 10., shape))
data_np = _np.matmul(L, U)
data = np.array(data_np, dtype=dtype)
data.attach_grad()
with mx.autograd.record():
A_inv = test_inv(data)

# check cholesky validity
check_inv(A_inv, data_np)
# check backward. backward does not support empty input
mx.autograd.backward(A_inv)
backward_expected = get_grad(data.asnumpy())
assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
# check imperative once again
A_inv = np.linalg.inv(data)
check_inv(A_inv, data_np)


@with_seed()
@use_np
def test_np_vstack():
Expand Down