Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numpy.put_along_axis. #24871

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

carlosgmartin
Copy link
Contributor

Fixes #9954.

@carlosgmartin carlosgmartin force-pushed the numpy_put_along_axis branch 3 times, most recently from 692feb7 to d6c8274 Compare November 12, 2024 23:20
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 12, 2024

Thanks for looking into this! It's a good start, but we'll need to handle (and test!) all the broadcasting semantics of put_along_axis – i.e. both indices and values may be N-dimensional arrays that broadcast together along all axes except axis

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Nov 12, 2024

@jakevdp Isn't that already handled? For example, testPutAlongAxis tests the case where M and J are different.

By the way, this seems to be equivalent to

def put_along_axis(arr, indices, values, axis):
  idx = jnp.indices(arr.shape, sparse=True)
  idx = idx[:axis] + (indices,) + idx[axis:][1:]  # replace the indices for axis
  return arr.at[idx].set(values)

tests/lax_numpy_test.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the numpy_put_along_axis branch 4 times, most recently from 747dd53 to 58e5b28 Compare November 13, 2024 02:53
@carlosgmartin
Copy link
Contributor Author

@jakevdp I think it should be fixed now.

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
@carlosgmartin
Copy link
Contributor Author

@jakevdp Updated.

tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
@carlosgmartin
Copy link
Contributor Author

@jakevdp Updated.

@carlosgmartin carlosgmartin force-pushed the numpy_put_along_axis branch 2 times, most recently from 66664d7 to f47dbd9 Compare November 13, 2024 23:41
tests/lax_numpy_test.py Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the numpy_put_along_axis branch 3 times, most recently from 823ff0e to 47daf09 Compare November 13, 2024 23:57
@carlosgmartin
Copy link
Contributor Author

@jakevdp I reduced the number of test cases further. Let me know if you want to keep the utility function jax._src.numpy.util.broadcastable_shapes I created. (Perhaps it might be useful in the future.)

@carlosgmartin carlosgmartin force-pushed the numpy_put_along_axis branch 3 times, most recently from cdb1901 to 0a69dfa Compare November 14, 2024 00:05
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 14, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 14, 2024

It looks like the example output has a stray character

jax/_src/numpy/util.py Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 14, 2024

We're seeing test failures on H100. I think it's probably due to the fact that repeated indices within scatter may be updated in any order, and so the updated values may not match the updated values in NumPy, which always traverses the indices in order.

Probably the best way to fix this would be to ensure within tests that generated indices are unique within each slice, though I don't know of an easy way to do that in the general case using existing test utilities.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jax.numpy: put_along_axis
3 participants