-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add numpy.put_along_axis. #24871
base: main
Are you sure you want to change the base?
Add numpy.put_along_axis. #24871
Conversation
692feb7
to
d6c8274
Compare
Thanks for looking into this! It's a good start, but we'll need to handle (and test!) all the broadcasting semantics of |
d6c8274
to
e419da0
Compare
@jakevdp Isn't that already handled? For example, By the way, this seems to be equivalent to def put_along_axis(arr, indices, values, axis):
idx = jnp.indices(arr.shape, sparse=True)
idx = idx[:axis] + (indices,) + idx[axis:][1:] # replace the indices for axis
return arr.at[idx].set(values) |
747dd53
to
58e5b28
Compare
@jakevdp I think it should be fixed now. |
58e5b28
to
a62a521
Compare
@jakevdp Updated. |
a62a521
to
c2c4285
Compare
@jakevdp Updated. |
66664d7
to
f47dbd9
Compare
823ff0e
to
47daf09
Compare
@jakevdp I reduced the number of test cases further. Let me know if you want to keep the utility function |
cdb1901
to
0a69dfa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
It looks like the example output has a stray character |
0a69dfa
to
4bab61d
Compare
4bab61d
to
15857c4
Compare
We're seeing test failures on H100. I think it's probably due to the fact that repeated indices within Probably the best way to fix this would be to ensure within tests that generated indices are unique within each slice, though I don't know of an easy way to do that in the general case using existing test utilities. |
Fixes #9954.