diff --git a/scripts/train.py b/scripts/train.py index a9b99fae6..94986a460 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -120,7 +120,7 @@ def create_transformer_layer_config( if hasattr(verifier_config, "text_config"): verifier_config = verifier_config.text_config - transformer_layer_config = config_class( + return config_class( vocab_size=verifier_config.vocab_size, hidden_size=verifier_config.hidden_size, intermediate_size=verifier_config.intermediate_size, @@ -132,9 +132,8 @@ def create_transformer_layer_config( initializer_range=verifier_config.initializer_range, rms_norm_eps=verifier_config.rms_norm_eps, head_dim=getattr(verifier_config, "head_dim", None), + tie_word_embeddings=False, ) - transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 - return transformer_layer_config def main(args: argparse.Namespace): @@ -176,8 +175,6 @@ def main(args: argparse.Namespace): args.verifier_name_or_path, args.num_layers, draft_arch=args.draft_arch ) - # Get model class from registry and create model using its factory method - if args.speculator_type not in SpeculatorModel.registry: raise ValueError( f"Unknown speculator type: {args.speculator_type}. " @@ -185,13 +182,18 @@ def main(args: argparse.Namespace): ) model_class = SpeculatorModel.registry[args.speculator_type] - draft_model = model_class.from_training_args( - verifier_config=transformer_layer_config, - t2d=t2d, - d2t=d2t, - draft_vocab_size=draft_vocab_size, - **vars(args), - ) + if args.from_pretrained: + draft_model = model_class.from_pretrained( + args.from_pretrained, t2d=t2d, d2t=d2t + ) + else: + draft_model = model_class.from_training_args( + verifier_config=transformer_layer_config, + t2d=t2d, + d2t=d2t, + draft_vocab_size=draft_vocab_size, + **vars(args), + ) # Setup dataloaders train_files, val_files = split_files(args.data_path, ratio=0.9) @@ -249,6 +251,12 @@ def parse_args(): default="eagle3", help="Type of speculator model to train (e.g., eagle3)", ) + parser.add_argument( + "--from-pretrained", + type=str, + default="", + help="The pretrained draft model to finetune", + ) parser.add_argument("--data-path", type=str, default="./data") parser.add_argument("--save-path", type=str, default="./checkpoints") parser.add_argument("--epochs", type=int, default=20) @@ -331,7 +339,7 @@ def parse_args(): # RUN WITH: -# torchrun --nnodes=1 --nproc_per_node= scripts/train.py +# torchrun --standalone --nproc_per_node= scripts/train.py # for FSDP training # OR # python scripts/train.py diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index 79bd8544e..e6208fa13 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -155,7 +155,7 @@ def conditional_torch_compile(func): @SpeculatorModel.register("eagle3") class Eagle3DraftModel(SpeculatorModel): config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc] - _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc] + _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc,assignment] "embed_tokens.weight", "verifier_norm.weight", "verifier_lm_head.weight", @@ -168,49 +168,65 @@ class Eagle3DraftModel(SpeculatorModel): "verifier_norm.weight", ] - def __init__( - self, - config: Eagle3SpeculatorConfig, - t2d: torch.Tensor | None, - d2t: torch.Tensor | None, - ): + t2d: torch.Tensor | None + d2t: torch.Tensor | None + + def __init__(self, config: Eagle3SpeculatorConfig): + # Forcibly override config settings + config.tie_word_embeddings = False + config.transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 super().__init__(config=config) + self.hidden_size = config.transformer_layer_config.hidden_size self.draft_vocab_size = config.draft_vocab_size + self.verifier_vocab_size = config.transformer_layer_config.vocab_size + + tl_config = self.config.transformer_layer_config + self._model_definitions = model_classes[tl_config.model_type] + + # VOCAB MAPPINGS + self.use_draft_vocab = self.draft_vocab_size != self.verifier_vocab_size + t2d: torch.Tensor | None = None + d2t: torch.Tensor | None = None + if self.use_draft_vocab: + t2d = torch.zeros((self.verifier_vocab_size,), dtype=torch.bool) + d2t = torch.zeros((self.draft_vocab_size,), dtype=torch.long) + self.register_buffer("t2d", t2d) + self.register_buffer("d2t", d2t) + + # FC LAYER + self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False) - # Verify that if one mapping tensor is provided, the other is as well - if (t2d is None) != (d2t is None): - raise ValueError( - "Both t2d and d2t must be provided together, or both must be None. " - f"Got t2d={'provided' if t2d is not None else 'None'}, " - f"d2t={'provided' if d2t is not None else 'None'}" + # DECODER LAYERS + num_layers = tl_config.num_hidden_layers + fl_class = self._model_definitions.first_layer_class + dl_class = self._model_definitions.decoder_layer_class + layers = [ + fl_class( # first layer + tl_config, + layer_idx=0, + norm_before_residual=self.config.norm_before_residual, ) + ] + layers.extend( # remaining layers + [dl_class(tl_config, layer_idx) for layer_idx in range(1, num_layers)] + ) + self.layers = torch.nn.ModuleList(layers) - # Register buffers - they can be None - if t2d is not None: - self.register_buffer("t2d", t2d) # shape: [verifier_vocab_size], bool - if int(t2d.sum(dtype=torch.long).item()) != self.draft_vocab_size: - raise ValueError( - f"t2d has {int(t2d.sum(dtype=torch.long).item())} non-zero values, " - f"expected {self.draft_vocab_size}." - ) - else: - self.register_buffer("t2d", None) + # ROTARY EMBEDDINGS + # Create a modified config for the rotary embedding to use 2x the hidden size + modified_tl_config = copy.copy(config.transformer_layer_config) + modified_tl_config.hidden_size *= 2 + self.rotary_emb = self._model_definitions.rotary_emb_class(modified_tl_config) - if d2t is not None: - self.register_buffer("d2t", d2t) # shape: [draft_vocab_size], int offsets - if d2t.shape[0] != self.draft_vocab_size: - raise ValueError( - f"d2t.shape[0] ({d2t.shape[0]}) must match" - f" draft_vocab_size ({self.draft_vocab_size})." - ) - else: - self.register_buffer("d2t", None) + # LAYER NORMS + norm_class = self._model_definitions.norm_class + self.norm = norm_class( + self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps + ) + self.verifier_norm = norm_class(self.hidden_size, eps=tl_config.rms_norm_eps) + self.verifier_norm.weight.requires_grad = False - self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False) - self._model_definitions = model_classes[ - config.transformer_layer_config.model_type - ] # Normalize draft path input (gpt-oss only) if config.norm_before_fc: self.input_norm = self._model_definitions.norm_class( @@ -219,55 +235,73 @@ def __init__( ) else: self.input_norm = None - self._setup_decoder_layers( - config.transformer_layer_config, config.norm_before_residual + + # TOKEN EMBEDDINGS + self.embed_tokens = torch.nn.Embedding( + self.verifier_vocab_size, + self.hidden_size, + padding_idx=tl_config.pad_token_id, ) - self.norm = self._model_definitions.norm_class( - self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps + self.embed_tokens.weight.requires_grad = self.config.embed_requires_grad + + # LM HEADS + self.lm_head = torch.nn.Linear( + self.hidden_size, self.draft_vocab_size, bias=False ) - self._setup_rotary_embedding(config.transformer_layer_config) - self._setup_embeddings_and_lm_heads( - config.speculators_config.verifier, t2d, config.embed_requires_grad + self.verifier_lm_head = torch.nn.Linear( + self.hidden_size, self.draft_vocab_size, bias=False ) + self.verifier_lm_head.weight.requires_grad = False - def _setup_decoder_layers( - self, transformer_layer_config: PretrainedConfig, norm_before_residual: bool - ): - num_hidden_layers = transformer_layer_config.num_hidden_layers - # Add first layer - layers = [ - self._model_definitions.first_layer_class( - transformer_layer_config, - layer_idx=0, - norm_before_residual=norm_before_residual, + # Initialize weights to nan + # This ensures it will be easy to detect if the weights are never + # loaded from the verifier model + torch.nn.init.constant_(self.lm_head.weight, torch.nan) + torch.nn.init.constant_(self.embed_tokens.weight, torch.nan) + torch.nn.init.constant_(self.verifier_lm_head.weight, torch.nan) + + def load_vocab_mappings(self, t2d: torch.Tensor | None, d2t: torch.Tensor | None): + if t2d is None and d2t is None: + # Nothing to load, return early + return + elif t2d is None or d2t is None: + raise ValueError( + "Both t2d and d2t must be provided together, or both must be None. " + f"Got t2d={'provided' if t2d is not None else 'None'}, " + f"d2t={'provided' if d2t is not None else 'None'}" ) - ] - # Add additional regular decoder layers - layers.extend( - [ - self._model_definitions.decoder_layer_class( - transformer_layer_config, layer_idx - ) - for layer_idx in range(1, num_hidden_layers) - ] - ) - self.layers = torch.nn.ModuleList(layers) - def _setup_rotary_embedding(self, transformer_layer_config: PretrainedConfig): - # Create a modified config for the rotary embedding to use 2x the hidden size - modified_config = copy.copy(transformer_layer_config) - modified_config.hidden_size = modified_config.hidden_size * 2 - self.rotary_emb = self._model_definitions.rotary_emb_class(modified_config) + if not self.use_draft_vocab: + raise RuntimeError( + "Vocab mappings were provided but are not needed because verifier " + "vocab size equals draft vocab size. Vocab mappings are only required " + "when using a reduced vocab." + ) - def _setup_embeddings_and_lm_heads( - self, - config: VerifierConfig, - t2d: torch.Tensor | None, - embed_requires_grad: bool, - ): - if config.name_or_path is None: + if t2d.shape[0] != self.verifier_vocab_size: + raise ValueError( + f"t2d.shape[0] ({t2d.shape[0]}) must match" + f" verifier_vocab_size ({self.verifier_vocab_size})." + ) + if int(t2d.sum(dtype=torch.long).item()) != self.draft_vocab_size: + raise ValueError( + f"t2d has {int(t2d.sum(dtype=torch.long).item())} non-zero values, " + f"expected {self.draft_vocab_size}." + ) + + if d2t.shape[0] != self.draft_vocab_size: + raise ValueError( + f"d2t.shape[0] ({d2t.shape[0]}) must match" + f" draft_vocab_size ({self.draft_vocab_size})." + ) + + self.load_state_dict({"t2d": t2d, "d2t": d2t}, strict=False) + + def load_verifier_weights(self): # noqa: C901 + verifier_config = self.config.speculators_config.verifier + if verifier_config.name_or_path is None: raise ValueError("VerifierConfig `name_or_path` value is required.") - verifier_model_config = AutoConfig.from_pretrained(config.name_or_path) + verifier_model_config = AutoConfig.from_pretrained(verifier_config.name_or_path) # For multimodal models (Qwen3VL, etc.), extract text_config if hasattr(verifier_model_config, "text_config"): @@ -278,86 +312,71 @@ def _setup_embeddings_and_lm_heads( f"Verifier hidden size {verifier_model_config.hidden_size} does not" f" match draft hidden size {self.hidden_size}." ) - if t2d is not None and t2d.shape[0] != verifier_model_config.vocab_size: - raise ValueError( - f"t2d.shape[0] ({t2d.shape[0]}) must match" - f" verifier_vocab_size ({verifier_model_config.vocab_size})." - ) - # Load embedding and lm_head weights using suffix patterns (model-agnostic) + # Load embedding, lm_head, and norm weights using suffix names (model-agnostic) verifier_weights = load_model_layers( ["embed_tokens.weight", "lm_head.weight", "model.norm.weight"], - config.name_or_path, + verifier_config.name_or_path, ) if "embed_tokens.weight" not in verifier_weights: raise KeyError( - f"Could not find embedding weights in {config.name_or_path}. " + f"Could not find embedding weights in {verifier_config.name_or_path}. " "Expected a key ending with 'embed_tokens.weight'." ) embed_tokens_weight = verifier_weights["embed_tokens.weight"] # Use embed_tokens as fallback for lm_head if not found (tied weights) lm_head_weight = verifier_weights.get("lm_head.weight", embed_tokens_weight) - self.verifier_norm = self._model_definitions.norm_class( - self.hidden_size, - eps=verifier_model_config.rms_norm_eps, - ) - # EMBEDDINGS - self.embed_tokens = torch.nn.Embedding( - verifier_model_config.vocab_size, - self.hidden_size, - padding_idx=verifier_model_config.pad_token_id, - ) - # shape: [verifier_vocab_size, hidden_size] - default_dtype = self.embed_tokens.weight.dtype - embed_tokens_sd = {"weight": embed_tokens_weight.to(default_dtype)} - self.embed_tokens.load_state_dict(embed_tokens_sd) - self.embed_tokens.weight.requires_grad = embed_requires_grad + # Check that embed_tokens hasn't been initialized yet (e.g. by from_pretrained) + if self.embed_tokens.weight.isnan().any(): + embed_tokens_sd = {"weight": embed_tokens_weight} + self.embed_tokens.load_state_dict(embed_tokens_sd) - # LM HEADS - self.lm_head = torch.nn.Linear( - self.hidden_size, self.draft_vocab_size, bias=False - ) - # shape: [hidden_size, draft_vocab_size] - self.verifier_lm_head = torch.nn.Linear( - self.hidden_size, self.draft_vocab_size, bias=False - ) + if self.use_draft_vocab: + if self.t2d is None or not torch.any(self.t2d).item(): + # (not torch.any(self.t2d).item) because t2d is initialized to zeros + raise ValueError( + "t2d tensor hasn't been set. Please call " + "`.load_vocab_mappings(t2d, d2t)` before `.load_verifier_weights()`" + ) - if t2d is not None: # Reduce to limited vocab - lm_head_weight = lm_head_weight.to(device=t2d.device, dtype=default_dtype)[ - t2d.to(torch.bool), : + lm_head_weight = lm_head_weight[ + self.t2d.to(device=lm_head_weight.device, dtype=torch.bool), : ] - else: - # Use full verifier vocab (no masking) - lm_head_weight = lm_head_weight.to(dtype=default_dtype) + if lm_head_weight.shape != self.lm_head.weight.shape: raise ValueError( f"Verifier lm head data shape " f"{lm_head_weight.shape} does not match draft " f"lm head shape {self.lm_head.weight.shape}" ) - self.lm_head.weight.data = lm_head_weight.detach().clone() - self.verifier_lm_head.weight.data = lm_head_weight.detach().clone() - self.verifier_lm_head.weight.requires_grad = False + # Check that lm_head hasn't been initialized yet (e.g. by from_pretrained) + if self.lm_head.weight.isnan().any(): + self.lm_head.load_state_dict( + {"weight": lm_head_weight.detach().clone()}, strict=False + ) + + # Always safe to overwrite verifier_lm_head + self.verifier_lm_head.load_state_dict( + {"weight": lm_head_weight.detach().clone()}, strict=False + ) if "model.norm.weight" not in verifier_weights: warnings.warn( - f"Could not find final norm weights in {config.name_or_path}. " + f"Could not find final norm weights in {verifier_config.name_or_path}. " "Using default initialization (weight=1.0).", UserWarning, stacklevel=2, ) else: - verifier_norm_weight = verifier_weights["model.norm.weight"] - verifier_norm_sd = {"weight": verifier_norm_weight.to(default_dtype)} + # Always safe to overwrite verifier_norm + verifier_norm_sd = {"weight": verifier_weights["model.norm.weight"]} self.verifier_norm.load_state_dict(verifier_norm_sd) - self.verifier_norm.weight.requires_grad = False - @conditional_torch_compile def forward( # noqa: C901 self, @@ -505,6 +524,8 @@ def forward( # noqa: C901 def from_training_args( cls, verifier_config: PretrainedConfig, + t2d: torch.Tensor | None = None, + d2t: torch.Tensor | None = None, **kwargs, ) -> "Eagle3DraftModel": """Create Eagle3 model from training arguments. @@ -540,8 +561,23 @@ def from_training_args( ), ), ) + model = cls(config=config) + model.load_vocab_mappings(t2d, d2t) + model.load_verifier_weights() + return model - return cls(config=config, t2d=kwargs.get("t2d"), d2t=kwargs.get("d2t")) + @classmethod + def from_pretrained( + cls, + *args, + t2d: torch.Tensor | None = None, + d2t: torch.Tensor | None = None, + **kwargs, + ) -> "Eagle3DraftModel": + model: Eagle3DraftModel = super().from_pretrained(*args, **kwargs) # type: ignore[assignment] + model.load_vocab_mappings(t2d, d2t) + model.load_verifier_weights() + return model @staticmethod def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]: diff --git a/src/speculators/train/trainer.py b/src/speculators/train/trainer.py index 332abb3ea..03b2f5a25 100644 --- a/src/speculators/train/trainer.py +++ b/src/speculators/train/trainer.py @@ -4,7 +4,10 @@ import torch import torch.distributed as dist -from torch.distributed.fsdp import FSDPModule +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, +) from torch.utils.data import DataLoader from tqdm import TqdmExperimentalWarning from tqdm.rich import tqdm @@ -88,29 +91,40 @@ def setup_model(self): # Verify model is compatible with training infrastructure SpeculatorModel.verify_training_compatible(self.model) - if self.is_distributed: - apply_fully_sharded(self.model) + load_checkpoint = ( + self.resume_from_checkpoint and self.checkpointer.previous_epoch != -1 + ) - if self.resume_from_checkpoint and self.checkpointer.previous_epoch != -1: - self.checkpointer.load_model_state_dict(self.model) - else: - for m in self.model.layers.children(): # type: ignore[union-attr] - if not isinstance(m, FSDPModule): - continue - acc = torch.accelerator.current_accelerator() - if acc is None: - m.to_empty(device="cuda") # type: ignore[attr-defined] - else: - acc_type = acc.type - m.to_empty(device=acc_type) # type: ignore[attr-defined] - for sub_module in m.modules(): # type: ignore[attr-defined] - if hasattr(sub_module, "reset_parameters"): - sub_module.reset_parameters() # type: ignore[operator] - # todo: Ensure lm_head and embed_tokens are loaded after reset - else: + if not self.is_distributed: + # Single device case self.model.to(self.local_rank) # type: ignore[arg-type] - if self.resume_from_checkpoint and self.checkpointer.previous_epoch != -1: + if load_checkpoint: self.checkpointer.load_model_state_dict(self.model) + return + + # Distributed case + # Capture full state dict on rank 0 before FSDP sharding + full_state_dict = {} + if not load_checkpoint and dist.get_rank() == 0: + full_state_dict = self.model.state_dict() + + apply_fully_sharded(self.model) + + if load_checkpoint: + self.checkpointer.load_model_state_dict(self.model) + else: + # Broadcast full state dict from rank 0 to all ranks + set_model_state_dict( + self.model, + full_state_dict, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + strict=False, + ), + ) + del full_state_dict + dist.barrier() def setup_optimizer(self): # Setup optimizer diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py index d1ddeb35e..40ee91f4c 100644 --- a/src/speculators/train/utils.py +++ b/src/speculators/train/utils.py @@ -67,8 +67,6 @@ def apply_fully_sharded(model: torch.nn.Module): ) for layer in model.layers: # type: ignore[union-attr] - # we apply fully_shard to each DecoderLayer - layer.to_empty(device="meta") fully_shard(layer, mp_policy=mp_policy) fully_shard(model) diff --git a/tests/e2e/vllm/test_finetuning_sanity.py b/tests/e2e/vllm/test_finetuning_sanity.py new file mode 100644 index 000000000..bfc377ba4 --- /dev/null +++ b/tests/e2e/vllm/test_finetuning_sanity.py @@ -0,0 +1,155 @@ +"""E2E test: verify finetuning modifies weights but keeps them close (bounded rel L1). + +Uses relative L1 distance: ||a-b||_1 / (||b||_1 + eps) per tensor. +All trainable tensors must have rel_l1 <= REL_L1_MAX; +at least a few must have rel_l1 > REL_L1_MIN to ensure they changed. + +To see all logs (per-tensor distances, frozen checks, etc.), run: + pytest tests/e2e/vllm/test_finetuning_weight_sanity.py -s --log-cli-level=INFO +Without -s, pytest captures output and you won't see logger output until failure. +""" + +import logging +import subprocess +import sys +from pathlib import Path + +import pytest +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + +from speculators import Eagle3DraftModel + +logger = logging.getLogger(__name__) + + +@pytest.mark.e2e +@pytest.mark.slow +def test_finetuning_weight_sanity(tmp_path: Path): + """Verify finetuning changes weights but keeps rel L1 distance bounded (low LR).""" + # Ensure logs are visible when running with -s (no capture) + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s %(name)s %(message)s", + stream=sys.stderr, + force=True, + ) + + # Learning rate used for training (shared for log and CLI). + FROZEN_KEY_PATTERNS = ("d2t", "embed_tokens.weight", "t2d") + # Relative distance thresholds: all trainable tensors must have rel_l1 <= REL_L1_MAX + # at least MIN_CHANGED tensors must have rel_l1 > REL_L1_MIN + # to ensure weights actually changed. + LR = "1e-5" + REL_L1_MAX = 0.05 + REL_L1_MIN = 1e-4 + MIN_CHANGED = 3 + EPS = 1e-12 + PRETRAINED = "RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3" + DATASET = "nm-testing/sharegpt_llama3_8b_hidden_states" + + # Get initial state dict + model = Eagle3DraftModel.from_pretrained(PRETRAINED) + initial_sd = model.state_dict() + # Remove verifier weights which aren't saved in checkpoints + del initial_sd["verifier_norm.weight"] + del initial_sd["verifier_lm_head.weight"] + del model + + # Run short training with low LR for single epoch + logger.info("Downloading dataset %s", DATASET) + data_dir = snapshot_download(repo_id=DATASET, repo_type="dataset") + logger.info("Dataset at %s", data_dir) + logger.info( + "Running training (1 epoch, lr=%s, save_path=%s)", LR, tmp_path / "ckpt" + ) + result = subprocess.run( # noqa: S603 + [ + sys.executable, + "scripts/train.py", + "--from-pretrained", + PRETRAINED, + "--verifier-name-or-path", + "meta-llama/Llama-3.1-8B-Instruct", + "--data-path", + data_dir, + "--save-path", + str(tmp_path / "ckpt"), + "--log-dir", + str(tmp_path / "logs"), + "--epochs", + "2", + "--lr", + LR, + "--total-seq-len", + "2048", + "--num-workers", + "2", + ], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + logger.error( + "Training failed (returncode=%d). stderr:\n%s", + result.returncode, + result.stderr, + ) + if result.stdout: + logger.debug("Training stdout:\n%s", result.stdout) + assert result.returncode == 0, f"Training failed:\n{result.stderr}" + + logger.info( + "Training finished. Loading finetuned weights from %s", tmp_path / "ckpt" + ) + ckpt_dir = next((tmp_path / "ckpt").glob("*")) + finetuned_sd = {} + for f in ckpt_dir.glob("*.safetensors"): + finetuned_sd.update(load_file(str(f))) + logger.info("Loaded %d parameter tensors from checkpoint", len(finetuned_sd)) + + # Verify same keys + assert set(initial_sd.keys()) == set(finetuned_sd.keys()) + + # These tensors must remain identical (frozen / not trained) + num_changed = 0 + for key in sorted(initial_sd.keys()): + assert initial_sd[key].shape == finetuned_sd[key].shape, ( + f"Shape mismatch for {key}: " + f"initial {initial_sd[key].shape} vs finetuned {finetuned_sd[key].shape}" + ) + + if any(pat in key for pat in FROZEN_KEY_PATTERNS): + assert torch.equal(initial_sd[key], finetuned_sd[key]), ( + f"Tensor {key} must stay identical after finetuning (frozen); " + f"initial and finetuned differ" + ) + logger.info(" [frozen] %s: identical", key) + else: + # Trainable + diff = initial_sd[key] - finetuned_sd[key] + l1_norm_finetuned = finetuned_sd[key].abs().sum() + EPS + rel_l1 = (diff.abs().sum() / l1_norm_finetuned).item() + max_abs = diff.abs().max().item() + mean_abs = diff.abs().mean().item() + logger.info( + " %s: rel_l1=%.3e max|Δ|=%.3e mean|Δ|=%.3e", + key, + rel_l1, + max_abs, + mean_abs, + ) + assert rel_l1 <= REL_L1_MAX, ( + f"Tensor {key} has rel_l1={rel_l1:.4e} > {REL_L1_MAX} " + f"(weights changed too much)" + ) + + if rel_l1 >= REL_L1_MIN: + num_changed += 1 + + assert num_changed >= MIN_CHANGED, ( + f"Expected at least {MIN_CHANGED} tensors with rel_l1 > {REL_L1_MIN}, " + f"got {num_changed}." + ) diff --git a/tests/unit/train/test_setup_model.py b/tests/unit/train/test_setup_model.py new file mode 100644 index 000000000..5b1b20b6e --- /dev/null +++ b/tests/unit/train/test_setup_model.py @@ -0,0 +1,816 @@ +""" +Tests for model weight loading and initialization pathways. + +Covers: +- Trainer.setup_model for single-GPU (fresh + resume) +- SingleGPUCheckpointer save/load round-trip +- from_pretrained save/load round-trip +- Weight precedence: checkpoint > pretrained > verifier > random init +- Distributed fresh init (FSDP + broadcast, mp.spawn) +- Distributed resume from checkpoint (mp.spawn) +- Distributed from_pretrained (mp.spawn) +""" + +import copy +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from safetensors import safe_open +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from transformers.models.llama.configuration_llama import LlamaConfig + +from speculators import SpeculatorsConfig, VerifierConfig +from speculators.models.eagle3 import Eagle3DraftModel, Eagle3SpeculatorConfig +from speculators.proposals.greedy import GreedyTokenProposalConfig +from speculators.train.checkpointer import ( + DistributedCheckpointer, + SingleGPUCheckpointer, +) +from speculators.train.trainer import Trainer, TrainerConfig + +# --------------------------------------------------------------------------- +# Skip decorators +# --------------------------------------------------------------------------- + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required" +) +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="2+ GPUs required", +) + +# --------------------------------------------------------------------------- +# Tiny model constants +# --------------------------------------------------------------------------- + +TINY_LLAMA_CONFIG = LlamaConfig( + vocab_size=64, + hidden_size=32, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=8, + max_position_embeddings=32, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + _attn_implementation="eager", +) + + +# --------------------------------------------------------------------------- +# Helpers (used by both fixtures and mp.spawn workers) +# --------------------------------------------------------------------------- + + +def _make_eagle3_config( + draft_vocab_size: int = 64, + verifier_name_or_path: str | None = None, +) -> Eagle3SpeculatorConfig: + return Eagle3SpeculatorConfig( + transformer_layer_config=copy.deepcopy(TINY_LLAMA_CONFIG), + draft_vocab_size=draft_vocab_size, + norm_before_residual=False, + embed_requires_grad=False, + speculators_config=SpeculatorsConfig( + algorithm="eagle3", + proposal_methods=[GreedyTokenProposalConfig(speculative_tokens=1)], + default_proposal_method="greedy", + verifier=VerifierConfig( + name_or_path=verifier_name_or_path, + architectures=["LlamaForCausalLM"], + ), + ), + ) + + +def _make_vocab_mappings( + verifier_vocab_size: int = 64, + draft_vocab_size: int = 32, +) -> tuple[torch.Tensor, torch.Tensor]: + """Create valid t2d and d2t tensors for testing. + + Selects the first `draft_vocab_size` tokens from the verifier vocab. + t2d: bool[verifier_vocab_size] — True for tokens included in draft vocab. + d2t: long[draft_vocab_size] — maps draft index to verifier index. + """ + t2d = torch.zeros(verifier_vocab_size, dtype=torch.bool) + t2d[:draft_vocab_size] = True + d2t = torch.arange(draft_vocab_size, dtype=torch.long) + return t2d, d2t + + +def _make_tiny_model() -> Eagle3DraftModel: + """Create a tiny Eagle3 model with NaN weights filled.""" + model = Eagle3DraftModel(_make_eagle3_config()) + _fill_nan_weights(model) + return model + + +def _fill_nan_weights(model: Eagle3DraftModel): + """Replace NaN-initialized weights with deterministic values (simulates + what load_verifier_weights does).""" + with torch.no_grad(): + torch.nn.init.ones_(model.embed_tokens.weight) + torch.nn.init.ones_(model.lm_head.weight) + torch.nn.init.ones_(model.verifier_lm_head.weight) + torch.nn.init.ones_(model.verifier_norm.weight) + + +def _make_trainer_no_init( + model, + *, + is_distributed=False, + resume_from_checkpoint=False, + local_rank=0, + save_path="/tmp/test_ckpt", +): + """Create a Trainer instance bypassing __init__ to control setup order.""" + config = TrainerConfig( + lr=1e-4, + num_epochs=1, + save_path=save_path, + resume_from_checkpoint=resume_from_checkpoint, + is_distributed=is_distributed, + local_rank=local_rank, + ) + trainer = Trainer.__new__(Trainer) + trainer.model = model + trainer.config = config + trainer.local_rank = config.local_rank + trainer.is_distributed = config.is_distributed + trainer.resume_from_checkpoint = config.resume_from_checkpoint + trainer.train_loader = MagicMock(__len__=MagicMock(return_value=1)) + trainer.val_loader = None + return trainer + + +def _param_checksums(state_dict: dict[str, torch.Tensor]) -> dict[str, float]: + """Compute per-key checksums for cross-rank comparison.""" + return { + k: v.float().sum().item() + for k, v in state_dict.items() + if isinstance(v, torch.Tensor) + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def eagle3_config(): + return _make_eagle3_config() + + +@pytest.fixture +def tiny_model(): + """Tiny Eagle3 model on CPU with NaN weights filled.""" + return _make_tiny_model() + + +@pytest.fixture +def tiny_model_on_gpu(tiny_model): + """Tiny Eagle3 model moved to cuda:0.""" + return tiny_model.to("cuda:0") + + +@pytest.fixture +def checkpoint_dir(tmp_path, tiny_model_on_gpu): + """Save a checkpoint with trainable weights = 42.0, return the path.""" + with torch.no_grad(): + for p in tiny_model_on_gpu.parameters(): + if p.requires_grad: + p.fill_(42.0) + + ckpt_dir = tmp_path / "ckpt" + checkpointer = SingleGPUCheckpointer(ckpt_dir) + optimizer = torch.optim.AdamW(tiny_model_on_gpu.parameters(), lr=1e-4) + checkpointer.save_checkpoint(tiny_model_on_gpu, optimizer, epoch=0) + return ckpt_dir + + +@pytest.fixture +def pretrained_dir(tmp_path, tiny_model): + """Save a pretrained model with fc=66.0, lm_head=55.0, return the path.""" + with torch.no_grad(): + tiny_model.fc.weight.fill_(66.0) + tiny_model.lm_head.weight.fill_(55.0) + model_dir = tmp_path / "pretrained" + tiny_model.save_pretrained(str(model_dir)) + return model_dir + + +@pytest.fixture +def mock_checkpointer(): + """Mock checkpointer with no previous checkpoint.""" + ckpt = MagicMock() + ckpt.previous_epoch = -1 + return ckpt + + +# =================================================================== +# Single GPU — Fresh Init +# =================================================================== + + +@requires_cuda +def test_single_gpu_fresh_init(tiny_model, mock_checkpointer): + """Fresh single-GPU setup: model moved to device, weights unchanged, + no checkpoint loading.""" + state_before = {k: v.clone() for k, v in tiny_model.state_dict().items()} + + trainer = _make_trainer_no_init(tiny_model, is_distributed=False) + trainer.checkpointer = mock_checkpointer + + trainer.setup_model() + + # Weights should be unchanged (just moved to device) + for k, v in tiny_model.state_dict().items(): + assert torch.allclose(v.cpu().float(), state_before[k].float()), ( + f"Weight {k} changed during fresh init" + ) + + # No checkpoint loading + mock_checkpointer.load_model_state_dict.assert_not_called() + + +# =================================================================== +# Single GPU — Resume from Checkpoint +# =================================================================== + + +@requires_cuda +def test_single_gpu_resume(checkpoint_dir): + """Resume from checkpoint: checkpoint weights loaded, verifier weights + preserved (not overwritten by checkpoint since they're not saved).""" + model = _make_tiny_model() + with torch.no_grad(): + model.verifier_norm.weight.fill_(77.0) + model.verifier_lm_head.weight.fill_(88.0) + + trainer = _make_trainer_no_init( + model, + is_distributed=False, + resume_from_checkpoint=True, + save_path=str(checkpoint_dir), + ) + trainer.checkpointer = SingleGPUCheckpointer(checkpoint_dir) + trainer.setup_model() + + # Trainable weights should match checkpoint (42.0, modulo bf16 round-trip) + for name, param in model.named_parameters(): + if param.requires_grad: + assert torch.allclose(param.cpu().float(), torch.tensor(42.0), atol=0.5), ( + f"Trainable weight {name} not loaded from checkpoint" + ) + + # Verifier weights should be preserved (not in checkpoint) + assert torch.allclose( + model.verifier_norm.weight.cpu().float(), torch.tensor(77.0) + ), "verifier_norm overwritten by checkpoint" + assert torch.allclose( + model.verifier_lm_head.weight.cpu().float(), torch.tensor(88.0) + ), "verifier_lm_head overwritten by checkpoint" + + +# =================================================================== +# Checkpoint Save/Load Round-Trip +# =================================================================== + + +@requires_cuda +def test_checkpoint_save_load_round_trip(checkpoint_dir): + """SingleGPUCheckpointer round-trip: trainable weights preserved, verifier + keys not saved, expected files created.""" + # Verify files + assert (checkpoint_dir / "0" / "model.safetensors").exists() + assert (checkpoint_dir / "0" / "config.json").exists() + assert (checkpoint_dir / "0" / "optimizer_state_dict.pt").exists() + + # Verify verifier-only keys not in saved safetensors + with safe_open( + str(checkpoint_dir / "0" / "model.safetensors"), framework="pt" + ) as f: + saved_keys = set(f.keys()) + for key in Eagle3DraftModel._keys_to_ignore_on_save: + assert key not in saved_keys, f"{key} should not be saved" + + # Load into fresh model and verify trainable weights match + model = _make_tiny_model() + model.to("cuda:0") # type: ignore[arg-type] + checkpointer = SingleGPUCheckpointer(checkpoint_dir) + checkpointer.load_model_state_dict(model) + + for name, param in model.named_parameters(): + if param.requires_grad: + assert torch.allclose(param.cpu().float(), torch.tensor(42.0), atol=0.5), ( + f"Trainable weight {name} not preserved in round-trip" + ) + + +# =================================================================== +# from_pretrained Round-Trip +# =================================================================== + + +def test_from_pretrained_round_trip(tiny_model): + """from_pretrained round-trip: trainable weights preserved, ignored keys + not in saved files, pretrained weights take precedence over verifier.""" + # Set trainable weights to known value + with torch.no_grad(): + for p in tiny_model.parameters(): + if p.requires_grad: + p.fill_(42.0) + + trainable_names = {n for n, p in tiny_model.named_parameters() if p.requires_grad} + original_trainable = { + k: v.clone() for k, v in tiny_model.state_dict().items() if k in trainable_names + } + + with tempfile.TemporaryDirectory() as tmpdir: + tiny_model.save_pretrained(tmpdir) + + # Verify _keys_to_ignore_on_save not in saved files + with safe_open(f"{tmpdir}/model.safetensors", framework="pt") as f: + saved_keys = set(f.keys()) + for key in Eagle3DraftModel._keys_to_ignore_on_save: + assert key not in saved_keys, f"{key} should not be saved" + + # Load (mock load_verifier_weights to avoid HF downloads) + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(tmpdir) + + # Trainable weights should match original + for k, original_v in original_trainable.items(): + loaded_v = loaded.state_dict()[k] + assert torch.allclose(loaded_v.float(), original_v.float(), atol=0.5), ( + f"Weight {k} not preserved in from_pretrained round-trip" + ) + + # lm_head was saved (it's trainable), so from_pretrained loads it. + # Even if load_verifier_weights ran, the NaN guard would keep the + # pretrained value since it's no longer NaN. + assert not loaded.lm_head.weight.isnan().any(), ( + "lm_head should have pretrained value, not NaN" + ) + + +# =================================================================== +# Weight Precedence +# =================================================================== + + +@requires_cuda +def test_weight_precedence(eagle3_config, pretrained_dir, tmp_path): + """Verify weight precedence: checkpoint > pretrained > verifier > random. + + Walks through the full chain in a single test.""" + + # --- Level 5: Random init produces NaN for verifier-loaded weights --- + model = Eagle3DraftModel(eagle3_config) + assert model.embed_tokens.weight.isnan().all(), ( + "embed_tokens should be NaN after random init" + ) + assert model.lm_head.weight.isnan().all(), "lm_head should be NaN after random init" + # fc (trainable) should NOT be NaN — it's randomly initialized + assert not model.fc.weight.isnan().any(), "fc should have random init, not NaN" + + # --- Level 4: Verifier fills NaN weights --- + _fill_nan_weights(model) # simulates load_verifier_weights + assert not model.embed_tokens.weight.isnan().any(), ( + "embed_tokens should be filled by verifier" + ) + assert not model.lm_head.weight.isnan().any(), ( + "lm_head should be filled by verifier" + ) + + # --- Level 2: Pretrained weights take precedence over verifier --- + # pretrained_dir fixture saved lm_head=55.0, fc=66.0 + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(str(pretrained_dir)) + + assert torch.allclose(loaded.lm_head.weight.float(), torch.tensor(55.0)), ( + "pretrained lm_head should not be overwritten by verifier" + ) + assert torch.allclose(loaded.fc.weight.float(), torch.tensor(66.0)), ( + "pretrained fc should be preserved" + ) + + # --- Level 1: Checkpoint overrides everything --- + loaded.to("cuda:0") # type: ignore[arg-type] + with torch.no_grad(): + loaded.fc.weight.fill_(99.0) # checkpoint value + ckpt_dir = str(tmp_path / "ckpt") + checkpointer = SingleGPUCheckpointer(ckpt_dir) + optimizer = torch.optim.AdamW(loaded.parameters(), lr=1e-4) + checkpointer.save_checkpoint(loaded, optimizer, epoch=0) + + # Load checkpoint into a model that had pretrained value (66.0) + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + model3 = Eagle3DraftModel.from_pretrained(str(pretrained_dir)) + model3.to("cuda:0") # type: ignore[arg-type] + checkpointer2 = SingleGPUCheckpointer(ckpt_dir) + checkpointer2.load_model_state_dict(model3) + + assert torch.allclose( + model3.fc.weight.cpu().float(), torch.tensor(99.0), atol=0.5 + ), "checkpoint fc should override pretrained" + + +# =================================================================== +# Distributed helpers +# =================================================================== + + +def _dist_setup(rank, world_size): + """Initialize distributed process group for testing.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def _dist_teardown(): + """Clean up distributed process group.""" + dist.destroy_process_group() + + +def _get_full_state_dict_rank0(model): + """Get unsharded full state dict from FSDP model (only populated on rank 0). + + All ranks must call this (it's a collective op), but only rank 0 + gets the actual tensors.""" + return get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + + +# =================================================================== +# Distributed — Fresh Init +# =================================================================== + + +def _worker_distributed_fresh_init(rank, world_size, results_dir): + """Worker for test_distributed_fresh_init.""" + _dist_setup(rank, world_size) + try: + model = _make_tiny_model() + + # Capture rank 0's pre-FSDP state dict for comparison + pre_fsdp_checksums = _param_checksums(model.state_dict()) if rank == 0 else {} + + trainer = _make_trainer_no_init(model, is_distributed=True, local_rank=rank) + trainer.checkpointer = MagicMock() + trainer.checkpointer.previous_epoch = -1 + + trainer.setup_model() + + # All ranks must call get_model_state_dict (collective op), + # but only rank 0 gets the actual tensors + full_sd = _get_full_state_dict_rank0(model) + + if rank == 0: + checksums = _param_checksums(full_sd) + has_nan = { + k: v.isnan().any().item() + for k, v in full_sd.items() + if isinstance(v, torch.Tensor) and v.is_floating_point() + } + + torch.save( + { + "pre_fsdp_checksums": pre_fsdp_checksums, + "post_fsdp_checksums": checksums, + "has_nan": has_nan, + }, + results_dir / "results.pt", + ) + finally: + _dist_teardown() + + +@requires_multi_gpu +def test_distributed_fresh_init(tmp_path): + """Distributed fresh init: after setup_model, the gathered full state dict + matches rank 0's original pre-FSDP weights and contains no NaN values. + + This verifies that set_model_state_dict(broadcast_from_rank0=True) + correctly distributes rank 0's weights to all ranks, because + get_model_state_dict gathers shards from ALL ranks to reconstruct + the full dict on rank 0.""" + world_size = 2 + results_dir = tmp_path / "results" + results_dir.mkdir() + + mp.spawn( + _worker_distributed_fresh_init, + args=(world_size, results_dir), + nprocs=world_size, + join=True, + ) + + results = torch.load(results_dir / "results.pt", weights_only=False) + + # Post-FSDP gathered state dict should match pre-FSDP state dict from rank 0 + pre = results["pre_fsdp_checksums"] + post = results["post_fsdp_checksums"] + for key in pre: + assert key in post, f"Key {key} missing after FSDP round-trip" + assert pre[key] == pytest.approx(post[key], abs=1e-2), ( + f"Weight {key} changed during FSDP broadcast: " + f"pre={pre[key]}, post={post[key]}" + ) + + # No NaN values in any float parameter + for key, is_nan in results["has_nan"].items(): + assert not is_nan, f"Weight {key} is NaN after distributed setup" + + +# =================================================================== +# Distributed — Resume from Checkpoint +# =================================================================== + + +def _worker_distributed_resume(rank, world_size, ckpt_dir, results_dir): + """Worker for test_distributed_resume.""" + _dist_setup(rank, world_size) + try: + model = _make_tiny_model() + # Set verifier weights to known value before FSDP + with torch.no_grad(): + model.verifier_norm.weight.fill_(77.0) + model.verifier_lm_head.weight.fill_(88.0) + + trainer = _make_trainer_no_init( + model, + is_distributed=True, + resume_from_checkpoint=True, + local_rank=rank, + save_path=ckpt_dir, + ) + trainer.checkpointer = DistributedCheckpointer(ckpt_dir) + trainer.setup_model() + + # All ranks must call (collective op), only rank 0 gets data + full_sd = _get_full_state_dict_rank0(model) + + if rank == 0: + checksums = _param_checksums(full_sd) + verifier_norm_val = full_sd["verifier_norm.weight"].float().mean().item() + verifier_lm_head_val = ( + full_sd["verifier_lm_head.weight"].float().mean().item() + ) + + torch.save( + { + "checksums": checksums, + "verifier_norm_val": verifier_norm_val, + "verifier_lm_head_val": verifier_lm_head_val, + }, + results_dir / "results.pt", + ) + finally: + _dist_teardown() + + +@requires_multi_gpu +def test_distributed_resume(checkpoint_dir, tmp_path): + """Distributed resume: checkpoint weights loaded correctly, verifier + weights preserved (not overwritten by checkpoint).""" + world_size = min(torch.cuda.device_count(), 2) + results_dir = tmp_path / "results" + results_dir.mkdir() + + mp.spawn( + _worker_distributed_resume, + args=(world_size, str(checkpoint_dir), results_dir), + nprocs=world_size, + join=True, + ) + + results = torch.load(results_dir / "results.pt", weights_only=False) + + # Verifier weights should be preserved (not in checkpoint) + assert results["verifier_norm_val"] == pytest.approx(77.0, abs=0.1), ( + "verifier_norm overwritten by checkpoint" + ) + assert results["verifier_lm_head_val"] == pytest.approx(88.0, abs=0.1), ( + "verifier_lm_head overwritten by checkpoint" + ) + + +# =================================================================== +# Distributed — from_pretrained +# =================================================================== + + +def _worker_distributed_from_pretrained(rank, world_size, model_dir, results_dir): + """Worker for test_distributed_from_pretrained.""" + _dist_setup(rank, world_size) + try: + # Load model from pretrained (mock verifier loading) + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + model = Eagle3DraftModel.from_pretrained(model_dir) + _fill_nan_weights(model) # fill verifier weights post-load + + trainer = _make_trainer_no_init(model, is_distributed=True, local_rank=rank) + trainer.checkpointer = MagicMock() + trainer.checkpointer.previous_epoch = -1 + + trainer.setup_model() + + # All ranks must call (collective op), only rank 0 gets data + full_sd = _get_full_state_dict_rank0(model) + + if rank == 0: + checksums = _param_checksums(full_sd) + fc_val = full_sd["fc.weight"].float().mean().item() + + torch.save( + {"checksums": checksums, "fc_val": fc_val}, + results_dir / "results.pt", + ) + finally: + _dist_teardown() + + +@requires_multi_gpu +def test_distributed_from_pretrained(pretrained_dir, tmp_path): + """Model loaded via from_pretrained should have correct weights after FSDP + setup, with pretrained weight values preserved through the broadcast.""" + world_size = min(torch.cuda.device_count(), 2) + results_dir = tmp_path / "results" + results_dir.mkdir() + + mp.spawn( + _worker_distributed_from_pretrained, + args=(world_size, str(pretrained_dir), results_dir), + nprocs=world_size, + join=True, + ) + + results = torch.load(results_dir / "results.pt", weights_only=False) + + # Pretrained fc weight should be preserved through FSDP setup + assert results["fc_val"] == pytest.approx(66.0, abs=0.5), ( + "Pretrained fc weight not preserved through FSDP broadcast" + ) + + +# =================================================================== +# Vocab Mapping Loading (t2d / d2t) +# =================================================================== + +DRAFT_VOCAB_SIZE = 32 # < TINY_LLAMA_CONFIG.vocab_size (64) + + +@pytest.fixture +def draft_vocab_config(): + """Eagle3 config with draft_vocab_size < verifier_vocab_size.""" + return _make_eagle3_config(draft_vocab_size=DRAFT_VOCAB_SIZE) + + +@pytest.fixture +def vocab_mappings(): + """Valid (t2d, d2t) pair for verifier_vocab=64, draft_vocab=32.""" + return _make_vocab_mappings( + verifier_vocab_size=TINY_LLAMA_CONFIG.vocab_size, + draft_vocab_size=DRAFT_VOCAB_SIZE, + ) + + +def test_load_vocab_mappings(draft_vocab_config, vocab_mappings): + """load_vocab_mappings stores t2d/d2t buffers correctly.""" + t2d, d2t = vocab_mappings + model = Eagle3DraftModel(draft_vocab_config) + + # Before loading: buffers exist but are zeros + assert model.t2d is not None + assert not model.t2d.any(), "t2d should be all zeros before loading" + assert model.d2t is not None + assert (model.d2t == 0).all(), "d2t should be all zeros before loading" + + model.load_vocab_mappings(t2d, d2t) + + # After loading: buffers match inputs + assert torch.equal(model.t2d, t2d), "t2d not loaded correctly" + assert torch.equal(model.d2t, d2t), "d2t not loaded correctly" + + +def test_load_vocab_mappings_validation(draft_vocab_config, vocab_mappings): + """load_vocab_mappings raises on invalid inputs.""" + t2d, d2t = vocab_mappings + model = Eagle3DraftModel(draft_vocab_config) + + # Only one of t2d/d2t provided + with pytest.raises(ValueError, match="Both t2d and d2t must be provided"): + model.load_vocab_mappings(t2d, None) + with pytest.raises(ValueError, match="Both t2d and d2t must be provided"): + model.load_vocab_mappings(None, d2t) + + # Wrong t2d shape + with pytest.raises(ValueError, match="t2d.shape"): + model.load_vocab_mappings(torch.ones(10, dtype=torch.bool), d2t) + + # Wrong d2t shape + with pytest.raises(ValueError, match="d2t.shape"): + model.load_vocab_mappings(t2d, torch.zeros(10, dtype=torch.long)) + + # Wrong number of True values in t2d + bad_t2d = torch.ones(TINY_LLAMA_CONFIG.vocab_size, dtype=torch.bool) + with pytest.raises(ValueError, match="non-zero values"): + model.load_vocab_mappings(bad_t2d, d2t) + + +def test_load_vocab_mappings_not_needed(): + """load_vocab_mappings raises when vocab sizes match (no mapping needed).""" + config = _make_eagle3_config(draft_vocab_size=64) # same as verifier + model = Eagle3DraftModel(config) + t2d, d2t = _make_vocab_mappings(verifier_vocab_size=64, draft_vocab_size=64) + + with pytest.raises(RuntimeError, match="not needed"): + model.load_vocab_mappings(t2d, d2t) + + +def test_from_training_args_loads_vocab_mappings(vocab_mappings): + """from_training_args passes t2d/d2t through to load_vocab_mappings.""" + t2d, d2t = vocab_mappings + + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + model = Eagle3DraftModel.from_training_args( + verifier_config=copy.deepcopy(TINY_LLAMA_CONFIG), + t2d=t2d, + d2t=d2t, + draft_vocab_size=DRAFT_VOCAB_SIZE, + norm_before_residual=False, + ttt_steps=1, + verifier_name_or_path="dummy", + ) + + assert model.t2d is not None, "t2d is None after from_training_args" + assert model.d2t is not None, "d2t is None after from_training_args" + assert torch.equal(model.t2d, t2d), "t2d not loaded via from_training_args" + assert torch.equal(model.d2t, d2t), "d2t not loaded via from_training_args" + + +def test_from_pretrained_loads_vocab_mappings_from_kwargs( + tmp_path, draft_vocab_config, vocab_mappings +): + """from_pretrained loads t2d/d2t passed as kwargs.""" + t2d, d2t = vocab_mappings + + # Save a model without vocab mappings in the safetensors + model = Eagle3DraftModel(draft_vocab_config) + _fill_nan_weights(model) + model_dir = tmp_path / "pretrained_no_vocab" + model.save_pretrained(str(model_dir)) + + # Load with t2d/d2t passed as kwargs + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(str(model_dir), t2d=t2d, d2t=d2t) + + assert loaded.t2d is not None, "t2d is None after from_pretrained" + assert loaded.d2t is not None, "d2t is None after from_pretrained" + assert torch.equal(loaded.t2d, t2d), "t2d not loaded from kwargs in from_pretrained" + assert torch.equal(loaded.d2t, d2t), "d2t not loaded from kwargs in from_pretrained" + + +def test_from_pretrained_loads_vocab_mappings_from_saved( + tmp_path, draft_vocab_config, vocab_mappings +): + """from_pretrained loads t2d/d2t from saved safetensors when not passed + as kwargs.""" + t2d, d2t = vocab_mappings + + # Save model WITH vocab mappings loaded + model = Eagle3DraftModel(draft_vocab_config) + _fill_nan_weights(model) + model.load_vocab_mappings(t2d, d2t) + model_dir = tmp_path / "pretrained_with_vocab" + model.save_pretrained(str(model_dir)) + + # Verify t2d/d2t are in the saved safetensors + with safe_open(str(model_dir / "model.safetensors"), framework="pt") as f: + saved_keys = set(f.keys()) + assert "t2d" in saved_keys, "t2d should be saved in safetensors" + assert "d2t" in saved_keys, "d2t should be saved in safetensors" + + # Load WITHOUT passing t2d/d2t — should come from safetensors + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(str(model_dir)) + + assert loaded.t2d is not None, "t2d is None after from_pretrained" + assert loaded.d2t is not None, "d2t is None after from_pretrained" + assert torch.equal(loaded.t2d, t2d), "t2d not loaded from saved safetensors" + assert torch.equal(loaded.d2t, d2t), "d2t not loaded from saved safetensors"