-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Jax custom call #12396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax custom call #12396
Conversation
jaxlib/jax_custom_call.h
Outdated
| JaxCustomCallStatus* status; | ||
| }; | ||
|
|
||
| void JaxCustomCallStatusSetSuccess(JaxCustomCallStatus* status); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, linkage may be hard here. My suggestion to Sharad was that we pass function pointers for all API entry points a custom call might need (i.e., like a c++ class implemented in C).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use runtime linkage instead of link-time linkage a la Numpy import_array() style.
a78db4a to
a379de5
Compare
bbeca61 to
01e660b
Compare
Co-authored-by: Kuangyuan Chen <[email protected]> Co-authored-by: Qiao Zhang<[email protected]>
Co-authored-by: Kuangyuan Chen <[email protected]> Co-authored-by: Qiao Zhang <[email protected]>
40c90e5 to
3a48ea8
Compare
a18f38c to
17fb10c
Compare
17fb10c to
79febdc
Compare
|
We are importing this to explore some things. @hawkinsp hold off on reviews. Thanks! |
|
Hi! This looks really useful. Is anyone still committed to making this happen? |
|
@jlu-spins yes, we are making an improved version with better C++ typed argument support. But it creates a new dependency on some XLA component that's still WIP. We are actively working on this, and will hopefully have a version in a few weeks. In the meanwhile, https://github.com/dfm/extending-jax is still a good guide for doing JAX custom call at head. |
|
@zhangqiaorjc Thank you! |
|
How create a custom class instance with several call method and its own stateful resource. Something like Extending TorchScript with Custom C++ Classes |
No description provided.