Skip to content
Merged

ckpt #130

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/lobster/model/lm_base/_lm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,12 @@ def forward(
)


# @classmethod
# def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
# # reproduce behavior before https://github.com/huggingface/transformers/pull/36963
# return []


@add_start_docstrings(
"""LMBase Model with Conditional generatation`language modeling` head on top.""", LMBase_START_DOCSTRING
)
Expand Down Expand Up @@ -1585,11 +1591,9 @@ def __init__(self, config):
self.decoder = nn.Linear((config.n_concepts * config.concept_emb) * 2, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.decoder(features)
# hack to maintain parameter structure
x = torch.nn.functional.linear(features, self.decoder.weight, self.bias)

return x

Expand All @@ -1606,15 +1610,13 @@ def __init__(self, config, input_dim, out_dim):

self.bias = nn.Parameter(torch.zeros(out_dim))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)

# hack to maintain parameter structure
x = torch.nn.functional.linear(x, self.decoder.weight, self.bias)

return x

Expand All @@ -1630,16 +1632,15 @@ def __init__(self, config):
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)

# project back to size of vocabulary with bias
x = self.decoder(x)

# hack to maintain parameter structure
x = torch.nn.functional.linear(x, self.decoder.weight, self.bias)

return x

Expand Down
29 changes: 29 additions & 0 deletions tests/lobster/model/test__cbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile

import pytest
import torch
from torch import Size, Tensor

from lobster.model import LobsterCBMPMLM
Expand Down Expand Up @@ -42,3 +43,31 @@ def test_load_from_s3(self):
)

assert model.config.hidden_size == 408


def test_cbmlm_checkpoint(tmp_path):
print(f"{tmp_path=}")
model = LobsterCBMPMLM("MLM_mini")

for k, v in model.named_parameters():
torch.nn.init.normal_(v)

model.save_pretrained(tmp_path / "checkpoint")

model2 = LobsterCBMPMLM(str(tmp_path / "checkpoint"))

for (k1, v1), (k2, v2) in zip(model.named_parameters(), model2.named_parameters()):
assert k1 == k2
assert torch.equal(v1, v2)
assert not torch.equal(v2, torch.zeros_like(v2)), f"{k1=}, {k2=}"

assert torch.equal(model.model.lm_head.bias, model2.model.lm_head.bias)

input = torch.randn(2, 56)
output = model.model.lm_head.decoder(input)
output2 = model2.model.lm_head.decoder(input)

diff = output - output2
print(f"{diff.abs().max()=}")

torch.testing.assert_close(output, output2)
28 changes: 28 additions & 0 deletions tests/lobster/model/test__mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,31 @@ def test_dynamic_masking(self, model):
# )

# assert model.config.hidden_size == 384


def test_mlm_checkpoint(tmp_path):
print(f"{tmp_path=}")
model = LobsterPMLM("MLM_mini")

for k, v in model.named_parameters():
torch.nn.init.normal_(v)

model.save_pretrained(tmp_path / "checkpoint")

model2 = LobsterPMLM(str(tmp_path / "checkpoint"))

for (k1, v1), (k2, v2) in zip(model.named_parameters(), model2.named_parameters()):
assert k1 == k2
assert torch.equal(v1, v2)
assert not torch.equal(v2, torch.zeros_like(v2)), f"{k1=}, {k2=}"

assert torch.equal(model.model.lm_head.bias, model2.model.lm_head.bias)

input = torch.randn(2, 72)
output = model.model.lm_head.decoder(input)
output2 = model2.model.lm_head.decoder(input)

diff = output - output2
print(f"{diff.abs().max()=}")

torch.testing.assert_close(output, output2)
Loading