Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds matrix sqrt #9544

Merged
merged 2 commits into from
Feb 17, 2022
Merged

Conversation

SaturdayGenfo
Copy link
Contributor

This PR adds an implementation of the matrix square root (makes progress on #2478). The algorithm is a Schur decomposition based method and is entirely performed with complex types. (Recent improvements of this method use Sylvester equation solvers which aren't available in JAX for the moment as mentioned in #6089)

To make things jittable, the main routine, computing the square root of upper-triangular matrices, uses lax.fori_loop which makes things a bit obscure. Before the jit-friendly-rewrite, the function here looked like this:

for j in range(n):
    for i in range(j-1, -1, -1):
        s = 0
        for k in range(i+1, j):
            s += U[i, k] * U[k,i]
        num = T[i, j] - s
        denom = diag[i] + diag[j]
        if num == 0.0:
            U = U.at[i,j].set(0.0)
        else:
            U = U.at[i,j].set(num/denom)  

(unsure if the proposed jit-friendly-rewrite is best)

I didn't squash the two commits for clarity. The first wraps the existing schur primitive into a jax.scipy function, the second adds jax.scipy.sqrtm.

@shoyer
Copy link
Member

shoyer commented Feb 12, 2022

Thanks for putting together! Matrix square root is definitely very welcome functionality for JAX.

How does performance compare to SciPy? I am a little concerned that iterating over individual matrix elements in JAX is going to be painfully slow. Usually we rely on compiled routines for this sort of thing, e.g., from Lapack or CuSolver.

@SaturdayGenfo
Copy link
Contributor Author

SaturdayGenfo commented Feb 12, 2022

I ran a few tests (in double precision) and it performs as well as the non-blocked version of scipy.linalg.sqrtm. But it is slower than the sylvester augmented version with default blocksize=64.

times

The execution times are the average of 5 runs of the following functions:

def time_jax(n):
    x = (np.random.randn(n,n))
    tic = time.perf_counter()
    jsp.linalg.sqrtm(x).block_until_ready()
    return time.perf_counter() - tic

def time_scipy(n):
    x = (np.random.randn(n,n))
    tic = time.perf_counter()
    sp.linalg.sqrtm(x, blocksize=1)
    return time.perf_counter() - tic

def time_scipy_sylvester(n):
    x = (np.random.randn(n,n))
    tic = time.perf_counter()
    sp.linalg.sqrtm(x)
    return time.perf_counter() - tic

@hawkinsp hawkinsp requested a review from shoyer February 14, 2022 14:58
@hawkinsp
Copy link
Member

I agree with Stephan that it is likely the scalar-loop formulation isn't ideal, especially on GPU or TPU where scalar code is very inefficient. But... we don't have a schur decomposition on GPU or TPU anyway. So that likely doesn't have a huge practical impact.

If it were possible to avoid the scalar formulation, it would be preferable, but I don't think it should block merging this.

@hawkinsp hawkinsp self-assigned this Feb 14, 2022
@SaturdayGenfo
Copy link
Contributor Author

I thought some more about this. It does feel un-jax like to have these nested impure loops and it does become slow for matrices of size exceeding 5000x5000. To avoid these loops, one option is to write a custom call to Eigen like they do in tensorflow. Only issue is that custom calls are not the easiest to write.

This version is only on CPU and does not yet have a jvp defined (although the jvp is only a sylvester solver away). Because of these reasons, there might not be so much urgency to having it included in JAX. So a possible choice is to table this PR and instead work towards writing the custom call or a sylvester solver to make this sqrtm worthwhile. What do you think ?

jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
@@ -103,7 +103,23 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
del overwrite_a, overwrite_b, turbo, check_finite
return _eigh(a, b, lower, eigvals_only, eigvals, type)

@partial(jit, static_argnames=('output',))
def _schur(a, output):
if output != "real":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should raise an error if an invalid output is provided

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a line to check the output argument schur.

jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Show resolved Hide resolved
tests/linalg_test.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Show resolved Hide resolved
@shoyer
Copy link
Member

shoyer commented Feb 14, 2022

I think this implementation is actually probably OK for a first pass. Performance within a factor of 2-4x of SciPy is not terrible and I think matrices of size 5000x5000 are very rare. JVPs and even VJPs should work fine out of the box, because you wrote this in terms of existing JAX primitives, though it's possible there are more numerically stable implementations of gradients.

@SaturdayGenfo SaturdayGenfo force-pushed the adds-matrix-sqrt branch 2 times, most recently from c5a0c32 to 624236d Compare February 15, 2022 17:21
jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
tests/linalg_test.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Show resolved Hide resolved
jax/_src/scipy/linalg.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Feb 16, 2022
@copybara-service copybara-service bot merged commit 15295a8 into google:main Feb 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants