diff --git a/scripts/launch_vllm.py b/scripts/launch_vllm.py index e3e3e682a..7d7899308 100644 --- a/scripts/launch_vllm.py +++ b/scripts/launch_vllm.py @@ -2,6 +2,7 @@ import json import os import sys +import warnings def parse_args(): @@ -9,7 +10,7 @@ def parse_args(): description="Launch vLLM for hidden states extraction", usage=( "launch_vllm.py [-h] MODEL [--hidden-states-path HIDDEN_STATES_PATH] " - "[--layers LAYERS [LAYERS ...]] -- *VLLM_ARGS" + "[--target-layer-ids TARGET_LAYER_IDS [TARGET_LAYER_IDS ...]] -- *VLLM_ARGS" ), ) parser.add_argument( @@ -22,12 +23,13 @@ def parse_args(): help="The directory to save hidden states to. Default '/tmp/hidden_states'.", ) parser.add_argument( - "--layers", + "--target-layer-ids", type=int, nargs="+", help=( - "(Optional) A (space separated) list of integer layer ids. Default layers " - "[2, num_hidden_layers // 2, num_hidden_layers - 3, num_hidden_layers]." + "(Optional) A (space separated) list of integer layer ids. Defaults to " + "[2, num_hidden_layers // 2, num_hidden_layers - 3, num_hidden_layers]. " + "Note: if set, you must also pass the same value into the training process" ), ) parser.add_argument( @@ -43,10 +45,15 @@ def main(): if "--" in vllm_args: vllm_args.remove("--") - if args.layers: - layers = args.layers + if args.target_layer_ids: + target_layer_ids = args.target_layer_ids + warnings.warn( + f"Using custom target layer ids {args.target_layer_ids}. These " + "must also be explicitly passed into the training script.", + stacklevel=2, + ) else: - # Import here so that it isn't required if layers passed explicitly + # Import here so that it isn't required if target_layer_ids passed explicitly from transformers import AutoConfig # noqa: PLC0415 config = AutoConfig.from_pretrained(args.model) @@ -54,13 +61,18 @@ def main(): config = config.text_config num_hidden_layers = config.num_hidden_layers - layers = [2, num_hidden_layers // 2, num_hidden_layers - 3, num_hidden_layers] + target_layer_ids = [ + 2, + num_hidden_layers // 2, + num_hidden_layers - 3, + num_hidden_layers, + ] speculative_config = { "method": "extract_hidden_states", "num_speculative_tokens": 1, "draft_model_config": { - "hf_config": {"eagle_aux_hidden_state_layer_ids": layers} + "hf_config": {"eagle_aux_hidden_state_layer_ids": target_layer_ids} }, } kv_transfer_config = { diff --git a/scripts/train.py b/scripts/train.py index 1490bad40..7c7bf3189 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -437,7 +437,16 @@ def parse_args(): help="Architecture for draft decoder layers. Defaults to 'llama'. " "Note: only 'llama' is currently supported in vLLM for inference.", ) - + parser.add_argument( + "--target-layer-ids", + type=int, + nargs="+", + help=( + "(Optional) A (space separated) list of integer layer ids. Defaults to" + "[2, num_hidden_layers // 2, num_hidden_layers - 3, num_hidden_layers]. " + "Note: must be set explicitly if custom values were used to launch vllm" + ), + ) parser.add_argument( "--token-freq-path", type=str, diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index 0f07e0da1..15077e6fd 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -548,12 +548,27 @@ def from_training_args( Returns: Initialized Eagle3DraftModel """ + target_layer_ids = kwargs.get("target_layer_ids") + if target_layer_ids is None: + unmodified_verifier_config = AutoConfig.from_pretrained( + kwargs["verifier_name_or_path"] + ) + num_target_layers = unmodified_verifier_config.num_hidden_layers + target_layer_ids = [2, num_target_layers // 2, num_target_layers - 3] + warnings.warn( + "--target-layer-ids is not explicitly set. Setting target " + f"layers to {target_layer_ids}. If custom target layers were used " + "when launching vllm datagen, please set them explicitly.", + stacklevel=2, + ) + config = Eagle3SpeculatorConfig( transformer_layer_config=verifier_config, draft_vocab_size=kwargs["draft_vocab_size"], norm_before_residual=kwargs["norm_before_residual"], norm_before_fc=kwargs.get("norm_before_fc", False), embed_requires_grad=kwargs.get("embed_requires_grad", False), + eagle_aux_hidden_state_layer_ids=target_layer_ids, speculators_config=SpeculatorsConfig( algorithm="eagle3", proposal_methods=[ diff --git a/tests/unit/train/test_setup_model.py b/tests/unit/train/test_setup_model.py index d82f2ae9e..eeb3982d0 100644 --- a/tests/unit/train/test_setup_model.py +++ b/tests/unit/train/test_setup_model.py @@ -761,7 +761,7 @@ def test_from_training_args_loads_vocab_mappings(vocab_mappings): draft_vocab_size=DRAFT_VOCAB_SIZE, norm_before_residual=False, ttt_steps=1, - verifier_name_or_path="dummy", + verifier_name_or_path="nm-testing/tinysmokellama-3.2", ) assert model.t2d is not None, "t2d is None after from_training_args"