From 98b0d3274237fb415f2b60d5a83d262c2a4ee844 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Tue, 1 Jul 2025 09:59:01 -0700 Subject: [PATCH 1/4] add checkpoint tests --- tests/lobster/model/test__cbm.py | 29 +++++++++++++++++++++++++++++ tests/lobster/model/test__mlm.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/tests/lobster/model/test__cbm.py b/tests/lobster/model/test__cbm.py index e866ac4c..0adba6a2 100644 --- a/tests/lobster/model/test__cbm.py +++ b/tests/lobster/model/test__cbm.py @@ -3,6 +3,7 @@ import tempfile import pytest +import torch from torch import Size, Tensor from lobster.model import LobsterCBMPMLM @@ -42,3 +43,31 @@ def test_load_from_s3(self): ) assert model.config.hidden_size == 408 + + +def test_cbmlm_checkpoint(tmp_path): + print(f"{tmp_path=}") + model = LobsterCBMPMLM("MLM_mini") + + for k, v in model.named_parameters(): + torch.nn.init.normal_(v) + + model.save_pretrained(tmp_path / "checkpoint") + + model2 = LobsterCBMPMLM(str(tmp_path / "checkpoint")) + + for (k1, v1), (k2, v2) in zip(model.named_parameters(), model2.named_parameters()): + assert k1 == k2 + assert torch.equal(v1, v2) + assert not torch.equal(v2, torch.zeros_like(v2)), f"{k1=}, {k2=}" + + assert torch.equal(model.model.lm_head.bias, model2.model.lm_head.bias) + + input = torch.randn(2, 56) + output = model.model.lm_head.decoder(input) + output2 = model2.model.lm_head.decoder(input) + + diff = output - output2 + print(f"{diff.abs().max()=}") + + torch.testing.assert_close(output, output2) diff --git a/tests/lobster/model/test__mlm.py b/tests/lobster/model/test__mlm.py index dedda3e9..9cb5533d 100644 --- a/tests/lobster/model/test__mlm.py +++ b/tests/lobster/model/test__mlm.py @@ -100,3 +100,31 @@ def test_dynamic_masking(self, model): # ) # assert model.config.hidden_size == 384 + + +def test_mlm_checkpoint(tmp_path): + print(f"{tmp_path=}") + model = LobsterPMLM("MLM_mini") + + for k, v in model.named_parameters(): + torch.nn.init.normal_(v) + + model.save_pretrained(tmp_path / "checkpoint") + + model2 = LobsterPMLM(str(tmp_path / "checkpoint")) + + for (k1, v1), (k2, v2) in zip(model.named_parameters(), model2.named_parameters()): + assert k1 == k2 + assert torch.equal(v1, v2) + assert not torch.equal(v2, torch.zeros_like(v2)), f"{k1=}, {k2=}" + + assert torch.equal(model.model.lm_head.bias, model2.model.lm_head.bias) + + input = torch.randn(2, 72) + output = model.model.lm_head.decoder(input) + output2 = model2.model.lm_head.decoder(input) + + diff = output - output2 + print(f"{diff.abs().max()=}") + + torch.testing.assert_close(output, output2) From 2bed894a197337b28f39ad5d94531b2bff17c08d Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Tue, 1 Jul 2025 10:56:17 -0700 Subject: [PATCH 2/4] old init behavior --- src/lobster/model/lm_base/_lm_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lobster/model/lm_base/_lm_base.py b/src/lobster/model/lm_base/_lm_base.py index d0d2730b..34ea5d66 100644 --- a/src/lobster/model/lm_base/_lm_base.py +++ b/src/lobster/model/lm_base/_lm_base.py @@ -1419,6 +1419,11 @@ def forward( attentions=outputs.attentions, ) + # reproduce behavior before https://github.com/huggingface/transformers/pull/36963 + @classmethod + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): + return [] + @add_start_docstrings( """LMBase Model with Conditional generatation`language modeling` head on top.""", LMBase_START_DOCSTRING From 11199c2c17d3a642cf4d0a807b09568c0e928406 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Tue, 1 Jul 2025 10:38:37 -0700 Subject: [PATCH 3/4] fix test --- src/lobster/model/lm_base/_lm_base.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/lobster/model/lm_base/_lm_base.py b/src/lobster/model/lm_base/_lm_base.py index 34ea5d66..70514f93 100644 --- a/src/lobster/model/lm_base/_lm_base.py +++ b/src/lobster/model/lm_base/_lm_base.py @@ -1590,11 +1590,9 @@ def __init__(self, config): self.decoder = nn.Linear((config.n_concepts * config.concept_emb) * 2, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): - x = self.decoder(features) + # hack to maintain parameter structure + x = torch.nn.functional.linear(features, self.decoder.weight, self.bias) return x @@ -1611,15 +1609,13 @@ def __init__(self, config, input_dim, out_dim): self.bias = nn.Parameter(torch.zeros(out_dim)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) - # project back to size of vocabulary with bias - x = self.decoder(x) + + # hack to maintain parameter structure + x = torch.nn.functional.linear(x, self.decoder.weight, self.bias) return x @@ -1635,16 +1631,15 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias - x = self.decoder(x) + + # hack to maintain parameter structure + x = torch.nn.functional.linear(x, self.decoder.weight, self.bias) return x From 5033ad26696527a1aef800c08f7c29eec88e5b26 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Tue, 1 Jul 2025 10:58:41 -0700 Subject: [PATCH 4/4] remove old init behavior --- src/lobster/model/lm_base/_lm_base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lobster/model/lm_base/_lm_base.py b/src/lobster/model/lm_base/_lm_base.py index 70514f93..833a2d36 100644 --- a/src/lobster/model/lm_base/_lm_base.py +++ b/src/lobster/model/lm_base/_lm_base.py @@ -1419,10 +1419,11 @@ def forward( attentions=outputs.attentions, ) - # reproduce behavior before https://github.com/huggingface/transformers/pull/36963 - @classmethod - def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - return [] + +# @classmethod +# def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): +# # reproduce behavior before https://github.com/huggingface/transformers/pull/36963 +# return [] @add_start_docstrings(