Skip to content

Commit 45905fa

Browse files
committed
jnp.solve: handle corner case in input shapes
1 parent 78543f7 commit 45905fa

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

jax/_src/numpy/linalg.py

+2
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,8 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]
606606
def solve(a: ArrayLike, b: ArrayLike) -> Array:
607607
check_arraylike("jnp.linalg.solve", a, b)
608608
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
609+
if a.ndim >= 2 and b.ndim > a.ndim:
610+
a = lax.expand_dims(a, tuple(range(b.ndim - a.ndim)))
609611
return lax_linalg._solve(a, b)
610612

611613

tests/linalg_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ def tensor_maker():
897897
((4, 4), (4,)),
898898
((8, 8), (8, 4)),
899899
((1, 2, 2), (3, 2)),
900+
((2, 2), (3, 2, 2)),
900901
((2, 1, 3, 3), (1, 4, 3, 4)),
901902
((1, 0, 0), (1, 0, 2)),
902903
]

0 commit comments

Comments
 (0)