diff --git a/src/speculators/models/dflash/core.py b/src/speculators/models/dflash/core.py index 7661bcf96..4aa084154 100644 --- a/src/speculators/models/dflash/core.py +++ b/src/speculators/models/dflash/core.py @@ -3,10 +3,7 @@ import torch from torch import nn from torch.nn.attention.flex_attention import create_block_mask -from transformers import ( - AutoConfig, - PretrainedConfig, -) +from transformers import PretrainedConfig from transformers.models.qwen3.modeling_qwen3 import ( Qwen3RMSNorm, Qwen3RotaryEmbedding, @@ -21,6 +18,7 @@ get_base_indices_for_anchored_blocks, select_anchors, ) +from speculators.models.utils import resolve_target_layer_ids @SpeculatorModel.register("dflash") @@ -63,25 +61,12 @@ def __init__( ] ) - verifier_name_or_path = config.speculators_config.verifier.name_or_path - if verifier_name_or_path is None: - raise ValueError("Verifier name_or_path must be set in speculators_config") - verifier_config = AutoConfig.from_pretrained(verifier_name_or_path) - if hasattr(verifier_config, "text_config"): - verifier_config = verifier_config.text_config - num_verifier_layers = verifier_config.num_hidden_layers - - if config.aux_hidden_state_layer_ids is not None: - self.target_layer_ids = config.aux_hidden_state_layer_ids - else: - # Eagle3 defaults; write back so they are persisted in config.json - self.target_layer_ids = [ - 2, - num_verifier_layers // 2, - num_verifier_layers - 3, - ] - # set defaults to config if not provided - vLLM will fail otherwise - config.aux_hidden_state_layer_ids = self.target_layer_ids + if config.aux_hidden_state_layer_ids is None: + raise ValueError( + "aux_hidden_state_layer_ids must be set in DFlashSpeculatorConfig. " + "Use DFlashDraftModel.from_training_args() to resolve defaults." + ) + self.target_layer_ids = config.aux_hidden_state_layer_ids self.norm = Qwen3RMSNorm( config.transformer_layer_config.hidden_size, @@ -140,12 +125,17 @@ def from_training_args( GreedyTokenProposalConfig, ) + target_layer_ids = resolve_target_layer_ids( + kwargs.get("target_layer_ids"), + kwargs["verifier_name_or_path"], + ) + config = DFlashSpeculatorConfig( transformer_layer_config=verifier_config, draft_vocab_size=kwargs["draft_vocab_size"], block_size=kwargs.get("block_size", 8), max_anchors=kwargs.get("max_anchors", 3072), - aux_hidden_state_layer_ids=kwargs.get("target_layer_ids"), + aux_hidden_state_layer_ids=target_layer_ids, mask_token_id=kwargs.get("mask_token_id"), speculators_config=SpeculatorsConfig( algorithm="dflash", diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index 16a67c54c..ff7142844 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -15,6 +15,7 @@ extend_mask_for_draft_tokens, ) from speculators.models.eagle3.model_definitions import model_classes +from speculators.models.utils import resolve_target_layer_ids from speculators.proposals.greedy import GreedyTokenProposalConfig from speculators.utils.loading import load_model_layers @@ -427,19 +428,10 @@ def from_training_args( Returns: Initialized Eagle3DraftModel """ - target_layer_ids = kwargs.get("target_layer_ids") - if target_layer_ids is None: - unmodified_verifier_config = AutoConfig.from_pretrained( - kwargs["verifier_name_or_path"] - ) - num_target_layers = unmodified_verifier_config.num_hidden_layers - target_layer_ids = [2, num_target_layers // 2, num_target_layers - 3] - warnings.warn( - "--target-layer-ids is not explicitly set. Setting target " - f"layers to {target_layer_ids}. If custom target layers were used " - "when launching vllm datagen, please set them explicitly.", - stacklevel=2, - ) + target_layer_ids = resolve_target_layer_ids( + kwargs.get("target_layer_ids"), + kwargs["verifier_name_or_path"], + ) config = Eagle3SpeculatorConfig( transformer_layer_config=verifier_config, diff --git a/src/speculators/models/utils.py b/src/speculators/models/utils.py new file mode 100644 index 000000000..0ebb8abca --- /dev/null +++ b/src/speculators/models/utils.py @@ -0,0 +1,33 @@ +import warnings + +from transformers import AutoConfig, PretrainedConfig + + +def get_verifier_config(verifier_name_or_path: str) -> PretrainedConfig: + verifier_config = AutoConfig.from_pretrained(verifier_name_or_path) + if hasattr(verifier_config, "text_config"): + verifier_config = verifier_config.text_config + return verifier_config + + +DEFAULT_TARGET_LAYER_IDS_WARNING = ( + "--target-layer-ids is not explicitly set. Setting target " + "layers to {target_layer_ids}. If custom target layers were used " + "when launching vllm datagen, please set them explicitly." +) + + +def resolve_target_layer_ids( + target_layer_ids: list[int] | None, + verifier_name_or_path: str, +) -> list[int]: + if target_layer_ids is not None: + return target_layer_ids + + num_layers = get_verifier_config(verifier_name_or_path).num_hidden_layers + target_layer_ids = [2, num_layers // 2, num_layers - 3] + warnings.warn( + DEFAULT_TARGET_LAYER_IDS_WARNING.format(target_layer_ids=target_layer_ids), + stacklevel=3, + ) + return target_layer_ids