Skip to content

Commit

Permalink
Lower jax.numpy.dot to mixed-precision dot_general
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 20, 2023
1 parent 60bb3bc commit ed2ee8e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 25 deletions.
26 changes: 15 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3052,21 +3052,25 @@ def apply_over_axes(func, a, axes):

@util._wraps(np.dot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None) -> Array: # pylint: disable=missing-docstring
util.check_arraylike("dot", a, b)
a, b = util.promote_dtypes(a, b)
a, b = asarray(a), asarray(b)
output_dtype, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True)

batch_dims = ((), ())
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
return lax.mul(a, b)
if max(a_ndim, b_ndim) <= 2:
return lax.dot(a, b, precision=precision)

if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
# TODO(jakevdp): lower this case to dot_general as well?
# Currently, doing so causes issues in remat tests due to #16805
result = lax.mul(a.astype(output_dtype), b.astype(output_dtype))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
batch_dims = ((), ())
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims),
precision=precision, preferred_element_type=output_dtype)
return lax_internal._convert_element_type(result, output_dtype, output_weak_type)


@util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC)
Expand Down
3 changes: 2 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,8 @@ def f(x, y):
return jnp.dot(x, y)

self.assertRaisesRegex(
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
TypeError, ("dot_general requires contracting dimensions to have "
"the same shape, got \\(3L?,\\) and \\(4L?,\\)."),
lambda: grad(f)(np.zeros(3), np.zeros(4)))

def test_abstract_error_message(self):
Expand Down
46 changes: 33 additions & 13 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,27 +469,30 @@ def np_fun(a, b):
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)

@jtu.sample_product(
[dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for name, lhs_shape, rhs_shape in [
("matrix-scalar", (3, 3), ()),
("scalar-matrix", (), (3, 3)),
("matrix-vector", (4, 5), (5,)),
("vector-matrix", (6,), (6, 4)),
("matrix-matrix", (3, 4), (4, 5)),
("tensor-vector", (4, 3, 2), (2,)),
("vector-tensor", (2,), (3, 2, 4)),
("tensor-matrix", (4, 3, 2), (2, 5)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-tensor", (2, 3, 4), (5, 4, 1))]],
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for lhs_shape, rhs_shape in [
((3, 3), ()),
((), (3, 3)),
((4, 5), (5,)),
((6,), (6, 4)),
((3, 4), (4, 5)),
((4, 3, 2), (2,)),
((2,), (3, 2, 4)),
((4, 3, 2), (2, 5)),
((5, 2), (3, 2, 4)),
((2, 3, 4), (5, 4, 1))]],
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
def testDot(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14,
np.complex128: 1e-14}
if (lhs_dtype in [np.float16, jnp.bfloat16] and
rhs_dtype in [np.float16, jnp.bfloat16]):
tol = 1e-2
def np_dot(x, y):
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
Expand All @@ -498,6 +501,23 @@ def np_dot(x, y):
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol)
self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol)

@jtu.sample_product(
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.numpy_dtype_promotion('standard')
def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype):
# This test confirms that jnp.dot lowers to a single dot_general call,
# avoiding explicit type casting of inputs and outputs.
lhs = jax.ShapeDtypeStruct((5,), lhs_dtype)
rhs = jax.ShapeDtypeStruct((5,), rhs_dtype)
jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs)
prims = [eqn.primitive for eqn in jaxpr.eqns]
self.assertIn(prims, [
[lax.dot_general_p],
[lax.dot_general_p, lax.convert_element_type_p]
])

@jtu.sample_product(
[dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for name, lhs_shape, rhs_shape in [
Expand Down

0 comments on commit ed2ee8e

Please sign in to comment.