This repository shows examples of running JAX models from C++ code.
One option (which involves JIT compilation) is to use an HLO file to run the JAX mode.
This can be tried by running bazel run //cpp:hlo_example
.
Another approach (relying on AOT compilation) is to serialize a pre-compiled executable
beforehand. Once the executable exists, this can be achieved by running bazel run //cpp:aot_example
.
Run the following commands:
python3 -m venv .venv
source .venv/bin/activate
pip install -e .
python3 call_jax_from_cpp/simple_jax_example.py