Skip to content

Commit

Permalink
test(efficient_kan,original_kan): coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 5, 2024
1 parent d3a5a6d commit 359f8d5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
20 changes: 16 additions & 4 deletions kan_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@
from kan_gpt.mingpt.utils import CfgNode as CN
from kan_gpt.settings import settings

if settings.kan.KAN_IMPLEMENTATION == "EFFICIENT_KAN":
KAN = EFFICIENT_KAN # type: ignore
elif settings.kan.KAN_IMPLEMENTATION == "ORIGINAL_KAN":
KAN = ORIGINAL_KAN # type: ignore

def get_KAN():
if settings.kan.KAN_IMPLEMENTATION == "EFFICIENT_KAN":
KAN = EFFICIENT_KAN # type: ignore
elif settings.kan.KAN_IMPLEMENTATION == "ORIGINAL_KAN":
KAN = ORIGINAL_KAN # type: ignore

return KAN


# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -60,6 +65,8 @@ class CausalSelfAttention(nn.Module):

def __init__(self, config):
super().__init__()
KAN = get_KAN()

assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = KAN(width=[config.n_embd, 3 * config.n_embd])
Expand Down Expand Up @@ -118,6 +125,7 @@ class Block(nn.Module):

def __init__(self, config):
super().__init__()
KAN = get_KAN()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
Expand Down Expand Up @@ -220,6 +228,7 @@ def __init__(self, config):
ln_f=nn.LayerNorm(config.n_embd),
)
)
KAN = get_KAN()
self.lm_head = KAN(
width=[config.n_embd, config.vocab_size], bias_trainable=False
)
Expand Down Expand Up @@ -285,6 +294,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):

total_reg = torch.tensor(0.0).to(device=x.device, dtype=torch.float32)
size = 0
KAN = get_KAN()
for mod in self.modules():
if isinstance(mod, KAN):
total_reg += reg(mod)
Expand All @@ -294,6 +304,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
return mean_reg

def _init_weights(self, module):
KAN = get_KAN()
if isinstance(module, KAN):
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# if module.bias is not None:
Expand Down Expand Up @@ -369,6 +380,7 @@ def configure_optimizers(self, train_config):
# regularizing weight decay
decay = set()
no_decay = set()
KAN = get_KAN()
whitelist_weight_modules = (KAN,)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
Expand Down
64 changes: 54 additions & 10 deletions tests/test_gpt_kan.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,58 @@
import os
import torch
from kan_gpt.model import GPT as KAN_GPT
import pytest
from unittest import mock

VOCAB_SIZE = 8
BLOCK_SIZE = 16
MODEL_TYPE = "gpt-pico"


def get_gpt_model() -> KAN_GPT:
@mock.patch.dict(
os.environ, {"KAN_IMPLEMENTATION": "EFFICIENT_KAN"}, clear=True
)
def get_gpt_model_efficient():
from kan_gpt.model import GPT as KAN_GPT

model_config = KAN_GPT.get_default_config()
model_config.model_type = MODEL_TYPE
model_config.vocab_size = VOCAB_SIZE
model_config.block_size = BLOCK_SIZE
model = KAN_GPT(model_config)

del KAN_GPT

return model


@mock.patch.dict(
os.environ, {"KAN_IMPLEMENTATION": "ORIGINAL_KAN"}, clear=True
)
def get_gpt_model_original():
from kan_gpt.model import GPT as KAN_GPT

model_config = KAN_GPT.get_default_config()
model_config.model_type = MODEL_TYPE
model_config.vocab_size = VOCAB_SIZE
model_config.block_size = BLOCK_SIZE
model = KAN_GPT(model_config)

del KAN_GPT

return model


def test_forward():
@pytest.fixture
def model(request):
return request.param()


@pytest.mark.parametrize(
"model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True
)
def test_forward(model):
with torch.no_grad():
model = get_gpt_model()

x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)

y, loss = model.forward(x)
Expand All @@ -29,8 +64,11 @@ def test_forward():
), f"Shape mismatch: {y.shape}"


def test_backward():
model = get_gpt_model()
@pytest.mark.parametrize(
"model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True
)
def test_backward(model):
model = model
x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)
y_gt = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)

Expand All @@ -55,9 +93,12 @@ def test_backward():
assert len(grad_set) > 0, f"Tensor.grad missing"


def test_forward_batched():
@pytest.mark.parametrize(
"model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True
)
def test_forward_batched(model):
with torch.no_grad():
model = get_gpt_model()

x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)

y, loss = model.forward(x)
Expand All @@ -69,8 +110,11 @@ def test_forward_batched():
), f"Shape mismatch: {y.shape}"


def test_backward_batched():
model = get_gpt_model()
@pytest.mark.parametrize(
"model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True
)
def test_backward_batched(model):
model = model
x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)
y_gt = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)

Expand Down

0 comments on commit 359f8d5

Please sign in to comment.