Plans with jax.experimental.export #19733
-
Hey all, I noticed that I'm currently evaluating different ways to deploy JAX models, and a native way to save and load jitted inference functions would be a strong contender. Given that there are no docs for the module yet, could you give us a glimpse of what's on the horizon?
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
Hi, We are preparing the documentation for the export module, but you are right that it is meant to be an alternative to jax2tf and SavedModel. In fact, the logic of the module is to a large extent extracted from jax2tf, minus the TF parts, plus a few features that TF does not support, such as effects. You can get a good sense of the features supported and the limitations if you look at the jax2tf documentation (the "native serialization" parts). We intend to support all JAX programs, although at the moment This module should already be ready to use for saving and loading Flax models, and other JAX programs. The serialization part of |
Beta Was this translation helpful? Give feedback.
-
Do you guys have a test suite of exported models yet? We'd like to support this better and it is fairly easy to run a project that has a good test suite of model variants (based on experience with onnx and torch). It's harder if the backend team has to synthesize that because it requires pulling from the more limited pool of folks who are competent to be putting together a Jax model garden. |
Beta Was this translation helpful? Give feedback.
-
We do not have a test suite of exported models. There are other repos like Flax and MaxText that have model examples, although not in exported form. For backend testing, the best we have are the JAX tests. They have really good coverage. There is even a way to generate several thousand test harnesses that should have good coverage for JAX primitives. See how |
Beta Was this translation helpful? Give feedback.
Hi,
We are preparing the documentation for the export module, but you are right that it is meant to be an alternative to jax2tf and SavedModel. In fact, the logic of the module is to a large extent extracted from jax2tf, minus the TF parts, plus a few features that TF does not support, such as effects. You can get a good sense of the features supported and the limitations if you look at the jax2tf documentation (the "native serialization" parts).
We intend to support all JAX programs, although at the moment
pmap
andxmap
are not supported.vmap
is supported in the sense that you can export a program that usesvmap
. The same is true for other transformations. There are however limitations …