Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA runner (to support JAX) #163

Open
VivekPanyam opened this issue Sep 29, 2023 · 0 comments
Open

XLA runner (to support JAX) #163

VivekPanyam opened this issue Sep 29, 2023 · 0 comments
Labels
new ml framework Adding support for a new runner/ML framework

Comments

@VivekPanyam
Copy link
Owner

Background

XLA is an ML "compiler for GPUs, CPUs, and ML accelerators."

Carton support for XLA would primarily be used to provide JAX support, but in theory it could also support some PyTorch and TensorFlow models.

Here's a guide on how to export a JAX model from Python and run it from C++ using XLA: jax-ml/jax#5337 (comment).

This is an example of the above in the JAX codebase.

I've explored doing this in the past (outside of Carton), but there weren't XLA prebuilt binaries available and it required building from source in the TensorFlow repo. Now, with OpenXLA and prebuilt binaries, this is a lot easier.

@LaurentMazare created rust bindings to XLA that include a straightforward example of loading the HLO IR generated by the JAX export code. That should make it fairly easy to prototype an integration with Carton if anyone is interested in doing so.

Implementation

Concretely, this could be implemented as follows:

Recommended reading:

@VivekPanyam VivekPanyam added the new ml framework Adding support for a new runner/ML framework label Sep 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new ml framework Adding support for a new runner/ML framework
Projects
None yet
Development

No branches or pull requests

1 participant