diff --git a/CHANGELOG.md b/CHANGELOG.md index 83b3ab874f2a..15c5633c62b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ Remember to align the itemized text with the first line of an item within a list * Bug fixes * Fixed a bug that meant that negative static_argnums to a jit were mishandled by the jit dispatch fast path. + * Fixed a bug that meant triangular solves of batches of singular matrices + produce nonsensical finite values, instead of inf or nan (#3589, #15429). ## jax 0.4.30 (June 18, 2024) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2a64b95b9452..dd0ae38d9aa4 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,7 @@ from functools import partial import itertools +import unittest import numpy as np import scipy @@ -33,6 +34,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import xla_extension_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -1623,6 +1625,15 @@ def testTriangularSolveGradPrecision(self): (a, b), (a, b)) + @unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30") + def testTriangularSolveSingularBatched(self): + x = jnp.array([[1, 1], [0, 0]], dtype=np.float32) + y = jnp.array([[1], [1.]], dtype=np.float32) + out = jax.lax.linalg.triangular_solve(x[None], y[None], left_side=True) + # x is singular. The triangular solve may contain either nans or infs, but + # it should not consist of only finite values. + self.assertFalse(np.all(np.isfinite(out))) + @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],