Skip to content

Commit

Permalink
Add numpy.put_along_axis.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Nov 12, 2024
1 parent c4a0369 commit e419da0
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.

* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
Expand Down
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ namespace; they are listed below.
promote_types
ptp
put
put_along_axis
quantile
r_
rad2deg
Expand Down
70 changes: 70 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11433,6 +11433,76 @@ def replace(tup, val):
mode="fill" if mode is None else mode, fill_value=fill_value)


@partial(jit, static_argnames=('axis', 'inplace'))
def put_along_axis(
arr: ArrayLike,
indices: ArrayLike,
values: ArrayLike,
axis: int | None,
inplace: bool = True,
) -> Array:
"""Put values into the destination array by matching 1d index and data slices.
JAX implementation of :func:`numpy.put_along_axis`.
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be put.
indices: array of indices at which to put values.
values: array of values to put into the array.
axis: the axis along which to put values. If not specified, the array will
be flattened before indexing is applied.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.put`: put elements into an array at given indices.
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
Examples:
>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> b = jnp.put_along_axis(a, i, 99, axis=1)
[[10 99 20]
[99 40 50]]
"""
if inplace:
raise ValueError(
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
"are immutable. Pass inplace=False to instead return an updated array.")

util.check_arraylike("put_along_axis", arr, indices, values)
arr = asarray(arr)
indices = asarray(indices)
values = asarray(values)

axis = _canonicalize_axis(axis, arr.ndim)

if axis is None:
arr = arr.ravel()
axis = 0

idx = tuple(
indices
if i == axis else
arange(dim).reshape((-1,) + (1,) * (arr.ndim - 1 - i))
for i, dim in enumerate(arr.shape)
)

return arr.at[idx].set(values)


### Indexing

def _is_integer_index(idx: Any) -> bool:
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
printoptions as printoptions,
promote_types as promote_types,
put as put,
put_along_axis as put_along_axis,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
repeat as repeat,
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ...) -> Array: ...
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike,
axis: int | None, inplace: bool = True) -> Array: ...
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
Expand Down
34 changes: 33 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5962,6 +5962,39 @@ def np_fun(a, i, v):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=array_shapes,
dtype=jtu.dtypes.all,
M=range(5),
J=range(5),
axis=range(5),
)
def testPutAlongAxis(self, shape, dtype, M, J, axis):
if axis > len(shape):
self.skipTest("invalid axis")

a_shape = shape[:axis] + (M,) + shape[axis:]
i_shape = shape[:axis] + (J,) + shape[axis:]
v_shape = shape[:axis] + (J,) + shape[axis:]

a_rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), high=M)

def args_maker():
a = a_rng(a_shape, dtype)
i = i_rng(i_shape, np.int32)
v = a_rng(v_shape, dtype)
return a, i, v

def np_fun(a, i, v):
a_copy = a.copy()
np.put_along_axis(a_copy, i, v, axis=axis)
return a_copy

jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def test_rot90_error(self):
with self.assertRaisesRegex(
ValueError,
Expand Down Expand Up @@ -6229,7 +6262,6 @@ def testWrappedSignaturesMatch(self):
'nditer',
'nested_iters',
'poly1d',
'put_along_axis',
'putmask',
'real_if_close',
'recarray',
Expand Down

0 comments on commit e419da0

Please sign in to comment.