Skip to content

Commit

Permalink
kron_and_outer_docstring_added
Browse files Browse the repository at this point in the history
description_fixed_and_kron_desc_added
  • Loading branch information
selamw1 committed Sep 5, 2024
1 parent 38184dd commit 571400b
Showing 1 changed file with 57 additions and 2 deletions.
59 changes: 57 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7405,9 +7405,33 @@ def inner(
preferred_element_type=preferred_element_type)


@util.implements(np.outer, skip_params=['out'])
@partial(jit, inline=True)
def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
"""Compute the outer product of two arrays.
JAX implementation of :func:`numpy.outer`.
Args:
a: first input array, if not 1D it will be flattened.
b: second input array, if not 1D it will be flattened.
out: unsupported by JAX.
Returns:
The outer product of the inputs ``a`` and ``b``. Returned array
will be of shape (``a.shape``, ``a.shape``).
See also:
- :func:`jax.numpy.inner`: compute the inner product of two arrays.
- :func:`jax.numpy.einsum`: Einstein summation.
Examples:
>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.outer(a, b)
Array([[ 4, 5, 6],
[ 8, 10, 12],
[12, 15, 18]], dtype=int32)
"""
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
util.check_arraylike("outer", a, b)
Expand Down Expand Up @@ -7443,9 +7467,40 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
return moveaxis(c, 0, axisc)


@util.implements(np.kron)
@jit
def kron(a: ArrayLike, b: ArrayLike) -> Array:
"""Compute the Kronecker product of two input arrays.
JAX implementation of :func:`numpy.kron`.
The Kronecker product is an operation on two matrices of arbitrary size that
produces a block matrix. Each element of the first matrix ``a`` is multiplied by
the entire second matrix ``b``. If ``a`` has shape (m, n) and ``b`` has shape (p, q),
the resulting matrix will have shape (m * p, n * q). The blocks of the resulting
matrix are formed by multiplying each element of ``a`` by the entire matrix ``b``.
Args:
a: first input array with any shape.
b: second input array with any shape.
Returns:
A new array representing the Kronecker product of the inputs ``a`` and ``b``.
The shape of the output is the element-wise product of the input shapes.
See also:
- :func:`jax.numpy.outer`: compute the outer product of two arrays.
Examples:
>>> a = jnp.array([[1, 2],
... [3, 4]])
>>> b = jnp.array([[5, 6],
... [7, 8]])
>>> jnp.kron(a, b)
Array([[ 5, 6, 10, 12],
[ 7, 8, 14, 16],
[15, 18, 20, 24],
[21, 24, 28, 32]], dtype=int32)
"""
util.check_arraylike("kron", a, b)
a, b = util.promote_dtypes(a, b)
if ndim(a) < ndim(b):
Expand Down

0 comments on commit 571400b

Please sign in to comment.