Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy: put_along_axis #9954

Open
LaraFuhrmann opened this issue Mar 18, 2022 · 11 comments · May be fixed by #24871
Open

jax.numpy: put_along_axis #9954

LaraFuhrmann opened this issue Mar 18, 2022 · 11 comments · May be fixed by #24871
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request

Comments

@LaraFuhrmann
Copy link

Hi :),
I am using jax.numpy to define a NumPyro model. To do so it would be very useful to have a function doing essentially the same as numpy's put_along_axis. In fact, for now I modified numpy's code slightly to output a jax object. But I was wondering, if it is planned to add this feature to jax.numpy in the future?

@LaraFuhrmann LaraFuhrmann added the enhancement New feature or request label Mar 18, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 18, 2022

We haven't implemented this because the semantics of np.put_along_axis are to modify the array in-place, and this is not possible in JAX because JAX arrays are immutable. I suspect you could accomplish what you want to do using some combination of index update operators; see https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at

@LaraFuhrmann
Copy link
Author

I see - thanks a lot for the explanation :) I was able to implement a custom version of np.put_along_axis using the index update operators as you suggested. Thanks again!

@mattjj mattjj closed this as completed Mar 21, 2022
@LaraFuhrmann
Copy link
Author

As there is some interest in this custom version that I am using for my problem, I am sharing the code here:
(Happy to any suggestions on how to solve this better)

"""
Defining put_along_axis() from numpy for jax.numpy.
Essentially copied the code from
https://github.com/numpy/numpy/blob/4adc87dff15a247e417d50f10cc4def8e1c17a03/numpy/lib/shape_base.py#L29

"""

import numpy.core.numeric as _nx
import jax

def _make_along_axis_idx(arr_shape, indices, axis):
    # compute dimensions to iterate over
    if not _nx.issubdtype(indices.dtype, _nx.integer):
        raise IndexError("`indices` must be an integer array")
    if len(arr_shape) != indices.ndim:
        raise ValueError("`indices` and `arr` must have the same number of dimensions")
    shape_ones = (1,) * indices.ndim
    dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))

    # build a fancy index, consisting of orthogonal aranges, with the
    # requested index inserted at the right location
    fancy_index = []
    for dim, n in zip(dest_dims, arr_shape):
        if dim is None:
            fancy_index.append(indices)
        else:
            ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
            fancy_index.append(_nx.arange(n).reshape(ind_shape))

    return tuple(fancy_index)

def custom_put_along_axis(arr, indices, values, axis):
    """
    Parameters
    ----------
    arr : ndarray (Ni..., M, Nk...)
        Destination array.
    indices : ndarray (Ni..., J, Nk...)
        Indices to change along each 1d slice of `arr`. This must match the
        dimension of arr, but dimensions in Ni and Nj may be 1 to broadcast
        against `arr`.
    values : array_like (Ni..., J, Nk...)
        values to insert at those indices. Its shape and dimension are
        broadcast to match that of `indices`.
    axis : int
        The axis to take 1d slices along. If axis is None, the destination
        array is treated as if a flattened 1d view had been created of it.

    """

    # normalize inputs
    if axis is None:
        arr = arr.flat
        axis = 0
        arr_shape = (len(arr),)  # flatiter has no .shape
    else:
        # axis = normalize_axis_index(axis, arr.ndim)
        arr_shape = arr.shape

    # use the fancy index
    arr = arr.at[tuple(_make_along_axis_idx(arr_shape, indices, axis))].set(values)
    return arr

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 20, 2022

One idea may be to define jax.numpy.put_along_axis, but add an extra inplace keyword that defaults to True, such that the function errors with the default value. Users could set inplace=False and then have a version of the function that returns the updated array. It would avoid the potential pitfall of users assuming the function works in-place and then being confused why the array isn't changing. What do you think?

I suspect we could refactor things to share much of the index processing with take_along_axis.

@jakevdp jakevdp reopened this Apr 20, 2022
@jakevdp jakevdp added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Apr 20, 2022
@riven314
Copy link

Hi @jakevdp
I would like to contribute to this ticket.

I will break down this ticket by the following, see if it makes sense:

  1. refactor some helper functions from take_along_axis that can be shared with put_along_axis
  2. implement put_along_axis in JAX
  3. default its argument inplace to True and raise an error for changing array in place)

@jdeschena
Copy link

Any update regarding this ?

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 20, 2023

No, but note that you should be able to use jnp.ndarray.at[] to do anything that you might do with put_along_axis.

@jdeschena
Copy link

Assuming I have an implementation functionally equivalent to the numpy version, not in place, but returning a new array, would you be interesting in me making a PR ? I know it is doable with only at but sometimes it is convenient to have such a function directly 🥲

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 20, 2023

Sure, I'd review a put_along_axis PR

@DCtheTall
Copy link

I am interested in taking a crack at this if this is something you'd still want. It seems the work for this PR is already well-scoped, but given it's been almost 2 years since I figure no one is working on this.

Btw I am a Googler so CLA is already signed, I am just not on a Jax team so doing this as some volunteer work :)

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 23, 2024

Sounds good - feel free to take a look!

@carlosgmartin carlosgmartin linked a pull request Nov 12, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants