Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower jax.numpy matmul functions to mixed-precision dot_general #16736

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3098,13 +3098,13 @@ def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None) -> Array
@partial(jit, static_argnames=('precision',), inline=True)
def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None) -> Array: # pylint: disable=missing-docstring
util.check_arraylike("matmul", a, b)
a, b = asarray(a), asarray(b)
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)

a, b = util.promote_dtypes(a, b)
out_dtype, out_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True)

a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims: tuple[int | None, ...] = shape(a)[:-2] if a_is_mat else ()
Expand Down Expand Up @@ -3153,8 +3153,9 @@ def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None) -> Ar
b = lax.squeeze(b, tuple(b_squeeze))
out = lax.dot_general(
a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
precision=precision)
return lax.transpose(out, perm)
precision=precision, preferred_element_type=out_dtype)
result = lax.transpose(out, perm)
return lax_internal._convert_element_type(result, out_dtype, out_weak_type)


@util._wraps(np.vdot, lax_description=_PRECISION_DOC)
Expand All @@ -3173,10 +3174,11 @@ def tensordot(a: ArrayLike, b: ArrayLike,
axes: int | Sequence[int] | Sequence[Sequence[int]] = 2,
*, precision: PrecisionLike = None) -> Array:
util.check_arraylike("tensordot", a, b)
a, b = asarray(a), asarray(b)
a_ndim = ndim(a)
b_ndim = ndim(b)

a, b = util.promote_dtypes(a, b)
out_dtype, out_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True)
if type(axes) is int:
if axes > min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
Expand All @@ -3201,8 +3203,9 @@ def tensordot(a: ArrayLike, b: ArrayLike,
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
"of lists/tuples of ints.")
raise TypeError(msg)
return lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision)
result = lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision, preferred_element_type=out_dtype)
return lax_internal._convert_element_type(result, out_dtype, out_weak_type)


_EINSUM_DOC = _PRECISION_DOC + """\
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,8 @@ def result(out_array, lhs_data, lhs_indices, rhs):
idx_right = (*idx_batch, *idx_right)
batch_dims = list(range(len(lhs_contracting_b) + bool(lhs_contracting_s)))
prod = lax.dot_general(lhs_data, rhs.at[idx_right].get(mode='fill', fill_value=0),
(([], []), (batch_dims, batch_dims)))
(([], []), (batch_dims, batch_dims)),
preferred_element_type=preferred_element_type)
if idx_out:
return out_array.at[idx_out].add(prod)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ def test_bcsr_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs),
jnp.array(rng(rhs_shape, rhs_dtype))]

tol = {np.float64: 1E-13, np.complex128: 1E-13,
tol = {np.float64: 1E-7, np.complex128: 1E-7,
np.float32: 2E-6, np.complex64: 2E-6}

with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
Expand Down Expand Up @@ -1976,7 +1976,7 @@ def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs),
jnp.array(rng(rhs_shape, rhs_dtype))]

tol = {np.float64: 1E-13, np.complex128: 1E-13,
tol = {np.float64: 1E-7, np.complex128: 1E-7,
np.float32: 1E-6, np.complex64: 1E-6}

with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
Expand Down