Skip to content

vpj/jax_transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 

Repository files navigation

This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.

I have implemented a simple Module class to build basic building blocks upon.

This was my first JAX project and many implementations were taken from PyTorch implementations at nn.labml.ai.

JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.

In JAX you don't have to worry about the batches. The functions are implemented for a single sample and jax.vit can vectorize (parallelize) the functions across the batch dimension (or any other dimension if needed).

Contents

View Run Twitter thread

About

Autoregressive transformer in JAX from scratch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages