From 95d67a545f2d76e53b3952652d9490b8b791aacd Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 16 Sep 2024 14:18:29 -0400 Subject: [PATCH] Unconditionally lower jnp.dot to lax.dot_general. https://github.com/google/jax/pull/16721 added a condition to lower calls to `jnp.dot` with scalar inputs to `lax.mul` instead of `lax.dot_general`. AFAICT, https://github.com/google/jax/pull/16826 fixed the issue that this was solving, so this condition should no longer be necessary. Removing this condition simplifies the addition of new arguments to `dot` and `dot_general`, including the `algorithm` parameter that I am currently working on in https://github.com/google/jax/pull/23574, so now seemed like a good time to remove it! --- jax/_src/numpy/lax_numpy.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5137b20bc898..450ff06c2669 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7364,20 +7364,17 @@ def dot(a: ArrayLike, b: ArrayLike, *, batch_dims = ((), ()) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: - # TODO(jakevdp): lower this case to dot_general as well? - # Currently, doing so causes issues in remat tests due to #16805 - if preferred_element_type is not None: - a = a.astype(preferred_element_type) - b = b.astype(preferred_element_type) - result = lax.mul(a, b) + contract_dims = ((), ()) else: if b_ndim == 1: contract_dims = ((a_ndim - 1,), (0,)) else: contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) - result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), - precision=precision, preferred_element_type=preferred_element_type) - return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) + result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), + precision=precision, + preferred_element_type=preferred_element_type) + return lax_internal._convert_element_type(result, preferred_element_type, + output_weak_type) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)