-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.numpy: put_along_axis #9954
Comments
We haven't implemented this because the semantics of |
I see - thanks a lot for the explanation :) I was able to implement a custom version of |
As there is some interest in this custom version that I am using for my problem, I am sharing the code here:
|
One idea may be to define I suspect we could refactor things to share much of the index processing with |
Hi @jakevdp I will break down this ticket by the following, see if it makes sense:
|
Any update regarding this ? |
No, but note that you should be able to use |
Assuming I have an implementation functionally equivalent to the numpy version, not in place, but returning a new array, would you be interesting in me making a PR ? I know it is doable with only |
Sure, I'd review a |
I am interested in taking a crack at this if this is something you'd still want. It seems the work for this PR is already well-scoped, but given it's been almost 2 years since I figure no one is working on this. Btw I am a Googler so CLA is already signed, I am just not on a Jax team so doing this as some volunteer work :) |
Sounds good - feel free to take a look! |
Hi :),
I am using jax.numpy to define a NumPyro model. To do so it would be very useful to have a function doing essentially the same as numpy's put_along_axis. In fact, for now I modified numpy's code slightly to output a jax object. But I was wondering, if it is planned to add this feature to jax.numpy in the future?
The text was updated successfully, but these errors were encountered: