From d8678a143670c951c883f92e042f33d9a98876a8 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 3 Oct 2022 11:45:56 -0700 Subject: [PATCH 1/5] Add proposal for FFI JEP Co-authored-by: Kuangyuan Chen Co-authored-by: Qiao Zhang --- docs/jep/12535-ffi.md | 338 ++++++++++++++++++++++++++++++++++++++++++ docs/jep/index.rst | 1 + 2 files changed, 339 insertions(+) create mode 100644 docs/jep/12535-ffi.md diff --git a/docs/jep/12535-ffi.md b/docs/jep/12535-ffi.md new file mode 100644 index 000000000000..02443ba6bc4d --- /dev/null +++ b/docs/jep/12535-ffi.md @@ -0,0 +1,338 @@ +# JAX Foreign Function Interface + +* Authors: Sharad Vikram, Kuangyuan Chen, Qiao Zhang +* Date: October 3, 2022 + +## tl;dr + +We propose a new API for user foreign functions calls, eliminating a lot of currently needed boilerplate and providing a tighter integration with JAX. + +### Example usage on CPU + +User-written C++ FFI function (exposed via `pybind11`) +```cpp +#include +#include "jax_ffi.h" // Header that comes from JAX + +namespace py = pybind11; + +struct Info { + float n; +}; + +extern "C" void add_n(JaxFFI_API* api, JaxFFIStatus* status, + void* static_info, void** inputs, void** outputs) { + Info* info = (Info*) static_info; + if (info->n < 0) { + JaxFFIStatusSetFailure(api, status, "Info must be >= 0"); + return; + } + float* input = (float*) inputs[0]; + float* output = (float*) outputs[0]; + *output = *input + info->n; +} + +PYBIND11_MODULE(add_n_lib, m) { + m.def("get_function", []() { + return py::capsule(reinterpret_cast(add_one), JAX_FFI_CALL_CPU); + }); + m.def("get_static_info", [](float n) { + Info* info = new Info(); + info->n = n; + return py::capsule(reinterpret_cast(info), [](void* ptr) { + delete reinterpret_cast(ptr); + }); + }); +} +``` + +JAX registration code and usage +```python +import jax +import add_n_lib + +add_n = jax.ffi.FFICall("add_one") +add_n.register_function(add_n_lib.get_function(), platform="cpu") + +def f(x): + static_info = add_n_lib.get_static_info(4.) + return add_n(x, static_info=static_info, out_shape_dtype=x) + +print(jax.jit(f)(4.)) +``` + +## Motivation +Our motivation is three-fold: + +* We’d like to make JAX’s APIs for defining foreign calls simpler. Currently, we have no official API and defining calls to foreign functions involves using XLA’s custom call API and creating JAX primitives that wrap the custom calls. Not only does this involve a lot of boilerplate to hook XLA and JAX together, but it also uses private JAX APIs, making the code unstable with future JAX releases. +* We’d like to better integrate foreign calls with JAX’s system. Currently, the only way foreign calls can be exposed with JAX is via creating custom primitives, which is not a stable integration point. By wrapping both the XLA custom call registration and the JAX primitive registration, we can provide a stable API and provide integration points with the rest of JAX, e.g. custom gradients, batching, and error handling. +* We’d like to have centralized documentation on how to extend JAX with foreign function calls. Currently users need to combine documentation from multiple sources (XLA documentation, JAX primitive documentation). Centralizing documentation will not only prevent user confusion but lower the barrier to using foreign functions. At the moment, nothing is documented on the JAX site itself, leading people to write and use external guides, namely [the dfm guide](https://github.com/dfm/extending-jax). + +Desiderata: +* We’d like the system to be as expressive as XLA’s current custom call machinery. We don’t want to limit what sorts of programs the user can write, forcing them to use the underlying private APIs. +* We’d like users to avoid having to learn about XLA’s custom call machinery and JAX’s primitive internals. For one, it adds a lot of mental overhead, but also we’d like a new FFI API to be both agnostic to 1) the choice of backend compiler and 2) JAX’s changing internals. + +## Background + +JAX uses XLA to compile staged-out Python programs, which are represented with MHLO. MHLO offers common operations used in machine learning, such as the dot-product, exponential function, and so on. Unfortunately, sometimes these operations (and compositions of these operations) are not sufficiently expressive or performant and users look to external libraries (Triton, Numba) or hand-written code (CUDA, C/C++) to implement operations for their desired program. + +The mechanism XLA offers to write custom ops is the XLA `CustomCall`. Its API is general purpose and JAX uses it for host callbacks and custom GPU/CPU library ops (`lapack`, `cublas`, `ducc`). To register your own custom op, the steps are (at a high level): +1. Compile a Python extension that exposes your code as a capsule (e.g. via `pybind11`, Cython). +2. Import that Python extension in your Python code and register it with JAX’s `xla_extension` (via `jax._src.lib.xla_client`) with a name for the custom call. +3. Create a new JAX primitive and register a MLIR lowering that emits a custom call with a matching name. + +Note that in order to do this, users need to be familiar with 1) the XLA custom call API (documented separately from JAX), 2) XLA’s Python extension API (how to register custom calls), and 3) JAX’s primitive API (which is internal and subject to change) and 4) the MHLO/MLIR python bindings. + +There are also many details associated with each of these steps, outlined in [the dfm guide](https://github.com/dfm/extending-jax). Note that the dfm guide uses the out-of-date XLA builder, not MLIR builder. + +## Technical challenges +JAX FFI fundamentally offers a C API to JAX users to extend the JAX system. Concretely, JAX users write a C/C++ custom kernel that needs to be registered with the XLA custom call registry. Further, users may want to use C helpers from JAX that for example report errors in a way that JAX can consume. We observe these three main challenges: +1. Avoiding building `jaxlib` – we do not want users to rebuild `jaxlib` because of possible version mismatch with JAX and also the notoriously long build time for TensorFlow/XLA subtree. +2. Registering a function (C function for ABI compatibility) defined in a user shared object with the XLA custom call registry defined in `jaxlib` shared object without running into duplicate global objects (XLA custom call registry is a global variable) issues +3. Referencing helper functions defined in `jaxlib` in user defined shared objects + +At a high level, all three challenges are about linking and symbol resolution in the presence of more than one shared object. + +## Proposal + +Here we go over the various parts of the proposed API, starting with how users expose their foreign functions to Python and ending with how those foreign functions are registered with and used in JAX. + +### User foreign function API +We propose the following API for user-written foreign CPU functions (with an additional `CUstream` argument for GPU): +```c +#include "jax_ffi.h" +void user_function(JaxFFI_API* api, JaxFFIStatus* status, void* static_info, void** outputs, void** inputs) { + ... +} +``` + +This roughly mirrors the XLA custom call API, except we provide `api`, which contains helper functions. `status` is used to indicate that there was an error in the computation like `XlaCustomCallStatus`. We also provide `static_info`, which contains compile-time/lowering-time information from JAX. The JaxFFI types will be exposed in a `jax_ffi.h` header that is shipped with `jaxlib`. Note that in this user code example, we do not reference any TensorFlow/XLA headers, meaning that the user only needs the `jax_ffi.h` header that will be shipped with `jaxlib` and users do not need to rebuild `jaxlib`, addressing challenge #1. + +We also propose APIs that avoid some potentially unnecessary arguments (`api`, `status`, `static_info`). + +Challenge #3 is about allowing users to find the helper functions exposed by `jaxlib`. Instead of the typical dynamic runtime linking (e.g., via `dlopen`), we offer a solution similar to Numpy C API linking. + +`jaxlib` needs to expose a few helper functions that user kernels can invoke. We implement these functions in a new file `jax_ffi.c`. To help symbol resolution, we store a pointer array of C function pointers called `JaxFFI_API_Table`: + +```c +struct JaxFFIStatus { + std::optional message; +}; + +int JAX_FFI_V1 = 1; + +int JaxFFIVersionFn() { + return JAX_FFI_V1; +} + +void JaxFFIStatusSetFailureFn(JaxFFIStatus* status, const char* message) { + status->message = std::string(message); +} + +void *JaxFFI_API_Table[] = { + (void *)&JaxFFIVersionFn, + (void *)&JaxFFIStatusSetFailureFn +}; +``` + +When jaxlib eventually invokes the user function, jaxlib will pass in the pointer array as an argument explicitly as `JaxFFI_API* api` (note that `jaxlib` stores the pointer array). We then provide macros in `jax_ffi.h` as convenience methods to index into the pointer array and find the appropriate helper function: + +```c +#define JAX_FFI_CALL_CPU "JAX_FFI_CALL_CPU" + +#define JaxFFIVersion() \ + ((*((void (*)())(api[0])))()) +#define JaxFFIStatusSetFailure(api, status, msg) \ + ((*((void (*)(JaxFFIStatus*, const char*))(api[1])))(status, msg)) + +struct JaxFFIStatus; + +typedef void* JaxFFI_API; +``` + + +### Exposing foreign functions to JAX + +To address challenge #2, namely that of needing to register in the XLA extension shared object, we require the user to expose the foreign function to Python. This can be done in a variety of ways, but fundamentally we need to produce Python bindings that expose the function pointer `&user_function`. The function pointer can be handled in Python opaquely via a `PyCapsule`. Here’s an example using `pybind11`: + +```cpp +#include +#include "jax_ffi.h" + +void user_function(JaxFFI_API* api, JaxFFIStatus* status, void* static_info, void** outputs, void** inputs) { + ... +} + +PYBIND11_MODULE(user_function_lib, m) { + m.def("get_function", []() { + return py::capsule(reinterpret_cast(user_function), JAX_FFI_CALL_CPU); + }); +} +``` + +The capsule should be given a name to both indicate the type of signature (CPU vs GPU) and to throw errors early if the wrong function is registered. + +### Passing in custom descriptors +Foreign functions may often need more than just “runtime information” like the values of input arrays. “Static information” that is provided in JAX at tracing/lowering/compile time also needs to be passed into the foreign function. XLA (currently) offers two separate mechanisms for providing this static information to custom calls. + +On CPU, there is no official mechanism but this information can often be provided by passing a pointer value as an argument to the custom call, which points to a heap allocated object. Inside of the custom call, the pointer can be dereferenced to get the object and access its information. + +On GPU, the custom call API offers opaque, a string that will be passed to the custom call. This requires that the information to be passed to the custom call needs to be serializable. Note that we can also “sneak” pointer values in the opaque string, allowing us to pass heap allocated objects as well. + +From the user perspective, these details are unnecessary and can be handled internally by JAX. The user should have a single API for passing this static information into a custom call. + +Suppose the user wants to pass a struct Info into their foreign function. In order to do so, the Info struct (or pointers to it) need to be available to Python so the JAX can construct MHLO that passes it back into the custom call. + +#### Exposing static information to Python via a pointer (pass by reference) + +```cpp +#include +#include "jax_ffi.h" + +struct Info { + float n; +}; + +void user_function(JaxFFI_API* api, JaxFFI_Status* status, void* static_info, void** outputs, void** inputs) { + Info info = (Info*) descriptor; + ... +} + +PYBIND11_MODULE(user_function_lib, m) { + m.def("get_function", []() { + return py::capsule(reinterpret_cast(user_function), JAX_FFI_CALL_CPU); + }); + m.def("make_info", [](float n) { + Info* info = new Info(); + info->n = n; + return py::capsule(reinterpret_cast(info), [](void* ptr) { + delete reinterpret_cast(ptr); + }); + }); +} +``` + +This approach wraps a heap allocated object in a capsule, and destroys the object when the capsule object is destroyed. This hands the ownership of the object to Python. JAX will then take ownership of the object and give it to the executable, like it does with other capsule objects. This is how JAX handles host callbacks. + +#### Exposing static information to Python via serialization (pass by value) + +```cpp +m.def("make_info", [](float n) { + Info info; + info.n = n; + return pybind11::bytes(std::string(reinterpret_cast(&info), sizeof(Info))); +}); +``` + + +This approach serializes the struct as a string, then returns it to Python as a bytes object. Since we’re not doing any heap allocation, we don’t need to worry about ownership and don’t require JAX to keep a heap object alive. This is how JAX handles custom calls to external libraries like cublas, lapack, and ducc. + +JAX should handle both cases (pass by reference and value) and pass the appropriate pointer back into the user foreign function. + +### Handling foreign functions and descriptors in Python + +We’ve shown how users expose foreign functions to Python. Now we’ll show how users register and use these functions with JAX. + +#### Registering FFI calls in JAX +First we introduce a new JAX module, `jax.ffi`. `jax.ffi` will expose a `jax.ffi.FFICall` object. + +```python +Platform = FunctionPointer = Any + +class FFICall: + name: str + _registry: Dict[Platform, FunctionPointer] + + def register_function(self, function_ptr, *, platform): + ... + + def __call__(self, *args, **kwargs): + ... +``` + +We can construct `FFICall`s with a string name that uniquely identifies them (we should error if the same name is used twice). + +```python +import jax.ffi + +user_function = jax.ffi.FFICall("user_function") +``` + +We allow users to register platform-specific implementations for the FFI call. +```python +import user_function_lib # the Python extension +user_function.register_function(user_function_lib.get_function(), platform="cpu") +``` + +This allows users to write a CPU version of their code and a GPU version as well. +#### Calling foreign functions from JAX + +`FFICall` objects have a `__call__` method that invokes a JAX primitive, `jax_ffi_call`, that has already registered transformation rules. + +```python +@jax.jit +def f(...): + ... = user_function(..., return_shape_dtype=...) +``` + +The user needs to provide a `return_shape_dtype` information since that can’t be inferred by JAX and JAX requires statically known shapes and dtypes for all values. + +To pass in a descriptor as well, users can construct static information and pass it into the `user_function` via a reserved keyword argument `static_info`. + +``` +@jax.jit +def f(...): + static_info = user_function_lib.make_info(4.) + ... = user_function(..., static_info=static_info, return_shape_dtype=...) +``` + +#### JAX custom call wrapper + +When the user eventually calls the `FFICall`, we emit a specific MHLO custom call (`jax_ffi_call`) during lowering. This custom call is already registered and is passed both the function pointer capsule (registered earlier) and the static info (passed into the primitive). It then prepares the `api` and `status` and calls the function pointer along with `api` and `status` with the input/output pointers. + +```cpp +struct Descriptor { + void* function_ptr; + void* user_static_info; +}; + +extern "C" void JaxFFICallWrapper(void* output, void** inputs, + XlaCustomCallStatus* status) { + auto descriptor = reinterpret_cast(*static_cast(inputs[0])); + inputs += 1; + JaxFFIStatus jax_ffi_status; + auto function_ptr = reinterpret_cast(descriptor->function_ptr); + function_ptr(JaxFFI_API_Table, &jax_ffi_status, + descriptor->user_static_info, + inputs, reinterpret_cast(output)); + if (jax_ffi_status.message) { + // Handle error! + } +} +``` + + +On CPU (for now), the `function_ptr` and `user_static_info` are passed by reference as the first argument to the custom call. On GPU, they can be passed by reference via the opaque string. + +Note that XLA custom calls support custom layouts for operands and results. Here we’ll generate MHLO that uses default layouts, which technically limits what users can express. + +### Handling JAX transformations + +Unlike in the dfm guide, users are not constructing JAX primitives and therefore don’t have the opportunity to register transformation rules for those primitives. Do we want to expose them and if so, how? + +For automatic differentiation, users have the option of wrapping their `FFICall` with a `jax.custom_jvp` or `jax.custom_vjp`. Alternatively we could expose additional methods on `FFICall` that do something similar. The `jax.custom_*` (`custom_vmap`, `custom_transpose`, etc.) API, in principle, could also handle any custom behavior users want from FFI calls. However, this API has not been fully built out yet. + +For now, we propose not committing to a specific transformation API for custom calls and wait to see how the `custom_*` solution plays out. If users want very specific transformation behavior, they can rely on the (internal) primitive API, i.e. the status quo. Custom transformation behavior is orthogonal to the problem of enabling foreign function calls and fits into the larger discussion of how to expose custom primitives to users. + +We could also consider some default transformation rules under vmap, for example. If the user promises their function is pure, we can adopt a strategy like `pure_callback` where we can sequentially map over the batch dimensions. If the function is not pure, we disable most transformations. + +### Error handling + +In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in [the error handling JEP](). Creating this layer of indirection allows us to have functional error handling in custom calls as well. + +There are also a few different ways custom calls often fail. In `jaxlib`, custom calls will call `XlaCustomCallStatusSetFailure` usually when there is an unrecoverable failure (CUDA errors, OOMs, etc.). Arguably we shouldn’t handle these errors in JAX itself. Other sorts of errors, for example numerical errors in linear algebra routines, could be surfaced in JAX via the unified error handling API. We should consider having an extra bit (e.g. recoverable) in `JaxFFIStatusSetFailure` that distinguishes between these two types of errors. + +## Implementation Plan + +We provide a prototype of the proposed API in [this PR](https://github.com/google/jax/pull/12396). diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 417d2dbf5cb1..bae243faaf17 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -46,6 +46,7 @@ Then create a pull request that adds a file named 10657: Sequencing side-effects in JAX <10657-sequencing-effects> 11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint> 12049: Type Annotation Roadmap for JAX <12049-type-annotations> + 12535: Foreign Function Interface <12535-ffi> Several early JEPs were converted in hindsight from other documentation, From 54779ca4b0ae7e3675017f14cb9cf8f7ec1460ed Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 3 Oct 2022 12:35:36 -0700 Subject: [PATCH 2/5] Use better bytes constructor --- docs/jep/12535-ffi.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jep/12535-ffi.md b/docs/jep/12535-ffi.md index 02443ba6bc4d..75607fee8509 100644 --- a/docs/jep/12535-ffi.md +++ b/docs/jep/12535-ffi.md @@ -220,7 +220,7 @@ This approach wraps a heap allocated object in a capsule, and destroys the objec m.def("make_info", [](float n) { Info info; info.n = n; - return pybind11::bytes(std::string(reinterpret_cast(&info), sizeof(Info))); + return pybind11::bytes(reinterpret_cast(&info), sizeof(Info)); }); ``` From 9082cb5e08d1e0bc8bc43fa5fe72a96cf594ab3e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 4 Oct 2022 11:34:59 -0700 Subject: [PATCH 3/5] Added dfm suggestions --- docs/jep/12535-ffi.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/jep/12535-ffi.md b/docs/jep/12535-ffi.md index 75607fee8509..4595313d59ab 100644 --- a/docs/jep/12535-ffi.md +++ b/docs/jep/12535-ffi.md @@ -279,7 +279,7 @@ The user needs to provide a `return_shape_dtype` information since that can’t To pass in a descriptor as well, users can construct static information and pass it into the `user_function` via a reserved keyword argument `static_info`. -``` +```python @jax.jit def f(...): static_info = user_function_lib.make_info(4.) @@ -329,7 +329,7 @@ We could also consider some default transformation rules under vmap, for example ### Error handling -In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in [the error handling JEP](). Creating this layer of indirection allows us to have functional error handling in custom calls as well. +In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in [the error handling JEP](TODO). Creating this layer of indirection allows us to have functional error handling in custom calls as well. There are also a few different ways custom calls often fail. In `jaxlib`, custom calls will call `XlaCustomCallStatusSetFailure` usually when there is an unrecoverable failure (CUDA errors, OOMs, etc.). Arguably we shouldn’t handle these errors in JAX itself. Other sorts of errors, for example numerical errors in linear algebra routines, could be surfaced in JAX via the unified error handling API. We should consider having an extra bit (e.g. recoverable) in `JaxFFIStatusSetFailure` that distinguishes between these two types of errors. From 34d33a2649a21c75180f08fa3a4f0866f6a5ad14 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 4 Oct 2022 11:47:22 -0700 Subject: [PATCH 4/5] Make RTD pass by removing bad TODO anchor --- docs/jep/12535-ffi.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jep/12535-ffi.md b/docs/jep/12535-ffi.md index 4595313d59ab..afbf25b3eba4 100644 --- a/docs/jep/12535-ffi.md +++ b/docs/jep/12535-ffi.md @@ -329,7 +329,7 @@ We could also consider some default transformation rules under vmap, for example ### Error handling -In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in [the error handling JEP](TODO). Creating this layer of indirection allows us to have functional error handling in custom calls as well. +In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in the error handling JEP. Creating this layer of indirection allows us to have functional error handling in custom calls as well. There are also a few different ways custom calls often fail. In `jaxlib`, custom calls will call `XlaCustomCallStatusSetFailure` usually when there is an unrecoverable failure (CUDA errors, OOMs, etc.). Arguably we shouldn’t handle these errors in JAX itself. Other sorts of errors, for example numerical errors in linear algebra routines, could be surfaced in JAX via the unified error handling API. We should consider having an extra bit (e.g. recoverable) in `JaxFFIStatusSetFailure` that distinguishes between these two types of errors. From 8a203011a75d11130b654e763053298c6c177e08 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 6 Oct 2022 14:12:41 -0700 Subject: [PATCH 5/5] Fix add_one typo --- docs/jep/12535-ffi.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/jep/12535-ffi.md b/docs/jep/12535-ffi.md index afbf25b3eba4..c12be8cbe85a 100644 --- a/docs/jep/12535-ffi.md +++ b/docs/jep/12535-ffi.md @@ -34,7 +34,7 @@ extern "C" void add_n(JaxFFI_API* api, JaxFFIStatus* status, PYBIND11_MODULE(add_n_lib, m) { m.def("get_function", []() { - return py::capsule(reinterpret_cast(add_one), JAX_FFI_CALL_CPU); + return py::capsule(reinterpret_cast(add_n), JAX_FFI_CALL_CPU); }); m.def("get_static_info", [](float n) { Info* info = new Info(); @@ -51,7 +51,7 @@ JAX registration code and usage import jax import add_n_lib -add_n = jax.ffi.FFICall("add_one") +add_n = jax.ffi.FFICall("add_n") add_n.register_function(add_n_lib.get_function(), platform="cpu") def f(x):