From 5164d836b48151830c55a21253556bdf8a6de37c Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Mon, 2 Mar 2026 16:21:12 +0000 Subject: [PATCH 1/7] Separate model initialize and verifier weight loading and fix from_pretrained Signed-off-by: Fynn Schmitt-Ulms --- scripts/train.py | 36 ++-- src/speculators/model.py | 3 + src/speculators/models/eagle3/core.py | 285 +++++++++++++++----------- 3 files changed, 186 insertions(+), 138 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index a9b99fae6..c752a458d 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -132,6 +132,7 @@ def create_transformer_layer_config( initializer_range=verifier_config.initializer_range, rms_norm_eps=verifier_config.rms_norm_eps, head_dim=getattr(verifier_config, "head_dim", None), + tie_word_embeddings=False, ) transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 return transformer_layer_config @@ -176,8 +177,6 @@ def main(args: argparse.Namespace): args.verifier_name_or_path, args.num_layers, draft_arch=args.draft_arch ) - # Get model class from registry and create model using its factory method - if args.speculator_type not in SpeculatorModel.registry: raise ValueError( f"Unknown speculator type: {args.speculator_type}. " @@ -185,13 +184,18 @@ def main(args: argparse.Namespace): ) model_class = SpeculatorModel.registry[args.speculator_type] - draft_model = model_class.from_training_args( - verifier_config=transformer_layer_config, - t2d=t2d, - d2t=d2t, - draft_vocab_size=draft_vocab_size, - **vars(args), - ) + if args.from_pretrained: + draft_model = model_class.from_pretrained( + args.from_pretrained, t2d=t2d, d2t=d2t + ) + else: + draft_model = model_class.from_training_args( + verifier_config=transformer_layer_config, + t2d=t2d, + d2t=d2t, + draft_vocab_size=draft_vocab_size, + **vars(args), + ) # Setup dataloaders train_files, val_files = split_files(args.data_path, ratio=0.9) @@ -249,6 +253,12 @@ def parse_args(): default="eagle3", help="Type of speculator model to train (e.g., eagle3)", ) + parser.add_argument( + "--from-pretrained", + type=str, + default="", + help="The pretrained draft model to finetune", + ) parser.add_argument("--data-path", type=str, default="./data") parser.add_argument("--save-path", type=str, default="./checkpoints") parser.add_argument("--epochs", type=int, default=20) @@ -298,12 +308,6 @@ def parse_args(): default=True, help="Toggle normalization before residual connections (default: True)", ) - parser.add_argument( - "--embed-requires-grad", - action=argparse.BooleanOptionalAction, - default=False, - help="Whether to train embedding layer weights (default: False)", - ) # Dataloader parameters parser.add_argument( "--num-workers", type=int, default=12, help="Number of dataloader workers" @@ -331,7 +335,7 @@ def parse_args(): # RUN WITH: -# torchrun --nnodes=1 --nproc_per_node= scripts/train.py +# torchrun --standalone --nproc_per_node= scripts/train.py # for FSDP training # OR # python scripts/train.py diff --git a/src/speculators/model.py b/src/speculators/model.py index 8fd80fb76..50646f78f 100644 --- a/src/speculators/model.py +++ b/src/speculators/model.py @@ -141,6 +141,9 @@ def from_pretrained( "provided to load a SpeculatorModel." ) + config.tie_word_embeddings = False + config.transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 + if cls is SpeculatorModel: # generic call to from_pretrained on this class, need to resolve the # specific model class to use for loading based on the config and registry diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index 79bd8544e..ec2051d62 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -168,49 +168,60 @@ class Eagle3DraftModel(SpeculatorModel): "verifier_norm.weight", ] - def __init__( - self, - config: Eagle3SpeculatorConfig, - t2d: torch.Tensor | None, - d2t: torch.Tensor | None, - ): + def __init__(self, config: Eagle3SpeculatorConfig): super().__init__(config=config) self.hidden_size = config.transformer_layer_config.hidden_size self.draft_vocab_size = config.draft_vocab_size + self.verifier_vocab_size = config.transformer_layer_config.vocab_size + + tl_config = self.config.transformer_layer_config + self._model_definitions = model_classes[tl_config.model_type] + + # VOCAB MAPPINGS + self.use_draft_vocab = self.draft_vocab_size != self.verifier_vocab_size + t2d = None + d2t = None + if self.use_draft_vocab: + # Use NaNs as placeholder so that it's clear if these aren't updated + # todo(fynn): NaNs might not work with the dtypes here + t2d = torch.zeros((self.verifier_vocab_size,), dtype=torch.bool) + d2t = torch.zeros((self.draft_vocab_size,), dtype=torch.long) + self.register_buffer("t2d", t2d) + self.register_buffer("d2t", d2t) + + # FC LAYER + self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False) - # Verify that if one mapping tensor is provided, the other is as well - if (t2d is None) != (d2t is None): - raise ValueError( - "Both t2d and d2t must be provided together, or both must be None. " - f"Got t2d={'provided' if t2d is not None else 'None'}, " - f"d2t={'provided' if d2t is not None else 'None'}" + # DECODER LAYERS + num_layers = tl_config.num_hidden_layers + fl_class = self._model_definitions.first_layer_class + dl_class = self._model_definitions.decoder_layer_class + layers = [ + fl_class( # first layer + tl_config, + layer_idx=0, + norm_before_residual=self.config.norm_before_residual, ) + ] + layers.extend( # remaining layers + [dl_class(tl_config, layer_idx) for layer_idx in range(1, num_layers)] + ) + self.layers = torch.nn.ModuleList(layers) - # Register buffers - they can be None - if t2d is not None: - self.register_buffer("t2d", t2d) # shape: [verifier_vocab_size], bool - if int(t2d.sum(dtype=torch.long).item()) != self.draft_vocab_size: - raise ValueError( - f"t2d has {int(t2d.sum(dtype=torch.long).item())} non-zero values, " - f"expected {self.draft_vocab_size}." - ) - else: - self.register_buffer("t2d", None) + # ROTARY EMBEDDINGS + # Create a modified config for the rotary embedding to use 2x the hidden size + modified_tl_config = copy.copy(config.transformer_layer_config) + modified_tl_config.hidden_size *= 2 + self.rotary_emb = self._model_definitions.rotary_emb_class(modified_tl_config) - if d2t is not None: - self.register_buffer("d2t", d2t) # shape: [draft_vocab_size], int offsets - if d2t.shape[0] != self.draft_vocab_size: - raise ValueError( - f"d2t.shape[0] ({d2t.shape[0]}) must match" - f" draft_vocab_size ({self.draft_vocab_size})." - ) - else: - self.register_buffer("d2t", None) + # LAYER NORMS + norm_class = self._model_definitions.norm_class + self.norm = norm_class( + self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps + ) + self.verifier_norm = norm_class(self.hidden_size, eps=tl_config.rms_norm_eps) + self.verifier_norm.weight.requires_grad = False - self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False) - self._model_definitions = model_classes[ - config.transformer_layer_config.model_type - ] # Normalize draft path input (gpt-oss only) if config.norm_before_fc: self.input_norm = self._model_definitions.norm_class( @@ -219,55 +230,63 @@ def __init__( ) else: self.input_norm = None - self._setup_decoder_layers( - config.transformer_layer_config, config.norm_before_residual + + # TOKEN EMBEDDINGS + self.embed_tokens = torch.nn.Embedding( + self.verifier_vocab_size, + self.hidden_size, + padding_idx=tl_config.pad_token_id, ) - self.norm = self._model_definitions.norm_class( - self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps + self.embed_tokens.weight.requires_grad = self.config.embed_requires_grad + + # LM HEADS + self.lm_head = torch.nn.Linear( + self.hidden_size, self.draft_vocab_size, bias=False ) - self._setup_rotary_embedding(config.transformer_layer_config) - self._setup_embeddings_and_lm_heads( - config.speculators_config.verifier, t2d, config.embed_requires_grad + self.verifier_lm_head = torch.nn.Linear( + self.hidden_size, self.draft_vocab_size, bias=False ) + self.verifier_lm_head.weight.requires_grad = False - def _setup_decoder_layers( - self, transformer_layer_config: PretrainedConfig, norm_before_residual: bool - ): - num_hidden_layers = transformer_layer_config.num_hidden_layers - # Add first layer - layers = [ - self._model_definitions.first_layer_class( - transformer_layer_config, - layer_idx=0, - norm_before_residual=norm_before_residual, + # Initialize weights to nan + # This ensures it will be easy to detect if the weights are never + # loaded from the verifier model + torch.nn.init.constant_(self.lm_head.weight, torch.nan) + torch.nn.init.constant_(self.embed_tokens.weight, torch.nan) + torch.nn.init.constant_(self.verifier_lm_head.weight, torch.nan) + + def load_vocab_mappings(self, t2d: torch.Tensor, d2t: torch.Tensor): + if not self.use_draft_vocab: + raise RuntimeError( + "Vocab mappings were provided but are not needed because verifier " + "vocab size equals draft vocab size. Vocab mappings are only required " + "when using a reduced vocab." ) - ] - # Add additional regular decoder layers - layers.extend( - [ - self._model_definitions.decoder_layer_class( - transformer_layer_config, layer_idx - ) - for layer_idx in range(1, num_hidden_layers) - ] - ) - self.layers = torch.nn.ModuleList(layers) - def _setup_rotary_embedding(self, transformer_layer_config: PretrainedConfig): - # Create a modified config for the rotary embedding to use 2x the hidden size - modified_config = copy.copy(transformer_layer_config) - modified_config.hidden_size = modified_config.hidden_size * 2 - self.rotary_emb = self._model_definitions.rotary_emb_class(modified_config) + if t2d.shape[0] != self.verifier_vocab_size: + raise ValueError( + f"t2d.shape[0] ({t2d.shape[0]}) must match" + f" verifier_vocab_size ({self.verifier_vocab_size})." + ) + if int(t2d.sum(dtype=torch.long).item()) != self.draft_vocab_size: + raise ValueError( + f"t2d has {int(t2d.sum(dtype=torch.long).item())} non-zero values, " + f"expected {self.draft_vocab_size}." + ) - def _setup_embeddings_and_lm_heads( - self, - config: VerifierConfig, - t2d: torch.Tensor | None, - embed_requires_grad: bool, - ): - if config.name_or_path is None: + if d2t.shape[0] != self.draft_vocab_size: + raise ValueError( + f"d2t.shape[0] ({d2t.shape[0]}) must match" + f" draft_vocab_size ({self.draft_vocab_size})." + ) + + self.load_state_dict({"t2d": t2d, "d2t": d2t}, strict=False) + + def load_verifier_weights(self): + verifier_config = self.config.speculators_config.verifier + if verifier_config.name_or_path is None: raise ValueError("VerifierConfig `name_or_path` value is required.") - verifier_model_config = AutoConfig.from_pretrained(config.name_or_path) + verifier_model_config = AutoConfig.from_pretrained(verifier_config.name_or_path) # For multimodal models (Qwen3VL, etc.), extract text_config if hasattr(verifier_model_config, "text_config"): @@ -278,86 +297,71 @@ def _setup_embeddings_and_lm_heads( f"Verifier hidden size {verifier_model_config.hidden_size} does not" f" match draft hidden size {self.hidden_size}." ) - if t2d is not None and t2d.shape[0] != verifier_model_config.vocab_size: - raise ValueError( - f"t2d.shape[0] ({t2d.shape[0]}) must match" - f" verifier_vocab_size ({verifier_model_config.vocab_size})." - ) - # Load embedding and lm_head weights using suffix patterns (model-agnostic) + # Load embedding, lm_head, and norm weights using suffix names (model-agnostic) verifier_weights = load_model_layers( ["embed_tokens.weight", "lm_head.weight", "model.norm.weight"], - config.name_or_path, + verifier_config.name_or_path, ) if "embed_tokens.weight" not in verifier_weights: raise KeyError( - f"Could not find embedding weights in {config.name_or_path}. " + f"Could not find embedding weights in {verifier_config.name_or_path}. " "Expected a key ending with 'embed_tokens.weight'." ) embed_tokens_weight = verifier_weights["embed_tokens.weight"] # Use embed_tokens as fallback for lm_head if not found (tied weights) lm_head_weight = verifier_weights.get("lm_head.weight", embed_tokens_weight) - self.verifier_norm = self._model_definitions.norm_class( - self.hidden_size, - eps=verifier_model_config.rms_norm_eps, - ) - # EMBEDDINGS - self.embed_tokens = torch.nn.Embedding( - verifier_model_config.vocab_size, - self.hidden_size, - padding_idx=verifier_model_config.pad_token_id, - ) - # shape: [verifier_vocab_size, hidden_size] - default_dtype = self.embed_tokens.weight.dtype - embed_tokens_sd = {"weight": embed_tokens_weight.to(default_dtype)} - self.embed_tokens.load_state_dict(embed_tokens_sd) - self.embed_tokens.weight.requires_grad = embed_requires_grad + # Check that embed_tokens hasn't been initialized yet (e.g. by from_pretrained) + if self.embed_tokens.weight.isnan().any(): + embed_tokens_sd = {"weight": embed_tokens_weight} + self.embed_tokens.load_state_dict(embed_tokens_sd) - # LM HEADS - self.lm_head = torch.nn.Linear( - self.hidden_size, self.draft_vocab_size, bias=False - ) - # shape: [hidden_size, draft_vocab_size] - self.verifier_lm_head = torch.nn.Linear( - self.hidden_size, self.draft_vocab_size, bias=False - ) + if self.use_draft_vocab: + if self.t2d is None or not torch.any(self.t2d).item: + # (not torch.any(self.t2d).item) because t2d is initialized to zeros + raise ValueError( + "t2d tensor hasn't been set. Please call " + "`.load_vocab_mappings(t2d, d2t)` before `.load_verifier_weights()`" + ) - if t2d is not None: # Reduce to limited vocab - lm_head_weight = lm_head_weight.to(device=t2d.device, dtype=default_dtype)[ - t2d.to(torch.bool), : + lm_head_weight = lm_head_weight[ + self.t2d.to(device=lm_head_weight.device, dtype=torch.bool), : ] - else: - # Use full verifier vocab (no masking) - lm_head_weight = lm_head_weight.to(dtype=default_dtype) + if lm_head_weight.shape != self.lm_head.weight.shape: raise ValueError( f"Verifier lm head data shape " f"{lm_head_weight.shape} does not match draft " f"lm head shape {self.lm_head.weight.shape}" ) - self.lm_head.weight.data = lm_head_weight.detach().clone() - self.verifier_lm_head.weight.data = lm_head_weight.detach().clone() - self.verifier_lm_head.weight.requires_grad = False + # Check that lm_head hasn't been initialized yet (e.g. by from_pretrained) + if self.lm_head.weight.isnan().any(): + self.lm_head.load_state_dict( + {"weight": lm_head_weight.detach().clone()}, strict=False + ) + + # Always safe to overwrite verifier_lm_head + self.verifier_lm_head.load_state_dict( + {"weight": lm_head_weight.detach().clone()}, strict=False + ) if "model.norm.weight" not in verifier_weights: warnings.warn( - f"Could not find final norm weights in {config.name_or_path}. " + f"Could not find final norm weights in {verifier_config.name_or_path}. " "Using default initialization (weight=1.0).", UserWarning, stacklevel=2, ) else: - verifier_norm_weight = verifier_weights["model.norm.weight"] - verifier_norm_sd = {"weight": verifier_norm_weight.to(default_dtype)} + # Always safe to overwrite verifier_norm + verifier_norm_sd = {"weight": verifier_weights["model.norm.weight"]} self.verifier_norm.load_state_dict(verifier_norm_sd) - self.verifier_norm.weight.requires_grad = False - @conditional_torch_compile def forward( # noqa: C901 self, @@ -505,6 +509,8 @@ def forward( # noqa: C901 def from_training_args( cls, verifier_config: PretrainedConfig, + t2d: torch.Tensor | None = None, + d2t: torch.Tensor | None = None, **kwargs, ) -> "Eagle3DraftModel": """Create Eagle3 model from training arguments. @@ -541,7 +547,42 @@ def from_training_args( ), ) - return cls(config=config, t2d=kwargs.get("t2d"), d2t=kwargs.get("d2t")) + model = cls(config=config) + # Verify that if one mapping tensor is provided, the other is as well + if (t2d is None) != (d2t is None): + raise ValueError( + "Both t2d and d2t must be provided together, or both must be None. " + f"Got t2d={'provided' if t2d is not None else 'None'}, " + f"d2t={'provided' if d2t is not None else 'None'}" + ) + elif t2d is not None: + model.load_vocab_mappings(t2d, d2t) + + model.load_verifier_weights() + return model + + @classmethod + def from_pretrained( + cls, + *args, + t2d: torch.Tensor | None = None, + d2t: torch.Tensor | None = None, + **kwargs, + ) -> "Eagle3DraftModel": + model: Eagle3DraftModel = super().from_pretrained(*args, **kwargs) + + # Verify that if one mapping tensor is provided, the other is as well + if (t2d is None) != (d2t is None): + raise ValueError( + "Both t2d and d2t must be provided together, or both must be None. " + f"Got t2d={'provided' if t2d is not None else 'None'}, " + f"d2t={'provided' if d2t is not None else 'None'}" + ) + elif t2d is not None: + model.load_vocab_mappings(t2d, d2t) + + model.load_verifier_weights() + return model @staticmethod def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]: From 8ad4d1b349dba297a57e3ecf1fcdc04a8d0ee4c5 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 4 Mar 2026 23:41:27 +0000 Subject: [PATCH 2/7] Improve handling of fully_shard There were some potentially robustness issues with the previous implementation. In particular, we were resetting parameters on the model after loading the verifier weights. Although we were mostly avoiding the verifier weights there were some risks with this approach. This commit simplifies the logic and adds an additional broadcast step from rank0 to ensure all ranks have the same weight copy before training begins. Signed-off-by: Fynn Schmitt-Ulms --- src/speculators/models/eagle3/core.py | 56 +++++++++++---------------- src/speculators/train/trainer.py | 56 +++++++++++++++++---------- src/speculators/train/utils.py | 2 - 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index ec2051d62..e5fd514da 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -155,7 +155,7 @@ def conditional_torch_compile(func): @SpeculatorModel.register("eagle3") class Eagle3DraftModel(SpeculatorModel): config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc] - _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc] + _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc,assignment] "embed_tokens.weight", "verifier_norm.weight", "verifier_lm_head.weight", @@ -179,15 +179,15 @@ def __init__(self, config: Eagle3SpeculatorConfig): # VOCAB MAPPINGS self.use_draft_vocab = self.draft_vocab_size != self.verifier_vocab_size - t2d = None - d2t = None + self.t2d: torch.Tensor | None = None + self.d2t: torch.Tensor | None = None if self.use_draft_vocab: # Use NaNs as placeholder so that it's clear if these aren't updated # todo(fynn): NaNs might not work with the dtypes here - t2d = torch.zeros((self.verifier_vocab_size,), dtype=torch.bool) - d2t = torch.zeros((self.draft_vocab_size,), dtype=torch.long) - self.register_buffer("t2d", t2d) - self.register_buffer("d2t", d2t) + self.t2d = torch.zeros((self.verifier_vocab_size,), dtype=torch.bool) + self.d2t = torch.zeros((self.draft_vocab_size,), dtype=torch.long) + self.register_buffer("t2d", self.t2d) + self.register_buffer("d2t", self.d2t) # FC LAYER self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False) @@ -255,7 +255,17 @@ def __init__(self, config: Eagle3SpeculatorConfig): torch.nn.init.constant_(self.embed_tokens.weight, torch.nan) torch.nn.init.constant_(self.verifier_lm_head.weight, torch.nan) - def load_vocab_mappings(self, t2d: torch.Tensor, d2t: torch.Tensor): + def load_vocab_mappings(self, t2d: torch.Tensor | None, d2t: torch.Tensor | None): + if t2d is None and d2t is None: + # Nothing to load, return early + return + elif t2d is None or d2t is None: + raise ValueError( + "Both t2d and d2t must be provided together, or both must be None. " + f"Got t2d={'provided' if t2d is not None else 'None'}, " + f"d2t={'provided' if d2t is not None else 'None'}" + ) + if not self.use_draft_vocab: raise RuntimeError( "Vocab mappings were provided but are not needed because verifier " @@ -282,7 +292,7 @@ def load_vocab_mappings(self, t2d: torch.Tensor, d2t: torch.Tensor): self.load_state_dict({"t2d": t2d, "d2t": d2t}, strict=False) - def load_verifier_weights(self): + def load_verifier_weights(self): # noqa: C901 verifier_config = self.config.speculators_config.verifier if verifier_config.name_or_path is None: raise ValueError("VerifierConfig `name_or_path` value is required.") @@ -320,7 +330,7 @@ def load_verifier_weights(self): self.embed_tokens.load_state_dict(embed_tokens_sd) if self.use_draft_vocab: - if self.t2d is None or not torch.any(self.t2d).item: + if self.t2d is None or not torch.any(self.t2d).item(): # (not torch.any(self.t2d).item) because t2d is initialized to zeros raise ValueError( "t2d tensor hasn't been set. Please call " @@ -546,18 +556,8 @@ def from_training_args( ), ), ) - model = cls(config=config) - # Verify that if one mapping tensor is provided, the other is as well - if (t2d is None) != (d2t is None): - raise ValueError( - "Both t2d and d2t must be provided together, or both must be None. " - f"Got t2d={'provided' if t2d is not None else 'None'}, " - f"d2t={'provided' if d2t is not None else 'None'}" - ) - elif t2d is not None: - model.load_vocab_mappings(t2d, d2t) - + model.load_vocab_mappings(t2d, d2t) model.load_verifier_weights() return model @@ -569,18 +569,8 @@ def from_pretrained( d2t: torch.Tensor | None = None, **kwargs, ) -> "Eagle3DraftModel": - model: Eagle3DraftModel = super().from_pretrained(*args, **kwargs) - - # Verify that if one mapping tensor is provided, the other is as well - if (t2d is None) != (d2t is None): - raise ValueError( - "Both t2d and d2t must be provided together, or both must be None. " - f"Got t2d={'provided' if t2d is not None else 'None'}, " - f"d2t={'provided' if d2t is not None else 'None'}" - ) - elif t2d is not None: - model.load_vocab_mappings(t2d, d2t) - + model: Eagle3DraftModel = super().from_pretrained(*args, **kwargs) # type: ignore[assignment] + model.load_vocab_mappings(t2d, d2t) model.load_verifier_weights() return model diff --git a/src/speculators/train/trainer.py b/src/speculators/train/trainer.py index 332abb3ea..03b2f5a25 100644 --- a/src/speculators/train/trainer.py +++ b/src/speculators/train/trainer.py @@ -4,7 +4,10 @@ import torch import torch.distributed as dist -from torch.distributed.fsdp import FSDPModule +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, +) from torch.utils.data import DataLoader from tqdm import TqdmExperimentalWarning from tqdm.rich import tqdm @@ -88,29 +91,40 @@ def setup_model(self): # Verify model is compatible with training infrastructure SpeculatorModel.verify_training_compatible(self.model) - if self.is_distributed: - apply_fully_sharded(self.model) + load_checkpoint = ( + self.resume_from_checkpoint and self.checkpointer.previous_epoch != -1 + ) - if self.resume_from_checkpoint and self.checkpointer.previous_epoch != -1: - self.checkpointer.load_model_state_dict(self.model) - else: - for m in self.model.layers.children(): # type: ignore[union-attr] - if not isinstance(m, FSDPModule): - continue - acc = torch.accelerator.current_accelerator() - if acc is None: - m.to_empty(device="cuda") # type: ignore[attr-defined] - else: - acc_type = acc.type - m.to_empty(device=acc_type) # type: ignore[attr-defined] - for sub_module in m.modules(): # type: ignore[attr-defined] - if hasattr(sub_module, "reset_parameters"): - sub_module.reset_parameters() # type: ignore[operator] - # todo: Ensure lm_head and embed_tokens are loaded after reset - else: + if not self.is_distributed: + # Single device case self.model.to(self.local_rank) # type: ignore[arg-type] - if self.resume_from_checkpoint and self.checkpointer.previous_epoch != -1: + if load_checkpoint: self.checkpointer.load_model_state_dict(self.model) + return + + # Distributed case + # Capture full state dict on rank 0 before FSDP sharding + full_state_dict = {} + if not load_checkpoint and dist.get_rank() == 0: + full_state_dict = self.model.state_dict() + + apply_fully_sharded(self.model) + + if load_checkpoint: + self.checkpointer.load_model_state_dict(self.model) + else: + # Broadcast full state dict from rank 0 to all ranks + set_model_state_dict( + self.model, + full_state_dict, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + strict=False, + ), + ) + del full_state_dict + dist.barrier() def setup_optimizer(self): # Setup optimizer diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py index d1ddeb35e..40ee91f4c 100644 --- a/src/speculators/train/utils.py +++ b/src/speculators/train/utils.py @@ -67,8 +67,6 @@ def apply_fully_sharded(model: torch.nn.Module): ) for layer in model.layers: # type: ignore[union-attr] - # we apply fully_shard to each DecoderLayer - layer.to_empty(device="meta") fully_shard(layer, mp_policy=mp_policy) fully_shard(model) From 950970a797f8ddc8ca54e35b7ede2b85f4ca9c6a Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 4 Mar 2026 23:41:27 +0000 Subject: [PATCH 3/7] Add tests for setup model logic Test correctness of different load from pretrained/checkpoint/fresh init on single gpu and multi-gpu setups. Signed-off-by: Fynn Schmitt-Ulms --- tests/unit/train/test_setup_model.py | 649 +++++++++++++++++++++++++++ 1 file changed, 649 insertions(+) create mode 100644 tests/unit/train/test_setup_model.py diff --git a/tests/unit/train/test_setup_model.py b/tests/unit/train/test_setup_model.py new file mode 100644 index 000000000..83ab16ad8 --- /dev/null +++ b/tests/unit/train/test_setup_model.py @@ -0,0 +1,649 @@ +""" +Tests for model weight loading and initialization pathways. + +Covers: +- Trainer.setup_model for single-GPU (fresh + resume) +- SingleGPUCheckpointer save/load round-trip +- from_pretrained save/load round-trip +- Weight precedence: checkpoint > pretrained > verifier > random init +- Distributed fresh init (FSDP + broadcast, mp.spawn) +- Distributed resume from checkpoint (mp.spawn) +- Distributed from_pretrained (mp.spawn) +""" + +import copy +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from safetensors import safe_open +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from transformers.models.llama.configuration_llama import LlamaConfig + +from speculators import SpeculatorsConfig, VerifierConfig +from speculators.models.eagle3 import Eagle3DraftModel, Eagle3SpeculatorConfig +from speculators.proposals.greedy import GreedyTokenProposalConfig +from speculators.train.checkpointer import ( + DistributedCheckpointer, + SingleGPUCheckpointer, +) +from speculators.train.trainer import Trainer, TrainerConfig + +# --------------------------------------------------------------------------- +# Skip decorators +# --------------------------------------------------------------------------- + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required" +) +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="2+ GPUs required", +) + +# --------------------------------------------------------------------------- +# Tiny model constants +# --------------------------------------------------------------------------- + +TINY_LLAMA_CONFIG = LlamaConfig( + vocab_size=64, + hidden_size=32, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=8, + max_position_embeddings=32, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + _attn_implementation="eager", +) + + +# --------------------------------------------------------------------------- +# Helpers (used by both fixtures and mp.spawn workers) +# --------------------------------------------------------------------------- + + +def _make_eagle3_config() -> Eagle3SpeculatorConfig: + return Eagle3SpeculatorConfig( + transformer_layer_config=copy.deepcopy(TINY_LLAMA_CONFIG), + draft_vocab_size=64, + norm_before_residual=False, + embed_requires_grad=False, + speculators_config=SpeculatorsConfig( + algorithm="eagle3", + proposal_methods=[GreedyTokenProposalConfig(speculative_tokens=1)], + default_proposal_method="greedy", + verifier=VerifierConfig( + name_or_path=None, + architectures=["LlamaForCausalLM"], + ), + ), + ) + + +def _make_tiny_model() -> Eagle3DraftModel: + """Create a tiny Eagle3 model with NaN weights filled.""" + model = Eagle3DraftModel(_make_eagle3_config()) + _fill_nan_weights(model) + return model + + +def _fill_nan_weights(model: Eagle3DraftModel): + """Replace NaN-initialized weights with deterministic values (simulates + what load_verifier_weights does).""" + with torch.no_grad(): + torch.nn.init.ones_(model.embed_tokens.weight) + torch.nn.init.ones_(model.lm_head.weight) + torch.nn.init.ones_(model.verifier_lm_head.weight) + torch.nn.init.ones_(model.verifier_norm.weight) + + +def _make_trainer_no_init( + model, + *, + is_distributed=False, + resume_from_checkpoint=False, + local_rank=0, + save_path="/tmp/test_ckpt", +): + """Create a Trainer instance bypassing __init__ to control setup order.""" + config = TrainerConfig( + lr=1e-4, + num_epochs=1, + save_path=save_path, + resume_from_checkpoint=resume_from_checkpoint, + is_distributed=is_distributed, + local_rank=local_rank, + ) + trainer = Trainer.__new__(Trainer) + trainer.model = model + trainer.config = config + trainer.local_rank = config.local_rank + trainer.is_distributed = config.is_distributed + trainer.resume_from_checkpoint = config.resume_from_checkpoint + trainer.train_loader = MagicMock(__len__=MagicMock(return_value=1)) + trainer.val_loader = None + return trainer + + +def _param_checksums(state_dict: dict[str, torch.Tensor]) -> dict[str, float]: + """Compute per-key checksums for cross-rank comparison.""" + return { + k: v.float().sum().item() + for k, v in state_dict.items() + if isinstance(v, torch.Tensor) + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def eagle3_config(): + return _make_eagle3_config() + + +@pytest.fixture +def tiny_model(): + """Tiny Eagle3 model on CPU with NaN weights filled.""" + return _make_tiny_model() + + +@pytest.fixture +def tiny_model_on_gpu(tiny_model): + """Tiny Eagle3 model moved to cuda:0.""" + return tiny_model.to("cuda:0") + + +@pytest.fixture +def checkpoint_dir(tmp_path, tiny_model_on_gpu): + """Save a checkpoint with trainable weights = 42.0, return the path.""" + with torch.no_grad(): + for p in tiny_model_on_gpu.parameters(): + if p.requires_grad: + p.fill_(42.0) + + ckpt_dir = tmp_path / "ckpt" + checkpointer = SingleGPUCheckpointer(ckpt_dir) + optimizer = torch.optim.AdamW(tiny_model_on_gpu.parameters(), lr=1e-4) + checkpointer.save_checkpoint(tiny_model_on_gpu, optimizer, epoch=0) + return ckpt_dir + + +@pytest.fixture +def pretrained_dir(tmp_path, tiny_model): + """Save a pretrained model with fc=66.0, lm_head=55.0, return the path.""" + with torch.no_grad(): + tiny_model.fc.weight.fill_(66.0) + tiny_model.lm_head.weight.fill_(55.0) + model_dir = tmp_path / "pretrained" + tiny_model.save_pretrained(str(model_dir)) + return model_dir + + +@pytest.fixture +def mock_checkpointer(): + """Mock checkpointer with no previous checkpoint.""" + ckpt = MagicMock() + ckpt.previous_epoch = -1 + return ckpt + + +# =================================================================== +# Single GPU — Fresh Init +# =================================================================== + + +@requires_cuda +def test_single_gpu_fresh_init(tiny_model, mock_checkpointer): + """Fresh single-GPU setup: model moved to device, weights unchanged, + no checkpoint loading.""" + state_before = {k: v.clone() for k, v in tiny_model.state_dict().items()} + + trainer = _make_trainer_no_init(tiny_model, is_distributed=False) + trainer.checkpointer = mock_checkpointer + + trainer.setup_model() + + # Weights should be unchanged (just moved to device) + for k, v in tiny_model.state_dict().items(): + assert torch.allclose(v.cpu().float(), state_before[k].float()), ( + f"Weight {k} changed during fresh init" + ) + + # No checkpoint loading + mock_checkpointer.load_model_state_dict.assert_not_called() + + +# =================================================================== +# Single GPU — Resume from Checkpoint +# =================================================================== + + +@requires_cuda +def test_single_gpu_resume(checkpoint_dir): + """Resume from checkpoint: checkpoint weights loaded, verifier weights + preserved (not overwritten by checkpoint since they're not saved).""" + model = _make_tiny_model() + with torch.no_grad(): + model.verifier_norm.weight.fill_(77.0) + model.verifier_lm_head.weight.fill_(88.0) + + trainer = _make_trainer_no_init( + model, + is_distributed=False, + resume_from_checkpoint=True, + save_path=str(checkpoint_dir), + ) + trainer.checkpointer = SingleGPUCheckpointer(checkpoint_dir) + trainer.setup_model() + + # Trainable weights should match checkpoint (42.0, modulo bf16 round-trip) + for name, param in model.named_parameters(): + if param.requires_grad: + assert torch.allclose(param.cpu().float(), torch.tensor(42.0), atol=0.5), ( + f"Trainable weight {name} not loaded from checkpoint" + ) + + # Verifier weights should be preserved (not in checkpoint) + assert torch.allclose( + model.verifier_norm.weight.cpu().float(), torch.tensor(77.0) + ), "verifier_norm overwritten by checkpoint" + assert torch.allclose( + model.verifier_lm_head.weight.cpu().float(), torch.tensor(88.0) + ), "verifier_lm_head overwritten by checkpoint" + + +# =================================================================== +# Checkpoint Save/Load Round-Trip +# =================================================================== + + +@requires_cuda +def test_checkpoint_save_load_round_trip(checkpoint_dir): + """SingleGPUCheckpointer round-trip: trainable weights preserved, verifier + keys not saved, expected files created.""" + # Verify files + assert (checkpoint_dir / "0" / "model.safetensors").exists() + assert (checkpoint_dir / "0" / "config.json").exists() + assert (checkpoint_dir / "0" / "optimizer_state_dict.pt").exists() + + # Verify verifier-only keys not in saved safetensors + with safe_open( + str(checkpoint_dir / "0" / "model.safetensors"), framework="pt" + ) as f: + saved_keys = set(f.keys()) + for key in Eagle3DraftModel._keys_to_ignore_on_save: + assert key not in saved_keys, f"{key} should not be saved" + + # Load into fresh model and verify trainable weights match + model = _make_tiny_model() + model.to("cuda:0") # type: ignore[arg-type] + checkpointer = SingleGPUCheckpointer(checkpoint_dir) + checkpointer.load_model_state_dict(model) + + for name, param in model.named_parameters(): + if param.requires_grad: + assert torch.allclose(param.cpu().float(), torch.tensor(42.0), atol=0.5), ( + f"Trainable weight {name} not preserved in round-trip" + ) + + +# =================================================================== +# from_pretrained Round-Trip +# =================================================================== + + +def test_from_pretrained_round_trip(tiny_model): + """from_pretrained round-trip: trainable weights preserved, ignored keys + not in saved files, pretrained weights take precedence over verifier.""" + # Set trainable weights to known value + with torch.no_grad(): + for p in tiny_model.parameters(): + if p.requires_grad: + p.fill_(42.0) + + trainable_names = {n for n, p in tiny_model.named_parameters() if p.requires_grad} + original_trainable = { + k: v.clone() for k, v in tiny_model.state_dict().items() if k in trainable_names + } + + with tempfile.TemporaryDirectory() as tmpdir: + tiny_model.save_pretrained(tmpdir) + + # Verify _keys_to_ignore_on_save not in saved files + with safe_open(f"{tmpdir}/model.safetensors", framework="pt") as f: + saved_keys = set(f.keys()) + for key in Eagle3DraftModel._keys_to_ignore_on_save: + assert key not in saved_keys, f"{key} should not be saved" + + # Load (mock load_verifier_weights to avoid HF downloads) + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(tmpdir) + + # Trainable weights should match original + for k, original_v in original_trainable.items(): + loaded_v = loaded.state_dict()[k] + assert torch.allclose(loaded_v.float(), original_v.float(), atol=0.5), ( + f"Weight {k} not preserved in from_pretrained round-trip" + ) + + # lm_head was saved (it's trainable), so from_pretrained loads it. + # Even if load_verifier_weights ran, the NaN guard would keep the + # pretrained value since it's no longer NaN. + assert not loaded.lm_head.weight.isnan().any(), ( + "lm_head should have pretrained value, not NaN" + ) + + +# =================================================================== +# Weight Precedence +# =================================================================== + + +@requires_cuda +def test_weight_precedence(eagle3_config, pretrained_dir, tmp_path): + """Verify weight precedence: checkpoint > pretrained > verifier > random. + + Walks through the full chain in a single test.""" + + # --- Level 5: Random init produces NaN for verifier-loaded weights --- + model = Eagle3DraftModel(eagle3_config) + assert model.embed_tokens.weight.isnan().all(), ( + "embed_tokens should be NaN after random init" + ) + assert model.lm_head.weight.isnan().all(), "lm_head should be NaN after random init" + # fc (trainable) should NOT be NaN — it's randomly initialized + assert not model.fc.weight.isnan().any(), "fc should have random init, not NaN" + + # --- Level 4: Verifier fills NaN weights --- + _fill_nan_weights(model) # simulates load_verifier_weights + assert not model.embed_tokens.weight.isnan().any(), ( + "embed_tokens should be filled by verifier" + ) + assert not model.lm_head.weight.isnan().any(), ( + "lm_head should be filled by verifier" + ) + + # --- Level 2: Pretrained weights take precedence over verifier --- + # pretrained_dir fixture saved lm_head=55.0, fc=66.0 + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(str(pretrained_dir)) + + assert torch.allclose(loaded.lm_head.weight.float(), torch.tensor(55.0)), ( + "pretrained lm_head should not be overwritten by verifier" + ) + assert torch.allclose(loaded.fc.weight.float(), torch.tensor(66.0)), ( + "pretrained fc should be preserved" + ) + + # --- Level 1: Checkpoint overrides everything --- + loaded.to("cuda:0") # type: ignore[arg-type] + with torch.no_grad(): + loaded.fc.weight.fill_(99.0) # checkpoint value + ckpt_dir = str(tmp_path / "ckpt") + checkpointer = SingleGPUCheckpointer(ckpt_dir) + optimizer = torch.optim.AdamW(loaded.parameters(), lr=1e-4) + checkpointer.save_checkpoint(loaded, optimizer, epoch=0) + + # Load checkpoint into a model that had pretrained value (66.0) + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + model3 = Eagle3DraftModel.from_pretrained(str(pretrained_dir)) + model3.to("cuda:0") # type: ignore[arg-type] + checkpointer2 = SingleGPUCheckpointer(ckpt_dir) + checkpointer2.load_model_state_dict(model3) + + assert torch.allclose( + model3.fc.weight.cpu().float(), torch.tensor(99.0), atol=0.5 + ), "checkpoint fc should override pretrained" + + +# =================================================================== +# Distributed helpers +# =================================================================== + + +def _dist_setup(rank, world_size): + """Initialize distributed process group for testing.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def _dist_teardown(): + """Clean up distributed process group.""" + dist.destroy_process_group() + + +def _get_full_state_dict_rank0(model): + """Get unsharded full state dict from FSDP model (only populated on rank 0). + + All ranks must call this (it's a collective op), but only rank 0 + gets the actual tensors.""" + return get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + + +# =================================================================== +# Distributed — Fresh Init +# =================================================================== + + +def _worker_distributed_fresh_init(rank, world_size, results_dir): + """Worker for test_distributed_fresh_init.""" + _dist_setup(rank, world_size) + try: + model = _make_tiny_model() + + # Capture rank 0's pre-FSDP state dict for comparison + pre_fsdp_checksums = _param_checksums(model.state_dict()) if rank == 0 else {} + + trainer = _make_trainer_no_init(model, is_distributed=True, local_rank=rank) + trainer.checkpointer = MagicMock() + trainer.checkpointer.previous_epoch = -1 + + trainer.setup_model() + + # All ranks must call get_model_state_dict (collective op), + # but only rank 0 gets the actual tensors + full_sd = _get_full_state_dict_rank0(model) + + if rank == 0: + checksums = _param_checksums(full_sd) + has_nan = { + k: v.isnan().any().item() + for k, v in full_sd.items() + if isinstance(v, torch.Tensor) and v.is_floating_point() + } + + torch.save( + { + "pre_fsdp_checksums": pre_fsdp_checksums, + "post_fsdp_checksums": checksums, + "has_nan": has_nan, + }, + results_dir / "results.pt", + ) + finally: + _dist_teardown() + + +@requires_multi_gpu +def test_distributed_fresh_init(tmp_path): + """Distributed fresh init: after setup_model, the gathered full state dict + matches rank 0's original pre-FSDP weights and contains no NaN values. + + This verifies that set_model_state_dict(broadcast_from_rank0=True) + correctly distributes rank 0's weights to all ranks, because + get_model_state_dict gathers shards from ALL ranks to reconstruct + the full dict on rank 0.""" + world_size = 2 + results_dir = tmp_path / "results" + results_dir.mkdir() + + mp.spawn( + _worker_distributed_fresh_init, + args=(world_size, results_dir), + nprocs=world_size, + join=True, + ) + + results = torch.load(results_dir / "results.pt", weights_only=False) + + # Post-FSDP gathered state dict should match pre-FSDP state dict from rank 0 + pre = results["pre_fsdp_checksums"] + post = results["post_fsdp_checksums"] + for key in pre: + assert key in post, f"Key {key} missing after FSDP round-trip" + assert pre[key] == pytest.approx(post[key], abs=1e-2), ( + f"Weight {key} changed during FSDP broadcast: " + f"pre={pre[key]}, post={post[key]}" + ) + + # No NaN values in any float parameter + for key, is_nan in results["has_nan"].items(): + assert not is_nan, f"Weight {key} is NaN after distributed setup" + + +# =================================================================== +# Distributed — Resume from Checkpoint +# =================================================================== + + +def _worker_distributed_resume(rank, world_size, ckpt_dir, results_dir): + """Worker for test_distributed_resume.""" + _dist_setup(rank, world_size) + try: + model = _make_tiny_model() + # Set verifier weights to known value before FSDP + with torch.no_grad(): + model.verifier_norm.weight.fill_(77.0) + model.verifier_lm_head.weight.fill_(88.0) + + trainer = _make_trainer_no_init( + model, + is_distributed=True, + resume_from_checkpoint=True, + local_rank=rank, + save_path=ckpt_dir, + ) + trainer.checkpointer = DistributedCheckpointer(ckpt_dir) + trainer.setup_model() + + # All ranks must call (collective op), only rank 0 gets data + full_sd = _get_full_state_dict_rank0(model) + + if rank == 0: + checksums = _param_checksums(full_sd) + verifier_norm_val = full_sd["verifier_norm.weight"].float().mean().item() + verifier_lm_head_val = ( + full_sd["verifier_lm_head.weight"].float().mean().item() + ) + + torch.save( + { + "checksums": checksums, + "verifier_norm_val": verifier_norm_val, + "verifier_lm_head_val": verifier_lm_head_val, + }, + results_dir / "results.pt", + ) + finally: + _dist_teardown() + + +@requires_multi_gpu +def test_distributed_resume(checkpoint_dir, tmp_path): + """Distributed resume: checkpoint weights loaded correctly, verifier + weights preserved (not overwritten by checkpoint).""" + world_size = min(torch.cuda.device_count(), 2) + results_dir = tmp_path / "results" + results_dir.mkdir() + + mp.spawn( + _worker_distributed_resume, + args=(world_size, str(checkpoint_dir), results_dir), + nprocs=world_size, + join=True, + ) + + results = torch.load(results_dir / "results.pt", weights_only=False) + + # Verifier weights should be preserved (not in checkpoint) + assert results["verifier_norm_val"] == pytest.approx(77.0, abs=0.1), ( + "verifier_norm overwritten by checkpoint" + ) + assert results["verifier_lm_head_val"] == pytest.approx(88.0, abs=0.1), ( + "verifier_lm_head overwritten by checkpoint" + ) + + +# =================================================================== +# Distributed — from_pretrained +# =================================================================== + + +def _worker_distributed_from_pretrained(rank, world_size, model_dir, results_dir): + """Worker for test_distributed_from_pretrained.""" + _dist_setup(rank, world_size) + try: + # Load model from pretrained (mock verifier loading) + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + model = Eagle3DraftModel.from_pretrained(model_dir) + _fill_nan_weights(model) # fill verifier weights post-load + + trainer = _make_trainer_no_init(model, is_distributed=True, local_rank=rank) + trainer.checkpointer = MagicMock() + trainer.checkpointer.previous_epoch = -1 + + trainer.setup_model() + + # All ranks must call (collective op), only rank 0 gets data + full_sd = _get_full_state_dict_rank0(model) + + if rank == 0: + checksums = _param_checksums(full_sd) + fc_val = full_sd["fc.weight"].float().mean().item() + + torch.save( + {"checksums": checksums, "fc_val": fc_val}, + results_dir / "results.pt", + ) + finally: + _dist_teardown() + + +@requires_multi_gpu +def test_distributed_from_pretrained(pretrained_dir, tmp_path): + """Model loaded via from_pretrained should have correct weights after FSDP + setup, with pretrained weight values preserved through the broadcast.""" + world_size = min(torch.cuda.device_count(), 2) + results_dir = tmp_path / "results" + results_dir.mkdir() + + mp.spawn( + _worker_distributed_from_pretrained, + args=(world_size, str(pretrained_dir), results_dir), + nprocs=world_size, + join=True, + ) + + results = torch.load(results_dir / "results.pt", weights_only=False) + + # Pretrained fc weight should be preserved through FSDP setup + assert results["fc_val"] == pytest.approx(66.0, abs=0.5), ( + "Pretrained fc weight not preserved through FSDP broadcast" + ) From c3d33b358fc1c59e045074c081c933a9b0d4e10d Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 5 Mar 2026 00:43:45 +0000 Subject: [PATCH 4/7] Fix vocab mapping buffer initialization and add tests Signed-off-by: Fynn Schmitt-Ulms --- src/speculators/models/eagle3/core.py | 21 ++-- tests/unit/train/test_setup_model.py | 173 +++++++++++++++++++++++++- 2 files changed, 183 insertions(+), 11 deletions(-) diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index e5fd514da..e6208fa13 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -168,8 +168,15 @@ class Eagle3DraftModel(SpeculatorModel): "verifier_norm.weight", ] + t2d: torch.Tensor | None + d2t: torch.Tensor | None + def __init__(self, config: Eagle3SpeculatorConfig): + # Forcibly override config settings + config.tie_word_embeddings = False + config.transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 super().__init__(config=config) + self.hidden_size = config.transformer_layer_config.hidden_size self.draft_vocab_size = config.draft_vocab_size self.verifier_vocab_size = config.transformer_layer_config.vocab_size @@ -179,15 +186,13 @@ def __init__(self, config: Eagle3SpeculatorConfig): # VOCAB MAPPINGS self.use_draft_vocab = self.draft_vocab_size != self.verifier_vocab_size - self.t2d: torch.Tensor | None = None - self.d2t: torch.Tensor | None = None + t2d: torch.Tensor | None = None + d2t: torch.Tensor | None = None if self.use_draft_vocab: - # Use NaNs as placeholder so that it's clear if these aren't updated - # todo(fynn): NaNs might not work with the dtypes here - self.t2d = torch.zeros((self.verifier_vocab_size,), dtype=torch.bool) - self.d2t = torch.zeros((self.draft_vocab_size,), dtype=torch.long) - self.register_buffer("t2d", self.t2d) - self.register_buffer("d2t", self.d2t) + t2d = torch.zeros((self.verifier_vocab_size,), dtype=torch.bool) + d2t = torch.zeros((self.draft_vocab_size,), dtype=torch.long) + self.register_buffer("t2d", t2d) + self.register_buffer("d2t", d2t) # FC LAYER self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False) diff --git a/tests/unit/train/test_setup_model.py b/tests/unit/train/test_setup_model.py index 83ab16ad8..5b1b20b6e 100644 --- a/tests/unit/train/test_setup_model.py +++ b/tests/unit/train/test_setup_model.py @@ -72,10 +72,13 @@ # --------------------------------------------------------------------------- -def _make_eagle3_config() -> Eagle3SpeculatorConfig: +def _make_eagle3_config( + draft_vocab_size: int = 64, + verifier_name_or_path: str | None = None, +) -> Eagle3SpeculatorConfig: return Eagle3SpeculatorConfig( transformer_layer_config=copy.deepcopy(TINY_LLAMA_CONFIG), - draft_vocab_size=64, + draft_vocab_size=draft_vocab_size, norm_before_residual=False, embed_requires_grad=False, speculators_config=SpeculatorsConfig( @@ -83,13 +86,29 @@ def _make_eagle3_config() -> Eagle3SpeculatorConfig: proposal_methods=[GreedyTokenProposalConfig(speculative_tokens=1)], default_proposal_method="greedy", verifier=VerifierConfig( - name_or_path=None, + name_or_path=verifier_name_or_path, architectures=["LlamaForCausalLM"], ), ), ) +def _make_vocab_mappings( + verifier_vocab_size: int = 64, + draft_vocab_size: int = 32, +) -> tuple[torch.Tensor, torch.Tensor]: + """Create valid t2d and d2t tensors for testing. + + Selects the first `draft_vocab_size` tokens from the verifier vocab. + t2d: bool[verifier_vocab_size] — True for tokens included in draft vocab. + d2t: long[draft_vocab_size] — maps draft index to verifier index. + """ + t2d = torch.zeros(verifier_vocab_size, dtype=torch.bool) + t2d[:draft_vocab_size] = True + d2t = torch.arange(draft_vocab_size, dtype=torch.long) + return t2d, d2t + + def _make_tiny_model() -> Eagle3DraftModel: """Create a tiny Eagle3 model with NaN weights filled.""" model = Eagle3DraftModel(_make_eagle3_config()) @@ -647,3 +666,151 @@ def test_distributed_from_pretrained(pretrained_dir, tmp_path): assert results["fc_val"] == pytest.approx(66.0, abs=0.5), ( "Pretrained fc weight not preserved through FSDP broadcast" ) + + +# =================================================================== +# Vocab Mapping Loading (t2d / d2t) +# =================================================================== + +DRAFT_VOCAB_SIZE = 32 # < TINY_LLAMA_CONFIG.vocab_size (64) + + +@pytest.fixture +def draft_vocab_config(): + """Eagle3 config with draft_vocab_size < verifier_vocab_size.""" + return _make_eagle3_config(draft_vocab_size=DRAFT_VOCAB_SIZE) + + +@pytest.fixture +def vocab_mappings(): + """Valid (t2d, d2t) pair for verifier_vocab=64, draft_vocab=32.""" + return _make_vocab_mappings( + verifier_vocab_size=TINY_LLAMA_CONFIG.vocab_size, + draft_vocab_size=DRAFT_VOCAB_SIZE, + ) + + +def test_load_vocab_mappings(draft_vocab_config, vocab_mappings): + """load_vocab_mappings stores t2d/d2t buffers correctly.""" + t2d, d2t = vocab_mappings + model = Eagle3DraftModel(draft_vocab_config) + + # Before loading: buffers exist but are zeros + assert model.t2d is not None + assert not model.t2d.any(), "t2d should be all zeros before loading" + assert model.d2t is not None + assert (model.d2t == 0).all(), "d2t should be all zeros before loading" + + model.load_vocab_mappings(t2d, d2t) + + # After loading: buffers match inputs + assert torch.equal(model.t2d, t2d), "t2d not loaded correctly" + assert torch.equal(model.d2t, d2t), "d2t not loaded correctly" + + +def test_load_vocab_mappings_validation(draft_vocab_config, vocab_mappings): + """load_vocab_mappings raises on invalid inputs.""" + t2d, d2t = vocab_mappings + model = Eagle3DraftModel(draft_vocab_config) + + # Only one of t2d/d2t provided + with pytest.raises(ValueError, match="Both t2d and d2t must be provided"): + model.load_vocab_mappings(t2d, None) + with pytest.raises(ValueError, match="Both t2d and d2t must be provided"): + model.load_vocab_mappings(None, d2t) + + # Wrong t2d shape + with pytest.raises(ValueError, match="t2d.shape"): + model.load_vocab_mappings(torch.ones(10, dtype=torch.bool), d2t) + + # Wrong d2t shape + with pytest.raises(ValueError, match="d2t.shape"): + model.load_vocab_mappings(t2d, torch.zeros(10, dtype=torch.long)) + + # Wrong number of True values in t2d + bad_t2d = torch.ones(TINY_LLAMA_CONFIG.vocab_size, dtype=torch.bool) + with pytest.raises(ValueError, match="non-zero values"): + model.load_vocab_mappings(bad_t2d, d2t) + + +def test_load_vocab_mappings_not_needed(): + """load_vocab_mappings raises when vocab sizes match (no mapping needed).""" + config = _make_eagle3_config(draft_vocab_size=64) # same as verifier + model = Eagle3DraftModel(config) + t2d, d2t = _make_vocab_mappings(verifier_vocab_size=64, draft_vocab_size=64) + + with pytest.raises(RuntimeError, match="not needed"): + model.load_vocab_mappings(t2d, d2t) + + +def test_from_training_args_loads_vocab_mappings(vocab_mappings): + """from_training_args passes t2d/d2t through to load_vocab_mappings.""" + t2d, d2t = vocab_mappings + + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + model = Eagle3DraftModel.from_training_args( + verifier_config=copy.deepcopy(TINY_LLAMA_CONFIG), + t2d=t2d, + d2t=d2t, + draft_vocab_size=DRAFT_VOCAB_SIZE, + norm_before_residual=False, + ttt_steps=1, + verifier_name_or_path="dummy", + ) + + assert model.t2d is not None, "t2d is None after from_training_args" + assert model.d2t is not None, "d2t is None after from_training_args" + assert torch.equal(model.t2d, t2d), "t2d not loaded via from_training_args" + assert torch.equal(model.d2t, d2t), "d2t not loaded via from_training_args" + + +def test_from_pretrained_loads_vocab_mappings_from_kwargs( + tmp_path, draft_vocab_config, vocab_mappings +): + """from_pretrained loads t2d/d2t passed as kwargs.""" + t2d, d2t = vocab_mappings + + # Save a model without vocab mappings in the safetensors + model = Eagle3DraftModel(draft_vocab_config) + _fill_nan_weights(model) + model_dir = tmp_path / "pretrained_no_vocab" + model.save_pretrained(str(model_dir)) + + # Load with t2d/d2t passed as kwargs + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(str(model_dir), t2d=t2d, d2t=d2t) + + assert loaded.t2d is not None, "t2d is None after from_pretrained" + assert loaded.d2t is not None, "d2t is None after from_pretrained" + assert torch.equal(loaded.t2d, t2d), "t2d not loaded from kwargs in from_pretrained" + assert torch.equal(loaded.d2t, d2t), "d2t not loaded from kwargs in from_pretrained" + + +def test_from_pretrained_loads_vocab_mappings_from_saved( + tmp_path, draft_vocab_config, vocab_mappings +): + """from_pretrained loads t2d/d2t from saved safetensors when not passed + as kwargs.""" + t2d, d2t = vocab_mappings + + # Save model WITH vocab mappings loaded + model = Eagle3DraftModel(draft_vocab_config) + _fill_nan_weights(model) + model.load_vocab_mappings(t2d, d2t) + model_dir = tmp_path / "pretrained_with_vocab" + model.save_pretrained(str(model_dir)) + + # Verify t2d/d2t are in the saved safetensors + with safe_open(str(model_dir / "model.safetensors"), framework="pt") as f: + saved_keys = set(f.keys()) + assert "t2d" in saved_keys, "t2d should be saved in safetensors" + assert "d2t" in saved_keys, "d2t should be saved in safetensors" + + # Load WITHOUT passing t2d/d2t — should come from safetensors + with patch.object(Eagle3DraftModel, "load_verifier_weights"): + loaded = Eagle3DraftModel.from_pretrained(str(model_dir)) + + assert loaded.t2d is not None, "t2d is None after from_pretrained" + assert loaded.d2t is not None, "d2t is None after from_pretrained" + assert torch.equal(loaded.t2d, t2d), "t2d not loaded from saved safetensors" + assert torch.equal(loaded.d2t, d2t), "d2t not loaded from saved safetensors" From 5e0e2395f27e173d1da17a5f7f28e78b017bc774 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 5 Mar 2026 00:43:45 +0000 Subject: [PATCH 5/7] Update Eagle3DraftModel config override handling Signed-off-by: Fynn Schmitt-Ulms --- scripts/train.py | 4 +--- src/speculators/model.py | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index c752a458d..2e69a7ef4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -120,7 +120,7 @@ def create_transformer_layer_config( if hasattr(verifier_config, "text_config"): verifier_config = verifier_config.text_config - transformer_layer_config = config_class( + return config_class( vocab_size=verifier_config.vocab_size, hidden_size=verifier_config.hidden_size, intermediate_size=verifier_config.intermediate_size, @@ -134,8 +134,6 @@ def create_transformer_layer_config( head_dim=getattr(verifier_config, "head_dim", None), tie_word_embeddings=False, ) - transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 - return transformer_layer_config def main(args: argparse.Namespace): diff --git a/src/speculators/model.py b/src/speculators/model.py index 50646f78f..8fd80fb76 100644 --- a/src/speculators/model.py +++ b/src/speculators/model.py @@ -141,9 +141,6 @@ def from_pretrained( "provided to load a SpeculatorModel." ) - config.tie_word_embeddings = False - config.transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 - if cls is SpeculatorModel: # generic call to from_pretrained on this class, need to resolve the # specific model class to use for loading based on the config and registry From 465ca8dfc0847d4e72723260286067bd3b1b857e Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 5 Mar 2026 19:23:54 +0000 Subject: [PATCH 6/7] Restore `--embed-requires-grad` option Signed-off-by: Fynn Schmitt-Ulms --- scripts/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index 2e69a7ef4..94986a460 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -306,6 +306,12 @@ def parse_args(): default=True, help="Toggle normalization before residual connections (default: True)", ) + parser.add_argument( + "--embed-requires-grad", + action=argparse.BooleanOptionalAction, + default=False, + help="Whether to train embedding layer weights (default: False)", + ) # Dataloader parameters parser.add_argument( "--num-workers", type=int, default=12, help="Number of dataloader workers" From c09fab9f9b7817fcb209770cd46b01c743be4729 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Mon, 9 Mar 2026 17:45:00 +0000 Subject: [PATCH 7/7] Add from_pretrained finetuning e2e sanity test Signed-off-by: Fynn Schmitt-Ulms --- tests/e2e/vllm/test_finetuning_sanity.py | 155 +++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 tests/e2e/vllm/test_finetuning_sanity.py diff --git a/tests/e2e/vllm/test_finetuning_sanity.py b/tests/e2e/vllm/test_finetuning_sanity.py new file mode 100644 index 000000000..bfc377ba4 --- /dev/null +++ b/tests/e2e/vllm/test_finetuning_sanity.py @@ -0,0 +1,155 @@ +"""E2E test: verify finetuning modifies weights but keeps them close (bounded rel L1). + +Uses relative L1 distance: ||a-b||_1 / (||b||_1 + eps) per tensor. +All trainable tensors must have rel_l1 <= REL_L1_MAX; +at least a few must have rel_l1 > REL_L1_MIN to ensure they changed. + +To see all logs (per-tensor distances, frozen checks, etc.), run: + pytest tests/e2e/vllm/test_finetuning_weight_sanity.py -s --log-cli-level=INFO +Without -s, pytest captures output and you won't see logger output until failure. +""" + +import logging +import subprocess +import sys +from pathlib import Path + +import pytest +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + +from speculators import Eagle3DraftModel + +logger = logging.getLogger(__name__) + + +@pytest.mark.e2e +@pytest.mark.slow +def test_finetuning_weight_sanity(tmp_path: Path): + """Verify finetuning changes weights but keeps rel L1 distance bounded (low LR).""" + # Ensure logs are visible when running with -s (no capture) + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s %(name)s %(message)s", + stream=sys.stderr, + force=True, + ) + + # Learning rate used for training (shared for log and CLI). + FROZEN_KEY_PATTERNS = ("d2t", "embed_tokens.weight", "t2d") + # Relative distance thresholds: all trainable tensors must have rel_l1 <= REL_L1_MAX + # at least MIN_CHANGED tensors must have rel_l1 > REL_L1_MIN + # to ensure weights actually changed. + LR = "1e-5" + REL_L1_MAX = 0.05 + REL_L1_MIN = 1e-4 + MIN_CHANGED = 3 + EPS = 1e-12 + PRETRAINED = "RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3" + DATASET = "nm-testing/sharegpt_llama3_8b_hidden_states" + + # Get initial state dict + model = Eagle3DraftModel.from_pretrained(PRETRAINED) + initial_sd = model.state_dict() + # Remove verifier weights which aren't saved in checkpoints + del initial_sd["verifier_norm.weight"] + del initial_sd["verifier_lm_head.weight"] + del model + + # Run short training with low LR for single epoch + logger.info("Downloading dataset %s", DATASET) + data_dir = snapshot_download(repo_id=DATASET, repo_type="dataset") + logger.info("Dataset at %s", data_dir) + logger.info( + "Running training (1 epoch, lr=%s, save_path=%s)", LR, tmp_path / "ckpt" + ) + result = subprocess.run( # noqa: S603 + [ + sys.executable, + "scripts/train.py", + "--from-pretrained", + PRETRAINED, + "--verifier-name-or-path", + "meta-llama/Llama-3.1-8B-Instruct", + "--data-path", + data_dir, + "--save-path", + str(tmp_path / "ckpt"), + "--log-dir", + str(tmp_path / "logs"), + "--epochs", + "2", + "--lr", + LR, + "--total-seq-len", + "2048", + "--num-workers", + "2", + ], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + logger.error( + "Training failed (returncode=%d). stderr:\n%s", + result.returncode, + result.stderr, + ) + if result.stdout: + logger.debug("Training stdout:\n%s", result.stdout) + assert result.returncode == 0, f"Training failed:\n{result.stderr}" + + logger.info( + "Training finished. Loading finetuned weights from %s", tmp_path / "ckpt" + ) + ckpt_dir = next((tmp_path / "ckpt").glob("*")) + finetuned_sd = {} + for f in ckpt_dir.glob("*.safetensors"): + finetuned_sd.update(load_file(str(f))) + logger.info("Loaded %d parameter tensors from checkpoint", len(finetuned_sd)) + + # Verify same keys + assert set(initial_sd.keys()) == set(finetuned_sd.keys()) + + # These tensors must remain identical (frozen / not trained) + num_changed = 0 + for key in sorted(initial_sd.keys()): + assert initial_sd[key].shape == finetuned_sd[key].shape, ( + f"Shape mismatch for {key}: " + f"initial {initial_sd[key].shape} vs finetuned {finetuned_sd[key].shape}" + ) + + if any(pat in key for pat in FROZEN_KEY_PATTERNS): + assert torch.equal(initial_sd[key], finetuned_sd[key]), ( + f"Tensor {key} must stay identical after finetuning (frozen); " + f"initial and finetuned differ" + ) + logger.info(" [frozen] %s: identical", key) + else: + # Trainable + diff = initial_sd[key] - finetuned_sd[key] + l1_norm_finetuned = finetuned_sd[key].abs().sum() + EPS + rel_l1 = (diff.abs().sum() / l1_norm_finetuned).item() + max_abs = diff.abs().max().item() + mean_abs = diff.abs().mean().item() + logger.info( + " %s: rel_l1=%.3e max|Δ|=%.3e mean|Δ|=%.3e", + key, + rel_l1, + max_abs, + mean_abs, + ) + assert rel_l1 <= REL_L1_MAX, ( + f"Tensor {key} has rel_l1={rel_l1:.4e} > {REL_L1_MAX} " + f"(weights changed too much)" + ) + + if rel_l1 >= REL_L1_MIN: + num_changed += 1 + + assert num_changed >= MIN_CHANGED, ( + f"Expected at least {MIN_CHANGED} tensors with rel_l1 > {REL_L1_MIN}, " + f"got {num_changed}." + )