LinearOperator is a PyTorch package for abstracting away the linear algebra routines needed for structured matrices (or operators).
This package is in beta. Currently, most of the functionality only supports positive semi-definite and triangular matrices. Package development TODOs:
- Support PSD operators
- Support triangular operators
- Interface to specify structure (i.e. symmetric, triangular, PSD, etc.)
- Add algebraic routines for symmetric operators
- Add algebraic routines for generic square operators
- Add algebraic routines for generic rectangular operators
- Add sparse operators
To get started, run either
pip install linear_operator
# or
conda install linear_operator -c gpytorch
or see below for more detailed instructions.
Before describing what linear operators are and why they make a useful abstraction, it's easiest to see an example. Let's say you wanted to compute a matrix solve:
If you didn't know anything about the matrix
# A = torch.randn(1000, 1000)
# b = torch.randn(1000)
torch.linalg.solve(A, b) # computes A^{-1} b
While this is easy, the solve
routine is
However, let's imagine that we knew that
Implementing the efficient solve that exploits
def low_rank_plus_diagonal_solve(C, d, b):
# A = C C^T + diag(d)
# A^{-1} b = D^{-1} b - D^{-1} C (I + C^T D^{-1} C)^{-1} C^T D^{-1} b
# where D = diag(d)
D_inv_b = b / d
D_inv_C = C / d.unsqueeze(-1)
eye = torch.eye(C.size(-2))
return (
D_inv_b - D_inv_C @ torch.cholesky_solve(
C.mT @ D_inv_b,
torch.linalg.cholesky(eye + C.mT @ D_inv_C, upper=False),
upper=False
)
)
# C = torch.randn(1000, 20)
# d = torch.randn(1000)
# b = torch.randn(1000)
low_rank_plus_diagonal_solve(C, d, b) # computes A^{-1} b in O(N) time, instead of O(N^3)
While this is efficient code, it's not ideal for a number of reasons:
- It's a lot more complicated than
torch.linalg.solve(A, b)
. - There's no object that represents
$\boldsymbol A$ . To perform any math with$\boldsymbol A$ , we have to pass around the matrixC
and the vectord
.
The LinearOperator package offers the best of both worlds:
from linear_operator.operators import DiagLinearOperator, LowRankRootLinearOperator
# C = torch.randn(1000, 20)
# d = torch.randn(1000)
# b = torch.randn(1000)
A = LowRankRootLinearOperator(C) + DiagLinearOperator(d) # represents C C^T + diag(d)
it provides an interface that lets us treat
torch.linalg.solve(A, b) # computes A^{-1} b efficiently!
Under-the-hood, the LinearOperator
object keeps track of the algebraic structure of
Crucially,
# C = torch.randn(10000000, 20)
# d = torch.randn(10000000)
# b = torch.randn(10000000)
A = LowRankRootLinearOperator(C) + DiagLinearOperator(d) # represents a 10M x 10M matrix!
torch.linalg.solve(A, b) # computes A^{-1} b efficiently!
A linear operator is a generalization of a matrix. It is a linear function that is defined in by its application to a vector. The most common linear operators are (potentially structured) matrices, where the function applying them to a vector are (potentially efficient) matrix-vector multiplication routines.
In code, a LinearOperator
is a class that
- specifies the tensor(s) needed to define the LinearOperator,
- specifies a
_matmul
function (how the LinearOperator is applied to a vector), - specifies a
_size
function (how big is the LinearOperator if it is represented as a matrix, or batch of matrices), and - specifies a
_transpose_nonbatch
function (the adjoint of the LinearOperator). - (optionally) defines other functions (e.g.
logdet
,eigh
, etc.) to accelerate computations for which efficient sturcture-exploiting routines exist.
For example:
class DiagLinearOperator(linear_operator.LinearOperator):
r"""
A LinearOperator representing a diagonal matrix.
"""
def __init__(self, diag):
# diag: the vector that defines the diagonal of the matrix
self.diag = diag
def _matmul(self, v):
return self.diag.unsqueeze(-1) * v
def _size(self):
return torch.Size([*self.diag.shape, self.diag.size(-1)])
def _transpose_nonbatch(self):
return self # Diagonal matrices are symmetric
# this function is optional, but it will accelerate computation
def logdet(self):
return self.diag.log().sum(dim=-1)
# ...
D = DiagLinearOperator(torch.tensor([1., 2., 3.])
# Represents the matrix
# [[1., 0., 0.],
# [0., 2., 0.],
# [0., 0., 3.]]
torch.matmul(D, torch.tensor([4., 5., 6.])
# Returns [4., 10., 18.]
While _matmul
, _size
, and _transpose_nonbatch
might seem like a limited set of functions,
it turns out that most functions on the torch
and torch.linalg
namespaces can be efficiently implemented
using only these three primitative functions.
Moreover, because _matmul
is a linear function, it is very easy to compose linear operators in various ways.
For example: adding two linear operators (SumLinearOperator
) just requires adding the output of their _matmul
functions.
This makes it possible to define very complex compositional structures that still yield efficient linear algebraic routines.
Finally, LinearOperator
objects can be composed with one another, yielding new LinearOperator
objects and automatically keeping track of algebraic structure after each computation.
As a result, users never need to reason about what efficient linear algebra routines to use (so long as the input elements defined by the user encode known input structure).
See the using LinearOperator objects section for more details.
There are several use cases for the LinearOperator package. Here we highlight two general themes:
For example, let's say that you have a generative model that involves sampling from a high-dimensional multivariate Gaussian. This sampling operation will require storing and manipulating a large covariance matrix, so to speed things up you might want to experiment with different structured approximations of that covariance matrix. This is easy with the LinearOperator package.
from gpytorch.distributions import MultivariateNormal
# variance = torch.randn(10000)
cov = DiagLinearOperator(variance)
# or
# cov = LowRankRootLinearOperator(...) + DiagLinearOperator(...)
# or
# cov = KroneckerProductLinearOperator(...)
# or
# cov = ToeplitzLinearOperator(...)
# or
# ...
mvn = MultivariateNormal(torch.zeros(cov.size(-1), cov) # 10000-dimensional MVN
mvn.rsample() # returns a 10000-dimensional vector
Many of the efficient linear algebra routines in LinearOperator are iterative algorithms based on matrix-vector multiplication. Since matrix-vector multiplication obeys many nice compositional properties it is possible to obtain efficient routines for extremely complex compositional LienarOperators:
from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator, ToeplitzLinearOperator
# mat1 = 200 x 200 PSD matrix
# mat2 = 100 x 100 PSD matrix
# vec3 = 20000 vector
A = KroneckerProductLinearOperator(mat1, mat2) + RootLinearOperator(ToeplitzLinearOperator(vec3))
# represents a 20000 x 20000 matrix
torch.linalg.solve(A, torch.randn(20000)) # Sub O(N^3) routine!
LinearOperator objects share (mostly) the same API as torch.Tensor
objects.
Under the hood, these objects use __torch_function__
to dispatch all efficient linear algebra operations
to the torch
and torch.linalg
namespaces.
This includes
torch.add
torch.cat
torch.clone
torch.diagonal
torch.dim
torch.div
torch.expand
torch.logdet
torch.matmul
torch.numel
torch.permute
torch.prod
torch.squeeze
torch.sub
torch.sum
torch.transpose
torch.unsqueeze
torch.linalg.cholesky
torch.linalg.eigh
torch.linalg.eigvalsh
torch.linalg.solve
torch.linalg.svd
Each of these functions will either return a torch.Tensor
, or a new LinearOperator
object,
depending on the function.
For example:
# A = RootLinearOperator(...)
# B = ToeplitzLinearOperator(...)
# d = vec
C = torch.matmul(A, B) # A new LienearOperator representing the product of A and B
torch.linalg.solve(C, d) # A torch.Tensor
For more examples, see the examples folder.
LinearOperator
objects operate naturally in batch mode.
For example, to represent a batch of 3 100 x 100
diagonal matrices:
# d = torch.randn(3, 100)
D = DiagLinearOperator(d) # Reprents an operator of size 3 x 100 x 100
These objects fully support broadcasted operations:
D @ torch.randn(100, 2) # Returns a tensor of size 3 x 100 x 2
D2 = DiagLinearOperator(torch.randn([2, 1, 100])) # Represents an operator of size 2 x 1 x 100 x 100
D2 + D # Represents an operator of size 2 x 3 x 100 x 100
LinearOperator
objects can be indexed in ways similar to torch Tensors. This includes:
- Integer indexing (get a row, column, or batch)
- Slice indexing (get a subset of rows, columns, or batches)
- LongTensor indexing (get a set of individual entries by index)
- Ellipses (support indexing operations with arbitrary batch dimensions)
D = DiagLinearOperator(torch.randn(2, 3, 100)) # Represents an operator of size 2 x 3 x 100 x 100
D[-1] # Returns a 3 x 100 x 100 operator
D[..., :10, -5:] # Returns a 2 x 3 x 10 x 5 operator
D[..., torch.LongTensor([0, 1, 2, 3]), torch.LongTensor([0, 1, 2, 3])] # Returns a 2 x 3 x 4 tensor
LinearOperators can be composed with one another in various ways. This includes
- Addition (
LinearOpA + LinearOpB
) - Matrix multiplication (
LinearOpA @ LinearOpB
) - Concatenation (
torch.cat([LinearOpA, LinearOpB], dim=-2)
) - Kronecker product (
torch.kron(LinearOpA, LinearOpB)
)
In addition, there are many ways to "decorate" LinearOperator objects. This includes:
- Elementwise multiplying by constants (
torch.mul(2., LinearOpA)
) - Summing over batches (
torch.sum(LinearOpA, dim=-3)
) - Elementwise multiplying over batches (
torch.prod(LinearOpA, dim=-3)
)
See the documentation for a full list of supported composition and decoration operations.
LinearOperator requires Python >= 3.8.
We recommend installing via pip
or Anaconda:
pip install linear_operator
# or
conda install linear_operator -c gpytorch
The installation requires the following packages:
- PyTorch >= 1.11
- Scipy
You can customize your PyTorch installation (i.e. CUDA version, CPU only option) by following the PyTorch installation instructions.
To install what is currently on the main
branch (potentially buggy and unstable):
pip install --upgrade git+https://github.com/cornellius-gp/linear_operator.git
If you are contributing a pull request, it is best to perform a manual installation:
git clone https://github.com/cornellius-gp/linear_operator.git
cd linear_operator
pip install -e ".[dev,docs,test]"
See the contributing guidelines CONTRIBUTING.md for information on submitting issues and pull requests.
LinearOperator is MIT licensed.