From cebd8c04be166edc401497629722ff653d1509c7 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 22 May 2024 16:34:54 -0400 Subject: [PATCH 1/4] [ADD] Implement adjoint for sub-matrix linear operator --- curvlinops/submatrix.py | 12 +++++++++++ test/test_submatrix.py | 47 +++++++++++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/curvlinops/submatrix.py b/curvlinops/submatrix.py index 7f3e260e..3fda4a8e 100644 --- a/curvlinops/submatrix.py +++ b/curvlinops/submatrix.py @@ -1,5 +1,7 @@ """Implements slices of linear operators.""" +from __future__ import annotations + from typing import List from numpy import column_stack, ndarray, zeros @@ -78,3 +80,13 @@ def _matmat(self, X: ndarray) -> ndarray: ``A[row_idxs, :][:, col_idxs] @ x``. Has shape ``[len(row_idxs), N]``. """ return column_stack([self @ col for col in X.T]) + + def _adjoint(self) -> SubmatrixLinearOperator: + """Return the adjoint of the sub-matrix. + + For that, we need to take the adjoint operator, and swap row and column indices. + + Returns: + The linear operator for the adjoint sub-matrix. + """ + return type(self)(self._A.adjoint(), self._col_idxs, self._row_idxs) diff --git a/test/test_submatrix.py b/test/test_submatrix.py index 216efc36..dadeb4ba 100644 --- a/test/test_submatrix.py +++ b/test/test_submatrix.py @@ -3,7 +3,7 @@ from typing import List, Tuple from numpy import eye, ndarray, random -from pytest import fixture, raises +from pytest import fixture, mark, raises from scipy.sparse.linalg import aslinearoperator from curvlinops.examples.utils import report_nonclose @@ -34,29 +34,62 @@ def submatrix_case(request) -> Tuple[ndarray, List[int], List[int]]: return case["A_fn"](), case["row_idxs_fn"](), case["col_idxs_fn"]() -def test_SubmatrixLinearOperator__matvec(submatrix_case): +@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"]) +def test_SubmatrixLinearOperator__matvec( + submatrix_case: Tuple[ndarray, List[int], List[int]], adjoint: bool +): + """Test the matrix-vector multiplication of a submatrix linear operator. + + Args: + submatrix_case: A tuple with a random matrix and two index lists. + adjoint: Whether to take the operator's adjoint before multiplying. + """ A, row_idxs, col_idxs = submatrix_case A_sub = A[row_idxs, :][:, col_idxs] A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs) - x = random.rand(len(col_idxs)) + if adjoint: + A_sub = A_sub.conj().T + A_sub_linop = A_sub_linop.adjoint() + + x = random.rand(A_sub.shape[1]) A_sub_linop_x = A_sub_linop @ x - assert A_sub_linop_x.shape == (len(row_idxs),) + assert (len(col_idxs),) if adjoint else A_sub_linop_x.shape == (len(row_idxs),) report_nonclose(A_sub @ x, A_sub_linop_x) -def test_SubmatrixLinearOperator__matmat(submatrix_case, num_vecs: int = 3): +@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"]) +def test_SubmatrixLinearOperator__matmat( + submatrix_case: Tuple[ndarray, List[int], List[int]], + adjoint: bool, + num_vecs: int = 3, +): + """Test the matrix-matrix multiplication of a submatrix linear operator. + + Args: + submatrix_case: A tuple with a random matrix and two index lists. + adjoint: Whether to take the operator's adjoint before multiplying. + num_vecs: The number of vectors to multiply. Default: ``3``. + """ A, row_idxs, col_idxs = submatrix_case A_sub = A[row_idxs, :][:, col_idxs] A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs) - X = random.rand(len(col_idxs), num_vecs) + if adjoint: + A_sub = A_sub.conj().T + A_sub_linop = A_sub_linop.adjoint() + + X = random.rand(A_sub.shape[1], num_vecs) A_sub_linop_X = A_sub_linop @ X - assert A_sub_linop_X.shape == (len(row_idxs), num_vecs) + assert ( + (len(col_idxs), num_vecs) + if adjoint + else A_sub_linop_X.shape == (len(row_idxs), num_vecs) + ) report_nonclose(A_sub @ X, A_sub_linop_X) From 8b776df643c92b7e6d80a151cd453fede861606b Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 22 May 2024 16:42:30 -0400 Subject: [PATCH 2/4] [FIX] assert statements --- test/test_submatrix.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/test_submatrix.py b/test/test_submatrix.py index dadeb4ba..7532905b 100644 --- a/test/test_submatrix.py +++ b/test/test_submatrix.py @@ -56,7 +56,7 @@ def test_SubmatrixLinearOperator__matvec( x = random.rand(A_sub.shape[1]) A_sub_linop_x = A_sub_linop @ x - assert (len(col_idxs),) if adjoint else A_sub_linop_x.shape == (len(row_idxs),) + assert A_sub_linop_x.shape == ((len(col_idxs),) if adjoint else (len(row_idxs),)) report_nonclose(A_sub @ x, A_sub_linop_x) @@ -85,10 +85,8 @@ def test_SubmatrixLinearOperator__matmat( X = random.rand(A_sub.shape[1], num_vecs) A_sub_linop_X = A_sub_linop @ X - assert ( - (len(col_idxs), num_vecs) - if adjoint - else A_sub_linop_X.shape == (len(row_idxs), num_vecs) + assert A_sub_linop_X.shape == ( + (len(col_idxs), num_vecs) if adjoint else (len(row_idxs), num_vecs) ) report_nonclose(A_sub @ X, A_sub_linop_X) From ba51fbb9d5e28c24a4109b66bce7b022b5360918 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 22 May 2024 17:03:20 -0400 Subject: [PATCH 3/4] [REQ] Try adding `setuptools` to requirements to fix RTD build --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 3527794b..ece6a8ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,6 +75,7 @@ lint = # Dependencies needed to build/view the documentation (semicolon/line-separated) docs = + setuptools transformers datasets matplotlib From 17ccf6dfc7d0889b7be6b82a5c27c060a1ea2922 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 22 May 2024 17:15:38 -0400 Subject: [PATCH 4/4] [REQ] Try using `setuptools<70` to fix RTD build --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index ece6a8ec..75924ff2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,7 +75,7 @@ lint = # Dependencies needed to build/view the documentation (semicolon/line-separated) docs = - setuptools + setuptools==69.5.1 # RTD fails with setuptools>=70, see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/15863 transformers datasets matplotlib