Skip to content

Commit

Permalink
Lower jax.numpy matmul functions to mixed-precision dot_general
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 17, 2023
1 parent 68ea651 commit ac6410d
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,8 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
@util._wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "asarray")
if hasattr(a, "__jax_array__"):
a = a.__jax_array__()
if dtype is not None:
dtype = dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True)
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
Expand Down Expand Up @@ -3087,13 +3089,13 @@ def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
@partial(jit, static_argnames=('precision',), inline=True)
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
util.check_arraylike("matmul", a, b)
a, b = asarray(a), asarray(b)
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)

a, b = util.promote_dtypes(a, b)
out_dtype = dtypes.result_type(a, b)

a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims = shape(a)[:-2] if a_is_mat else ()
Expand Down Expand Up @@ -3142,7 +3144,7 @@ def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
b = lax.squeeze(b, tuple(b_squeeze))
out = lax.dot_general(
a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
precision=precision)
precision=precision, preferred_element_type=out_dtype)
return lax.transpose(out, perm)


Expand All @@ -3160,10 +3162,11 @@ def tensordot(a: ArrayLike, b: ArrayLike,
axes: Union[int, Sequence[int], Sequence[Sequence[int]]] = 2,
*, precision=None) -> Array:
util.check_arraylike("tensordot", a, b)
a, b = asarray(a), asarray(b)
a_ndim = ndim(a)
b_ndim = ndim(b)

a, b = util.promote_dtypes(a, b)
out_dtype = dtypes.result_type(a, b)
if type(axes) is int:
if axes > min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
Expand All @@ -3189,7 +3192,7 @@ def tensordot(a: ArrayLike, b: ArrayLike,
"of lists/tuples of ints.")
raise TypeError(msg)
return lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision)
precision=precision, preferred_element_type=out_dtype)


_EINSUM_DOC = _PRECISION_DOC + """\
Expand Down

0 comments on commit ac6410d

Please sign in to comment.