This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.
I have implemented a simple Module
class to build basic building blocks upon.
This was my first JAX project and many implementations were taken from PyTorch implementations at nn.labml.ai.
JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.
In JAX you don't have to worry about the batches.
The functions are implemented for a single sample and jax.vit
can vectorize (parallelize) the functions
across the batch dimension (or any other dimension if needed).