-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax cpp interface #1871
Comments
What would you like to do exactly? It's currently possible to save a |
I was thinking of building and evaluating networks on a GPU and computing
gradients quickly.
…On Mon, Dec 16, 2019 at 1:37 PM Skye Wanderman-Milne < ***@***.***> wrote:
What would you like to do exactly? It's currently possible to save a jit'd
version of a function as an XLA program, which can then be invoked from C++
via XLA's C++ client
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/local_client.h>.
Check out jax/tools/jax_to_hlo.py
<https://github.com/google/jax/blob/master/jax/tools/jax_to_hlo.py> (HLO
is how XLA programs are represented).
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#1871>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADHSUBWTPUWYFO5NFLJAMR3QY7YJNANCNFSM4J3PNNZA>
.
--
Marin Bukov, PhD
Gordon and Betty Moore postdoctoral fellow,
Dept. of Physics
University of California, Berkeley
366 LeConte Hall MC 7300,
Berkeley, CA 94720-7300
url: https://mgbukov.github.io/
|
With the above tooling I mentioned, you could build the network + gradient function in Python, save the resulting XLA HLO, then quickly evaluate both via the C++ interface. (I'm gonna close this issue because I don't think there's any concrete work to do here, but feel free to continue commenting.) |
Hi @skye! |
@dawgster I think the key thing to notice is that the |
@hawkinsp Yes, that helped quite a bit. Thanks for your help! |
#5337 shows an explicit example It can be adapted easily for GPU client. |
Do you think it's a good alternative to call JAX python code in C++? |
something like #include<Python.h>? I wonder if there is a more elegant approach to invoke jax in c++. I need to write a plugin for OpenMM and the reason I use jax is that it is small and flexible than tf. |
@skye the links you posted don't exist anymore, is there a new way of using jax via C++? |
Here's where the C++ PJRT client is these days: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h We also have a C API now too, which is very similar to the C++ API: https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h To get HLO from jax functions, check out https://jax.readthedocs.io/en/latest/aot.html |
Is there (any work towards) a c++ interface for jax?
The text was updated successfully, but these errors were encountered: