Skip to content

Commit

Permalink
Don't mask out zero elements on the diagonal of the matrix when inver…
Browse files Browse the repository at this point in the history
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes #3589
Fixes #15429

PiperOrigin-RevId: 653562611
  • Loading branch information
hawkinsp authored and jax authors committed Jul 18, 2024
1 parent 174429d commit 47e6da3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Remember to align the itemized text with the first line of an item within a list
* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
by the jit dispatch fast path.
* Fixed a bug that meant triangular solves of batches of singular matrices
produce nonsensical finite values, instead of inf or nan (#3589, #15429).

## jax 0.4.30 (June 18, 2024)

Expand Down
11 changes: 11 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from functools import partial
import itertools
import unittest

import numpy as np
import scipy
Expand All @@ -33,6 +34,7 @@
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes_inexact

config.parse_flags_with_absl()
Expand Down Expand Up @@ -1623,6 +1625,15 @@ def testTriangularSolveGradPrecision(self):
(a, b),
(a, b))

@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
def testTriangularSolveSingularBatched(self):
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
y = jnp.array([[1], [1.]], dtype=np.float32)
out = jax.lax.linalg.triangular_solve(x[None], y[None], left_side=True)
# x is singular. The triangular solve may contain either nans or infs, but
# it should not consist of only finite values.
self.assertFalse(np.all(np.isfinite(out)))

@jtu.sample_product(
n=[1, 4, 5, 20, 50, 100],
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
Expand Down

0 comments on commit 47e6da3

Please sign in to comment.