Improving performance and numerical stability for solving Toeplitz systems. #26015
Unanswered
tillahoffmann
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Summary
The covariance matrix of Gaussian processes often has a Toeplitz structure, i.e., the elements on each diagonal are constant1. This structure can be exploited to evaluate the log likelihood of Gaussian process realizations more efficiently than having to solve a general linear system, e.g., using Levinson-Durbin recursion. Running variational inference through
numpyro
, I get about a 5x speed up in terms of training time using the implementation below (although I'm struggling to reproduce the performance improvement in simple benchmarks). It would be great to get your opinion on two questions.jax.lax.scan
likely has room for improvement.Details
The log-likelihood for a realization$\mathbf x$ of a Gaussian process with zero mean covariance matrix $\mathbf C$ is (up to constants)
If$\mathbf{C}$ is Toeplitz, we can solve $\mathbf{C}\mathbf{z} = \mathbf{x}$ for $\mathbf{z}$ and evaluate the log absolute determinant in one pass to get
We also don't need to evaluate the full matrix$\mathbf{C}$ but can get away with only evaluating the first row because all other elements are specified by the first row for a symmetric Toeplitz matrix. The following function implements the solver and evaluation of log absolute determinant.
Notes Regarding Performance
mask
in thecarry
and updating it seems to be more performant than creating amask = jnp.arange(n) < m
in the @body
.g
might be a bit of a time hog, but I couldn't figure out a better way without indexing (and thus creating intermediate arrays with variable shape).Notes Regarding Numerical Stability
Covariances matrices for Gaussian processes with relatively long correlation lengths have high condition number which can degrade the accuracy of the above algorithm. Below, I've shown an example of the first row of the covariance matrix and reconstruction error using different methods for a random vector from a standard normal distribution. The
solve_toeplitz
method exhibits relatively strong oscillatory behavior around index 5.This same oscillatory behavior is observed in some of the working memory of the algorithm. I've kept track of some of the working variables and plotted the below (each row corresponds to one iteration of the loop over
m
in the implementation.Squinting at these, maybe the oscillatory errors have something to do with rounding errors in the summation of
a_rev * x
?Thanks for making it this far and any input you might have!
Footnotes
That's the case on a regularly spaced grid in one dimension for a stationary covariance kernel. ↩
Beta Was this translation helpful? Give feedback.
All reactions