Skip to content

srishti-git1110/torch-switch-transformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Switch Transformers

PyTorch implementation of the Switch Transformer paper. Read also my blogpost covering the paper.

Switch Layer

News

  • Now supporting the latest aux_loss free load balancing technique from this paper. Simply pass use_biased_gating=True while instantiating the SwitchTransformer class.

Rest all is taken care of!

switch_transformer = SwitchTransformer(
    inp_dim,
    num_experts,
    num_heads,
    vocab_size,
    use_biased_gating=True,
).cuda()

Usage

  1. Clone the repo
git clone https://github.com/srishti-git1110/torch-switch-transformers.git
  1. Navigate to the correct directory
cd torch-switch-transformers
  1. Install the required dependencies
pip install -r requirements.txt
  1. Usage

With aux_loss

import torch

from switch_transformers import SwitchTransformer

inp_dim = 512
num_experts = 8
num_heads = 8
vocab_size = 50000

switch_transformer = SwitchTransformer(
    inp_dim,
    num_experts,
    num_heads,
    vocab_size,
    use_aux_loss=True, # optional since this is used by default if use_biased_gating is not True
).cuda()

x = torch.randn(2, 1024, inp_dim).cuda()
output, total_aux_loss = switch_transformer(x)

With aux_loss free load balancing

switch_transformer = SwitchTransformer(
    inp_dim,
    num_experts,
    num_heads,
    vocab_size,
    use_biased_gating=True,
).cuda()

About

PyTorch implementation of the Switch Transformers paper

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages