Skip to content

Using JAX as a framework to chain together primitive operations #3950

Discussion options

You must be logged in to vote

In order to write a custom primitive that can't be jitted, all you need to do is create a Primitive instance and implement eval and abstract_eval rules (along with whatever transformation rules you need, e.g. vmap and jvp and/or transpose).

If you want to be able to jit code that calls your primitive, you need to additionally register your external code (with a C API) as an XLA CustomCall on CPU and/or GPU. There are examples in the repo that use Cython to call LAPACK functions in FORTRAN and C++ to call PRNG functions in CUDA C, as well as an external repo that uses this mechanism to wrap MPI. Then you need to add an XLA translation rule to your primitive that creates this CustomCall.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by shoyer
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants