Skip to content

Commit

Permalink
Merge pull request jax-ml#23737 from jakevdp:digitize-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676111220
  • Loading branch information
Google-ML-Automation committed Sep 18, 2024
2 parents dbc03cf + 57a4b76 commit b51c653
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
48 changes: 42 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10808,11 +10808,46 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
}[method]
return impl(asarray(a), asarray(v), side, dtype) # type: ignore

@util.implements(np.digitize, lax_description=_dedent("""
Optionally, the ``method`` argument can be used to configure the
underlying :func:`jax.numpy.searchsorted` algorithm."""))

@partial(jit, static_argnames=('right', 'method'))
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str = 'scan') -> Array:
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
*, method: str | None = None) -> Array:
"""Convert an array to bin indices.
JAX implementation of :func:`numpy.digitize`.
Args:
x: array of values to digitize.
bins: 1D array of bin edges. Must be monotonically increasing or decreasing.
right: if true, the intervals include the right bin edges. If false (default)
the intervals include the left bin edges.
method: optional method argument to be passed to :func:`~jax.numpy.searchsorted`.
See that function for available options.
Returns:
An integer array of the same shape as ``x`` indicating the bin number that
the values are in.
See also:
- :func:`jax.numpy.searchsorted`: find insertion indices for values in a
sorted array.
- :func:`jax.numpy.histogram`: compute frequency of array values within
specified bins.
Examples:
>>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5])
>>> bins = jnp.array([1, 2, 3])
>>> jnp.digitize(x, bins)
Array([1, 2, 2, 1, 3, 3], dtype=int32)
>>> jnp.digitize(x, bins, right=True)
Array([0, 1, 2, 1, 2, 3], dtype=int32)
``digitize`` supports reverse-ordered bins as well:
>>> bins = jnp.array([3, 2, 1])
>>> jnp.digitize(x, bins)
Array([2, 1, 1, 2, 0, 0], dtype=int32)
"""
util.check_arraylike("digitize", x, bins)
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
bins_arr = asarray(bins)
Expand All @@ -10821,10 +10856,11 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str
if bins_arr.shape[0] == 0:
return zeros_like(x, dtype=int32)
side = 'right' if not right else 'left'
kwds: dict[str, str] = {} if method is None else {'method': method}
return where(
bins_arr[-1] >= bins_arr[0],
searchsorted(bins_arr, x, side=side, method=method),
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, method=method)
searchsorted(bins_arr, x, side=side, **kwds),
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds)
)


Expand Down
3 changes: 2 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def diagonal(
def diff(a: ArrayLike, n: int = ..., axis: int = ...,
prepend: ArrayLike | None = ...,
append: ArrayLike | None = ...) -> Array: ...
def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ...
def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ..., *,
method: str | None = ...) -> Array: ...
divide = true_divide
def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ...
def dot(
Expand Down

0 comments on commit b51c653

Please sign in to comment.