Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: Improved documentation for jax.numpy.kron and jax.numpy.outer #23443

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7405,9 +7405,33 @@ def inner(
preferred_element_type=preferred_element_type)


@util.implements(np.outer, skip_params=['out'])
@partial(jit, inline=True)
def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
"""Compute the outer product of two arrays.

JAX implementation of :func:`numpy.outer`.

Args:
a: first input array, if not 1D it will be flattened.
b: second input array, if not 1D it will be flattened.
out: unsupported by JAX.

Returns:
The outer product of the inputs ``a`` and ``b``. Returned array
will be of shape ``(a.size, b.size)``.

See also:
- :func:`jax.numpy.inner`: compute the inner product of two arrays.
- :func:`jax.numpy.einsum`: Einstein summation.

Examples:
>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.outer(a, b)
Array([[ 4, 5, 6],
[ 8, 10, 12],
[12, 15, 18]], dtype=int32)
"""
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
util.check_arraylike("outer", a, b)
Expand Down Expand Up @@ -7443,9 +7467,39 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
return moveaxis(c, 0, axisc)


@util.implements(np.kron)
@jit
def kron(a: ArrayLike, b: ArrayLike) -> Array:
"""Compute the Kronecker product of two input arrays.
jakevdp marked this conversation as resolved.
Show resolved Hide resolved

JAX implementation of :func:`numpy.kron`.

The Kronecker product is an operation on two matrices of arbitrary size that
produces a block matrix. Each element of the first matrix ``a`` is multiplied by
the entire second matrix ``b``. If ``a`` has shape (m, n) and ``b``
has shape (p, q), the resulting matrix will have shape (m * p, n * q).

Args:
a: first input array with any shape.
b: second input array with any shape.

Returns:
A new array representing the Kronecker product of the inputs ``a`` and ``b``.
The shape of the output is the element-wise product of the input shapes.

See also:
- :func:`jax.numpy.outer`: compute the outer product of two arrays.

Examples:
>>> a = jnp.array([[1, 2],
... [3, 4]])
>>> b = jnp.array([[5, 6],
... [7, 8]])
>>> jnp.kron(a, b)
Array([[ 5, 6, 10, 12],
[ 7, 8, 14, 16],
[15, 18, 20, 24],
[21, 24, 28, 32]], dtype=int32)
"""
util.check_arraylike("kron", a, b)
a, b = util.promote_dtypes(a, b)
if ndim(a) < ndim(b):
Expand Down