From 28b9212b4c75b18ef91a6a02e4a4257a9a10dc76 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Sep 2024 08:25:54 -0700 Subject: [PATCH] Better docs for jnp.meshgrid --- jax/_src/numpy/lax_numpy.py | 65 +++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b71412e586bd..a8a8195de521 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -819,11 +819,6 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim -_ARRAY_VIEW_DOC = """ -The JAX version of this function may in some cases return a copy rather than a -view of the input. -""" - def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -5951,9 +5946,67 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) -@util.implements(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: + """Construct N-dimensional grid arrays from N 1-dimensional vectors. + + JAX implementation of :func:`numpy.meshgrid`. + + Args: + xi: N arrays to convert to a grid. + copy: whether to copy the input arrays. JAX supports only ``copy=True``, + though under JIT compilation the compiler may opt to avoid copies. + sparse: if False (default), then each returned arrays will be of shape + ``[len(x1), len(x2), ..., len(xN)]``. If False, then returned arrays + will be of shape ``[1, 1, ..., len(xi), ..., 1, 1]``. + indexing: options are ``'xy'`` for cartesian indexing (default) or ``'ij'`` + for matrix indexing. + + Returns: + A length-N list of grid arrays. + + See also: + - :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax. + - :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax. + + Examples: + For the following examples, we'll use these 1D arrays as inputs: + + >>> x = jnp.array([1, 2]) + >>> y = jnp.array([10, 20, 30]) + + 2D cartesian mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y) + >>> print(x_grid) + [[1 2] + [1 2] + [1 2]] + >>> print(y_grid) + [[10 10] + [20 20] + [30 30]] + + 2D sparse cartesian mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) + >>> print(x_grid) + [[1 2]] + >>> print(y_grid) + [[10] + [20] + [30]] + + 2D matrix-index mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') + >>> print(x_grid) + [[1 1 1] + [2 2 2]] + >>> print(y_grid) + [[10 20 30] + [10 20 30]] + """ util.check_arraylike("meshgrid", *xi) args = [asarray(x) for x in xi] if not copy: