Replies: 1 comment 1 reply
-
Hi Garry, unfortunately there are no official guarantees on TPU VM image stability, which is necessary to ensure that a given JAX version can continue to run on TPU VM (mostly due to the low-level TPU driver shipped as part of the VM image). There are some internal discussions around maintaining a guaranteed support window, although it's unlikely to extend to 3 years. In practice the TPU images have been relatively stable since Cloud TPU VMs went GA in May (roughly corresponding to jax 0.3.14), but I would understand if you don't want to rely on this. I suggest bringing this up with your Cloud TPU rep if you haven't already. Also just to clarify, JAX itself doesn't have any kind of LTS release branches, meaning we can't easily support a particular JAX version -- if you encounter any JAX bugs or missing features, the only choice is to upgrade along the main line. We do our best to maintain at least a 3-month compatibility window outside of If possible, I would recommend upgrading the JAX version of your application with some regularity, or at least testing your application against new JAX versions and addressing any known issues so you can more easily upgrade if needed. I realize this is more work for you, but is currently the only guaranteed way to ensure things keep working over a 3-year period, and has the added benefit of giving you new JAX features, bug fixes, and performance improvements :) We're of course happy to help if upgrading causes any issues for you!
To minimize compatibility issues, yes. However, you may decide that some features are useful enough to take the risk. You can always ask here if you have questions about the future of any particular feature -- we can try to make a prediction about how radically a feature is likely to change (e.g. pjit is in the process of coming out of experimental by being merged into jit, whereas xmap is more likely to see further changes before graduating from experimental). Hope this helps. |
Beta Was this translation helpful? Give feedback.
-
We have been working for a year on deep NN using JAX 0.2.22 and Jaxlib 0.1.69 using TPUs to build a NLP model for commercial use. We are still in the development process but as we move forward we want to make sure, we have at least three years of runway for our application before Google Cloud TPUs stop supporting JAX 0.2.22. As an example JAX 0.2.12 was released last year and now is no longer supported by TPU VMs. Can you please provide some guidance so we can properly plan our application lifecycle?
One more question should we avoid the use of experimental functions in JAX (we really like and use some of them) to avoid compatibility issues in the future? Thanks.
Beta Was this translation helpful? Give feedback.
All reactions