Skip to content

Commit

Permalink
Better docs for jnp.meshgrid
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 20, 2024
1 parent 1acf956 commit 28b9212
Showing 1 changed file with 59 additions and 6 deletions.
65 changes: 59 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,6 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
return hist, bin_edges_by_dim


_ARRAY_VIEW_DOC = """
The JAX version of this function may in some cases return a copy rather than a
view of the input.
"""

def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
"""Return a transposed version of an N-dimensional array.
Expand Down Expand Up @@ -5951,9 +5946,67 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
return lax.convert_element_type(res, dtype)


@util.implements(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> list[Array]:
"""Construct N-dimensional grid arrays from N 1-dimensional vectors.
JAX implementation of :func:`numpy.meshgrid`.
Args:
xi: N arrays to convert to a grid.
copy: whether to copy the input arrays. JAX supports only ``copy=True``,
though under JIT compilation the compiler may opt to avoid copies.
sparse: if False (default), then each returned arrays will be of shape
``[len(x1), len(x2), ..., len(xN)]``. If False, then returned arrays
will be of shape ``[1, 1, ..., len(xi), ..., 1, 1]``.
indexing: options are ``'xy'`` for cartesian indexing (default) or ``'ij'``
for matrix indexing.
Returns:
A length-N list of grid arrays.
See also:
- :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax.
- :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax.
Examples:
For the following examples, we'll use these 1D arrays as inputs:
>>> x = jnp.array([1, 2])
>>> y = jnp.array([10, 20, 30])
2D cartesian mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y)
>>> print(x_grid)
[[1 2]
[1 2]
[1 2]]
>>> print(y_grid)
[[10 10]
[20 20]
[30 30]]
2D sparse cartesian mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True)
>>> print(x_grid)
[[1 2]]
>>> print(y_grid)
[[10]
[20]
[30]]
2D matrix-index mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij')
>>> print(x_grid)
[[1 1 1]
[2 2 2]]
>>> print(y_grid)
[[10 20 30]
[10 20 30]]
"""
util.check_arraylike("meshgrid", *xi)
args = [asarray(x) for x in xi]
if not copy:
Expand Down

0 comments on commit 28b9212

Please sign in to comment.