Using JAX as a framework to chain together primitive operations #3950
-
Hi, Let's say I'm trying to compute the derivative of a function using reverse-mode AD. That function will be made up of primitive operations. Let's say one of those primitive operations is too complex or specialized for me to program in JAX, but I have a FORTRAN callable which computes the primitive and a FORTRAN callable which computes the VJP of that primitive. Is there a way to write a custom primitive in JAX which wraps that FORTRAN backend, while writing other primitive operations in native JAX? My guess is: something like this will work, but it couldn't be JITted. Is that right? Can JAX automatically JIT the composition of python primitives, exit from the JIT, call FORTRAN, then JIT the rest of the function? Or do I need to set that up manually by JITting twice (or more), keeping track of what primitives in the composition of primitives cannot be JITted? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
In order to write a custom primitive that can't be jitted, all you need to do is create a If you want to be able to jit code that calls your primitive, you need to additionally register your external code (with a C API) as an XLA CustomCall on CPU and/or GPU. There are examples in the repo that use Cython to call LAPACK functions in FORTRAN and C++ to call PRNG functions in CUDA C, as well as an external repo that uses this mechanism to wrap MPI. Then you need to add an XLA translation rule to your primitive that creates this CustomCall. |
Beta Was this translation helpful? Give feedback.
-
This documentation page is a good reference on writing a JAX primitive. For what it's worth, I agree that this is a little painful currently. It would be ideal if JAX provides a mechanism for calling back into Python without requiring writing a CustomCall from scratch. See #766 for that feature request. |
Beta Was this translation helpful? Give feedback.
In order to write a custom primitive that can't be jitted, all you need to do is create a
Primitive
instance and implementeval
andabstract_eval
rules (along with whatever transformation rules you need, e.g. vmap and jvp and/or transpose).If you want to be able to jit code that calls your primitive, you need to additionally register your external code (with a C API) as an XLA CustomCall on CPU and/or GPU. There are examples in the repo that use Cython to call LAPACK functions in FORTRAN and C++ to call PRNG functions in CUDA C, as well as an external repo that uses this mechanism to wrap MPI. Then you need to add an XLA translation rule to your primitive that creates this CustomCall.