PyTorch implementation of the Switch Transformer paper. Read also my blogpost covering the paper.
- Now supporting the latest aux_loss free load balancing technique from this paper. Simply pass
use_biased_gating=True
while instantiating theSwitchTransformer
class.
Rest all is taken care of!
switch_transformer = SwitchTransformer(
inp_dim,
num_experts,
num_heads,
vocab_size,
use_biased_gating=True,
).cuda()
- Clone the repo
git clone https://github.com/srishti-git1110/torch-switch-transformers.git
- Navigate to the correct directory
cd torch-switch-transformers
- Install the required dependencies
pip install -r requirements.txt
- Usage
import torch
from switch_transformers import SwitchTransformer
inp_dim = 512
num_experts = 8
num_heads = 8
vocab_size = 50000
switch_transformer = SwitchTransformer(
inp_dim,
num_experts,
num_heads,
vocab_size,
use_aux_loss=True, # optional since this is used by default if use_biased_gating is not True
).cuda()
x = torch.randn(2, 1024, inp_dim).cuda()
output, total_aux_loss = switch_transformer(x)
switch_transformer = SwitchTransformer(
inp_dim,
num_experts,
num_heads,
vocab_size,
use_biased_gating=True,
).cuda()