Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax cpp interface #1871

Closed
mgbukov opened this issue Dec 16, 2019 · 11 comments
Closed

jax cpp interface #1871

mgbukov opened this issue Dec 16, 2019 · 11 comments

Comments

@mgbukov
Copy link

mgbukov commented Dec 16, 2019

Is there (any work towards) a c++ interface for jax?

@skye
Copy link
Member

skye commented Dec 16, 2019

What would you like to do exactly? It's currently possible to save a jit'd version of a function as an XLA program, which can then be invoked from C++ via XLA's C++ client. Check out jax/tools/jax_to_hlo.py (HLO is how XLA programs are represented).

@mgbukov
Copy link
Author

mgbukov commented Dec 16, 2019 via email

@skye
Copy link
Member

skye commented Dec 16, 2019

With the above tooling I mentioned, you could build the network + gradient function in Python, save the resulting XLA HLO, then quickly evaluate both via the C++ interface.

(I'm gonna close this issue because I don't think there's any concrete work to do here, but feel free to continue commenting.)

@dawgster
Copy link

Hi @skye!
Could you give a few pointers as to how to get started on this? Ive went through the C++ client you linked but havnt found a possibility to load an HLO file at first or second glances. Ive already built a network in jax, and transformed it to HLO. The final piece of the puzzle is to run it via that interface. Any help would be greatly appreciated. Thanks!

@hawkinsp
Copy link
Collaborator

@dawgster I think the key thing to notice is that the XlaComputation passed to LocalClient::Compile() has a constructor that takes an HloModuleProto. Does that help?

@dawgster
Copy link

@hawkinsp Yes, that helped quite a bit. Thanks for your help!

@zhangqiaorjc
Copy link
Collaborator

#5337 shows an explicit example

It can be adapted easily for GPU client.

@uduse
Copy link

uduse commented May 16, 2021

Do you think it's a good alternative to call JAX python code in C++?

@Roy-Kid
Copy link

Roy-Kid commented Sep 6, 2021

Do you think it's a good alternative to call JAX python code in C++?

something like #include<Python.h>? I wonder if there is a more elegant approach to invoke jax in c++. I need to write a plugin for OpenMM and the reason I use jax is that it is small and flexible than tf.

@Firestorm-253
Copy link

@skye the links you posted don't exist anymore, is there a new way of using jax via C++?

@skye
Copy link
Member

skye commented Aug 12, 2024

Here's where the C++ PJRT client is these days: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h

We also have a C API now too, which is very similar to the C++ API: https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h
The advantages are you don't need to build your application code with bazel in order to link in all the XLA stuff, you can just copy that header file into your project, and then separately build or use a pre-built PJRT library to actually run the functions. (Lemme know if you want more detail on this, I think there are other comment threads somewhere...)

To get HLO from jax functions, check out https://jax.readthedocs.io/en/latest/aot.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants