diff --git a/src/lobster/model/lm_base/_lm_base.py b/src/lobster/model/lm_base/_lm_base.py index d0d2730b..833a2d36 100644 --- a/src/lobster/model/lm_base/_lm_base.py +++ b/src/lobster/model/lm_base/_lm_base.py @@ -1420,6 +1420,12 @@ def forward( ) +# @classmethod +# def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): +# # reproduce behavior before https://github.com/huggingface/transformers/pull/36963 +# return [] + + @add_start_docstrings( """LMBase Model with Conditional generatation`language modeling` head on top.""", LMBase_START_DOCSTRING ) @@ -1585,11 +1591,9 @@ def __init__(self, config): self.decoder = nn.Linear((config.n_concepts * config.concept_emb) * 2, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): - x = self.decoder(features) + # hack to maintain parameter structure + x = torch.nn.functional.linear(features, self.decoder.weight, self.bias) return x @@ -1606,15 +1610,13 @@ def __init__(self, config, input_dim, out_dim): self.bias = nn.Parameter(torch.zeros(out_dim)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) - # project back to size of vocabulary with bias - x = self.decoder(x) + + # hack to maintain parameter structure + x = torch.nn.functional.linear(x, self.decoder.weight, self.bias) return x @@ -1630,16 +1632,15 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias - x = self.decoder(x) + + # hack to maintain parameter structure + x = torch.nn.functional.linear(x, self.decoder.weight, self.bias) return x diff --git a/tests/lobster/model/test__cbm.py b/tests/lobster/model/test__cbm.py index e866ac4c..0adba6a2 100644 --- a/tests/lobster/model/test__cbm.py +++ b/tests/lobster/model/test__cbm.py @@ -3,6 +3,7 @@ import tempfile import pytest +import torch from torch import Size, Tensor from lobster.model import LobsterCBMPMLM @@ -42,3 +43,31 @@ def test_load_from_s3(self): ) assert model.config.hidden_size == 408 + + +def test_cbmlm_checkpoint(tmp_path): + print(f"{tmp_path=}") + model = LobsterCBMPMLM("MLM_mini") + + for k, v in model.named_parameters(): + torch.nn.init.normal_(v) + + model.save_pretrained(tmp_path / "checkpoint") + + model2 = LobsterCBMPMLM(str(tmp_path / "checkpoint")) + + for (k1, v1), (k2, v2) in zip(model.named_parameters(), model2.named_parameters()): + assert k1 == k2 + assert torch.equal(v1, v2) + assert not torch.equal(v2, torch.zeros_like(v2)), f"{k1=}, {k2=}" + + assert torch.equal(model.model.lm_head.bias, model2.model.lm_head.bias) + + input = torch.randn(2, 56) + output = model.model.lm_head.decoder(input) + output2 = model2.model.lm_head.decoder(input) + + diff = output - output2 + print(f"{diff.abs().max()=}") + + torch.testing.assert_close(output, output2) diff --git a/tests/lobster/model/test__mlm.py b/tests/lobster/model/test__mlm.py index dedda3e9..9cb5533d 100644 --- a/tests/lobster/model/test__mlm.py +++ b/tests/lobster/model/test__mlm.py @@ -100,3 +100,31 @@ def test_dynamic_masking(self, model): # ) # assert model.config.hidden_size == 384 + + +def test_mlm_checkpoint(tmp_path): + print(f"{tmp_path=}") + model = LobsterPMLM("MLM_mini") + + for k, v in model.named_parameters(): + torch.nn.init.normal_(v) + + model.save_pretrained(tmp_path / "checkpoint") + + model2 = LobsterPMLM(str(tmp_path / "checkpoint")) + + for (k1, v1), (k2, v2) in zip(model.named_parameters(), model2.named_parameters()): + assert k1 == k2 + assert torch.equal(v1, v2) + assert not torch.equal(v2, torch.zeros_like(v2)), f"{k1=}, {k2=}" + + assert torch.equal(model.model.lm_head.bias, model2.model.lm_head.bias) + + input = torch.randn(2, 72) + output = model.model.lm_head.decoder(input) + output2 = model2.model.lm_head.decoder(input) + + diff = output - output2 + print(f"{diff.abs().max()=}") + + torch.testing.assert_close(output, output2)