From ed2ee8ea590186b566d05047b998a666cd6a5940 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 20 Jul 2023 13:47:02 -0700 Subject: [PATCH] Lower jax.numpy.dot to mixed-precision dot_general --- jax/_src/numpy/lax_numpy.py | 26 ++++++++++++--------- tests/api_test.py | 3 ++- tests/lax_numpy_test.py | 46 ++++++++++++++++++++++++++----------- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b1b5f28089eb..0be45c0e4243 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3052,21 +3052,25 @@ def apply_over_axes(func, a, axes): @util._wraps(np.dot, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('precision',), inline=True) -def dot(a, b, *, precision=None): # pylint: disable=missing-docstring +def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None) -> Array: # pylint: disable=missing-docstring util.check_arraylike("dot", a, b) - a, b = util.promote_dtypes(a, b) + a, b = asarray(a), asarray(b) + output_dtype, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) + + batch_dims = ((), ()) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: - return lax.mul(a, b) - if max(a_ndim, b_ndim) <= 2: - return lax.dot(a, b, precision=precision) - - if b_ndim == 1: - contract_dims = ((a_ndim - 1,), (0,)) + # TODO(jakevdp): lower this case to dot_general as well? + # Currently, doing so causes issues in remat tests due to #16805 + result = lax.mul(a.astype(output_dtype), b.astype(output_dtype)) else: - contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) - batch_dims = ((), ()) - return lax.dot_general(a, b, (contract_dims, batch_dims), precision) + if b_ndim == 1: + contract_dims = ((a_ndim - 1,), (0,)) + else: + contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) + result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), + precision=precision, preferred_element_type=output_dtype) + return lax_internal._convert_element_type(result, output_dtype, output_weak_type) @util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC) diff --git a/tests/api_test.py b/tests/api_test.py index 0d30e80ec1e4..355340efbfbd 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1527,7 +1527,8 @@ def f(x, y): return jnp.dot(x, y) self.assertRaisesRegex( - TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).", + TypeError, ("dot_general requires contracting dimensions to have " + "the same shape, got \\(3L?,\\) and \\(4L?,\\)."), lambda: grad(f)(np.zeros(3), np.zeros(4))) def test_abstract_error_message(self): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 079b1bba6bcc..4e9cd546d97f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -469,27 +469,30 @@ def np_fun(a, b): self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @jtu.sample_product( - [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for name, lhs_shape, rhs_shape in [ - ("matrix-scalar", (3, 3), ()), - ("scalar-matrix", (), (3, 3)), - ("matrix-vector", (4, 5), (5,)), - ("vector-matrix", (6,), (6, 4)), - ("matrix-matrix", (3, 4), (4, 5)), - ("tensor-vector", (4, 3, 2), (2,)), - ("vector-tensor", (2,), (3, 2, 4)), - ("tensor-matrix", (4, 3, 2), (2, 5)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-tensor", (2, 3, 4), (5, 4, 1))]], + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for lhs_shape, rhs_shape in [ + ((3, 3), ()), + ((), (3, 3)), + ((4, 5), (5,)), + ((6,), (6, 4)), + ((3, 4), (4, 5)), + ((4, 3, 2), (2,)), + ((2,), (3, 2, 4)), + ((4, 3, 2), (2, 5)), + ((5, 2), (3, 2, 4)), + ((2, 3, 4), (5, 4, 1))]], lhs_dtype=number_dtypes, rhs_dtype=number_dtypes, ) @jax.default_matmul_precision("float32") - def testDot(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14, np.complex128: 1e-14} + if (lhs_dtype in [np.float16, jnp.bfloat16] and + rhs_dtype in [np.float16, jnp.bfloat16]): + tol = 1e-2 def np_dot(x, y): x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y @@ -498,6 +501,23 @@ def np_dot(x, y): self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) + @jtu.sample_product( + lhs_dtype=number_dtypes, + rhs_dtype=number_dtypes, + ) + @jax.numpy_dtype_promotion('standard') + def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype): + # This test confirms that jnp.dot lowers to a single dot_general call, + # avoiding explicit type casting of inputs and outputs. + lhs = jax.ShapeDtypeStruct((5,), lhs_dtype) + rhs = jax.ShapeDtypeStruct((5,), rhs_dtype) + jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs) + prims = [eqn.primitive for eqn in jaxpr.eqns] + self.assertIn(prims, [ + [lax.dot_general_p], + [lax.dot_general_p, lax.convert_element_type_p] + ]) + @jtu.sample_product( [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) for name, lhs_shape, rhs_shape in [