diff --git a/kan_gpt/model.py b/kan_gpt/model.py index f963830..7643a37 100644 --- a/kan_gpt/model.py +++ b/kan_gpt/model.py @@ -20,10 +20,15 @@ from kan_gpt.mingpt.utils import CfgNode as CN from kan_gpt.settings import settings -if settings.kan.KAN_IMPLEMENTATION == "EFFICIENT_KAN": - KAN = EFFICIENT_KAN # type: ignore -elif settings.kan.KAN_IMPLEMENTATION == "ORIGINAL_KAN": - KAN = ORIGINAL_KAN # type: ignore + +def get_KAN(): + if settings.kan.KAN_IMPLEMENTATION == "EFFICIENT_KAN": + KAN = EFFICIENT_KAN # type: ignore + elif settings.kan.KAN_IMPLEMENTATION == "ORIGINAL_KAN": + KAN = ORIGINAL_KAN # type: ignore + + return KAN + # ----------------------------------------------------------------------------- @@ -60,6 +65,8 @@ class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() + KAN = get_KAN() + assert config.n_embd % config.n_head == 0 # key, query, value projections for all heads, but in a batch self.c_attn = KAN(width=[config.n_embd, 3 * config.n_embd]) @@ -118,6 +125,7 @@ class Block(nn.Module): def __init__(self, config): super().__init__() + KAN = get_KAN() self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) @@ -220,6 +228,7 @@ def __init__(self, config): ln_f=nn.LayerNorm(config.n_embd), ) ) + KAN = get_KAN() self.lm_head = KAN( width=[config.n_embd, config.vocab_size], bias_trainable=False ) @@ -285,6 +294,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): total_reg = torch.tensor(0.0).to(device=x.device, dtype=torch.float32) size = 0 + KAN = get_KAN() for mod in self.modules(): if isinstance(mod, KAN): total_reg += reg(mod) @@ -294,6 +304,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): return mean_reg def _init_weights(self, module): + KAN = get_KAN() if isinstance(module, KAN): # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) # if module.bias is not None: @@ -369,6 +380,7 @@ def configure_optimizers(self, train_config): # regularizing weight decay decay = set() no_decay = set() + KAN = get_KAN() whitelist_weight_modules = (KAN,) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): diff --git a/tests/test_gpt_kan.py b/tests/test_gpt_kan.py index 297adbe..a64a293 100644 --- a/tests/test_gpt_kan.py +++ b/tests/test_gpt_kan.py @@ -1,23 +1,58 @@ +import os import torch -from kan_gpt.model import GPT as KAN_GPT +import pytest +from unittest import mock VOCAB_SIZE = 8 BLOCK_SIZE = 16 MODEL_TYPE = "gpt-pico" -def get_gpt_model() -> KAN_GPT: +@mock.patch.dict( + os.environ, {"KAN_IMPLEMENTATION": "EFFICIENT_KAN"}, clear=True +) +def get_gpt_model_efficient(): + from kan_gpt.model import GPT as KAN_GPT + + model_config = KAN_GPT.get_default_config() + model_config.model_type = MODEL_TYPE + model_config.vocab_size = VOCAB_SIZE + model_config.block_size = BLOCK_SIZE + model = KAN_GPT(model_config) + + del KAN_GPT + + return model + + +@mock.patch.dict( + os.environ, {"KAN_IMPLEMENTATION": "ORIGINAL_KAN"}, clear=True +) +def get_gpt_model_original(): + from kan_gpt.model import GPT as KAN_GPT + model_config = KAN_GPT.get_default_config() model_config.model_type = MODEL_TYPE model_config.vocab_size = VOCAB_SIZE model_config.block_size = BLOCK_SIZE model = KAN_GPT(model_config) + + del KAN_GPT + return model -def test_forward(): +@pytest.fixture +def model(request): + return request.param() + + +@pytest.mark.parametrize( + "model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True +) +def test_forward(model): with torch.no_grad(): - model = get_gpt_model() + x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) y, loss = model.forward(x) @@ -29,8 +64,11 @@ def test_forward(): ), f"Shape mismatch: {y.shape}" -def test_backward(): - model = get_gpt_model() +@pytest.mark.parametrize( + "model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True +) +def test_backward(model): + model = model x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) y_gt = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) @@ -55,9 +93,12 @@ def test_backward(): assert len(grad_set) > 0, f"Tensor.grad missing" -def test_forward_batched(): +@pytest.mark.parametrize( + "model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True +) +def test_forward_batched(model): with torch.no_grad(): - model = get_gpt_model() + x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long) y, loss = model.forward(x) @@ -69,8 +110,11 @@ def test_forward_batched(): ), f"Shape mismatch: {y.shape}" -def test_backward_batched(): - model = get_gpt_model() +@pytest.mark.parametrize( + "model", (get_gpt_model_efficient, get_gpt_model_original), indirect=True +) +def test_backward_batched(model): + model = model x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long) y_gt = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)