Skip to content
This repository has been archived by the owner on Dec 29, 2022. It is now read-only.

Latest commit

 

History

History
17 lines (13 loc) · 1.14 KB

README.md

File metadata and controls

17 lines (13 loc) · 1.14 KB

Vi-T in Flax

A very basic implementtation of Vi-T paper using Flax neural network framework. The main goal of this one is to learn the device-agnostic framework, not get the best results. All results are collected in wandb.sweep using a small custom logger wrapper.

Architecture and some implementation details

Architecture of the model is only suitable for classification tasks

  • Used Adam optimizer with cosine schedule of rate learning and gradient clipping;
  • Used MultiHead self-attention with n = 8 heads and hidden dimension of 768;
  • Implemented learnable and sinusoid positional embeddings but used the former;

Helpful links

  1. https://huggingface.co/flax-community/vit-gpt2/tree/main/vit_gpt2
  2. https://github.com/google/flax/blob/main/examples/imagenet/train.py
  3. Official implementation
  4. Good set of jax tutorials