Skip to content

Commit

Permalink
Add numpy.put_along_axis.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Nov 13, 2024
1 parent c4a0369 commit 66664d7
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.

* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
Expand Down
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ namespace; they are listed below.
promote_types
ptp
put
put_along_axis
quantile
r_
rad2deg
Expand Down
97 changes: 96 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace)
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
Expand Down Expand Up @@ -11433,6 +11433,101 @@ def replace(tup, val):
mode="fill" if mode is None else mode, fill_value=fill_value)


_indices = indices # argument below named 'indices' shadows the function


def _make_along_axis_idx(shape, indices, axis):
return tuple_replace(_indices(shape, sparse=True), axis, indices)


@partial(jit, static_argnames=('axis', 'inplace', 'mode'))
def put_along_axis(
arr: ArrayLike,
indices: ArrayLike,
values: ArrayLike,
axis: int | None,
inplace: bool = True,
*,
mode: str | None = None,
) -> Array:
"""Put values into the destination array by matching 1d index and data slices.
JAX implementation of :func:`numpy.put_along_axis`.
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be put.
indices: array of indices at which to put values.
values: array of values to put into the array.
axis: the axis along which to put values. If not specified, the array will
be flattened before indexing is applied.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
see :attr:`jax.numpy.ndarray.at`.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.put`: put elements into an array at given indices.
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
Examples:
>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> b = jnp.put_along_axis(a, i, 99, axis=1)
[[10 99 20]
[99 40 50]]
"""
if inplace:
raise ValueError(
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
"are immutable. Pass inplace=False to instead return an updated array.")

util.check_arraylike("put_along_axis", arr, indices, values)
arr = asarray(arr)
indices = asarray(indices)
values = asarray(values)

original_axis = axis
original_arr_shape = arr.shape

if axis is None:
arr = arr.ravel()
axis = 0

if not arr.ndim == indices.ndim:
raise ValueError(
"put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
f"{arr.ndim=} and {indices.ndim=}."
)

try:
values = broadcast_to(values, indices.shape)
except ValueError:
raise ValueError(
"put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
f"{values.shape=} and {indices.shape=}."
)

idx = _make_along_axis_idx(arr.shape, indices, axis)
result = arr.at[idx].set(values, mode=mode)

if original_axis is None:
result = result.reshape(original_arr_shape)

return result


### Indexing

def _is_integer_index(idx: Any) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Sequence
from functools import partial
from typing import Any
import itertools

import warnings

Expand Down Expand Up @@ -265,3 +266,9 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr

def broadcastable_shapes(shape):
# yields all shapes that can be broadcasted to the given shape
# https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules
for i in range(len(shape) + 1):
yield from itertools.product(*({dim, 1} for dim in shape[i:]))
4 changes: 4 additions & 0 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ def tuple_update(t, idx, val):
assert 0 <= idx < len(t), (idx, len(t))
return t[:idx] + (val,) + t[idx+1:]

def tuple_replace(tupl, index, item):
# unlike tuple_update, works with negative indices as well
return tupl[:index] + (item,) + tupl[index:][1:]

class HashableFunction:
"""Decouples function equality and hash from its identity.
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
printoptions as printoptions,
promote_types as promote_types,
put as put,
put_along_axis as put_along_axis,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
repeat as repeat,
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ...) -> Array: ...
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike,
axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ...
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
Expand Down
53 changes: 51 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.util import safe_zip, NumpyComplexWarning
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
from jax._src.numpy.util import broadcastable_shapes

config.parse_flags_with_absl()

Expand Down Expand Up @@ -160,6 +161,15 @@ def _shapes_are_equal_length(shapes):
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])


def _broadcastable_to(shape_1, shape_2):
try:
assert jnp.broadcast_shapes(shape_1, shape_2) == shape_2
except (ValueError, AssertionError):
return False
else:
return True


def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.5) -> list[jax.Array]:
"""Generate multiple arrays with some overlapping values.
Expand Down Expand Up @@ -5962,6 +5972,46 @@ def np_fun(a, i, v):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis)
for a_shape in array_shapes
for axis in list(range(-len(a_shape), len(a_shape)))
if a_shape[axis] > 0
for i_shape in [tuple_replace(a_shape, axis, J) for J in range(5)]
for v_shape in broadcastable_shapes(i_shape)
] + [
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None)
for a_shape in nonempty_array_shapes
for i_shape in [(J,) for J in range(5)]
for v_shape in broadcastable_shapes(i_shape)
],
dtype=jtu.dtypes.all,
mode=[None, "promise_in_bounds", "clip"],
)
def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode):
a_rng = jtu.rand_default(self.rng())
if axis is None:
size = math.prod(a_shape)
else:
size = a_shape[axis]
i_rng = jtu.rand_int(self.rng(), -size, size)

def args_maker():
a = a_rng(a_shape, dtype)
i = i_rng(i_shape, np.int32)
v = a_rng(v_shape, dtype)
return a, i, v

def np_fun(a, i, v):
a_copy = a.copy()
np.put_along_axis(a_copy, i, v, axis=axis)
return a_copy

jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def test_rot90_error(self):
with self.assertRaisesRegex(
ValueError,
Expand Down Expand Up @@ -6229,7 +6279,6 @@ def testWrappedSignaturesMatch(self):
'nditer',
'nested_iters',
'poly1d',
'put_along_axis',
'putmask',
'real_if_close',
'recarray',
Expand Down

0 comments on commit 66664d7

Please sign in to comment.