Skip to content

Commit

Permalink
Unconditionally lower jnp.dot to lax.dot_general.
Browse files Browse the repository at this point in the history
jax-ml#16721 added a condition to lower
calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
`lax.dot_general`. AFAICT, jax-ml#16826
fixed the issue that this was solving, so this condition should no
longer be necessary. Removing this condition simplifies the addition of
new arguments to `dot` and `dot_general`, including the `algorithm`
parameter that I am currently working on in
jax-ml#23574, so now seemed like a good time
to remove it!
  • Loading branch information
dfm committed Sep 16, 2024
1 parent 8c39d03 commit 95d67a5
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7364,20 +7364,17 @@ def dot(a: ArrayLike, b: ArrayLike, *,
batch_dims = ((), ())
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
# TODO(jakevdp): lower this case to dot_general as well?
# Currently, doing so causes issues in remat tests due to #16805
if preferred_element_type is not None:
a = a.astype(preferred_element_type)
b = b.astype(preferred_element_type)
result = lax.mul(a, b)
contract_dims = ((), ())
else:
if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims),
precision=precision, preferred_element_type=preferred_element_type)
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims),
precision=precision,
preferred_element_type=preferred_element_type)
return lax_internal._convert_element_type(result, preferred_element_type,
output_weak_type)


@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
Expand Down

0 comments on commit 95d67a5

Please sign in to comment.