Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unconditionally lower jnp.dot to lax.dot_general.
jax-ml#16721 added a condition to lower calls to `jnp.dot` with scalar inputs to `lax.mul` instead of `lax.dot_general`. AFAICT, jax-ml#16826 fixed the issue that this was solving, so this condition should no longer be necessary. Removing this condition simplifies the addition of new arguments to `dot` and `dot_general`, including the `algorithm` parameter that I am currently working on in jax-ml#23574, so now seemed like a good time to remove it!
- Loading branch information