Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions megatron/model/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch import nn
from torch.nn import functional as F


class _GLUBaseModule(nn.Module):
def __init__(self, activation_fn):
super().__init__()
self.activation_fn = activation_fn

def forward(self, x):
# dim=-1 breaks in jit for pt<1.10
x1, x2 = x.chunk(2, dim=(x.ndim-1))
return x1 * self.activation_fn(x2)


class LiGLU(_GLUBaseModule):
def __init__(self):
super().__init__(nn.Identity())


class GEGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.gelu)


class ReGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.relu)


class SwiGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.silu)


liglu = torch.jit.script(LiGLU())
geglu = torch.jit.script(GEGLU())
reglu = torch.jit.script(ReGLU())
swiglu = torch.jit.script(SwiGLU())
43 changes: 43 additions & 0 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import random
import unittest

import torch
from torch.nn import functional as F

from megatron.model.activations import liglu, geglu, reglu, swiglu

from .utils import set_seed


class TestActivations(unittest.TestCase):
def setUp(self):
"""setup an input of reasonable size"""
set_seed()
self.batch_size = random.randint(2, 64)
self.seq_len = random.randint(256, 1025)
self.num_channels = random.randint(1, 384) * 2
self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels)
self.x1, self.x2 = self.x.chunk(2, dim=-1)

def test_shapes(self):
# glu should halve the last dimension
output_shape = [self.batch_size, self.seq_len, self.num_channels // 2]
for activation_fn in [liglu, geglu, reglu, swiglu]:
output = activation_fn(self.x)
self.assertEqual(list(output.shape), output_shape)

def test_liglu(self):
expected = self.x1 * self.x2
torch.testing.assert_equal(liglu(self.x), expected)

def test_geglu(self):
expected = self.x1 * F.gelu(self.x2)
torch.testing.assert_equal(geglu(self.x), expected)

def test_reglu(self):
expected = self.x1 * F.relu(self.x2)
torch.testing.assert_equal(reglu(self.x), expected)

def test_swiglu(self):
expected = self.x1 * F.silu(self.x2)
torch.testing.assert_equal(swiglu(self.x), expected)
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import random

import numpy as np
import torch


def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)