Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions src/speculators/models/dflash/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import torch
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from transformers import (
AutoConfig,
PretrainedConfig,
)
from transformers import PretrainedConfig
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
Expand All @@ -21,6 +18,7 @@
get_base_indices_for_anchored_blocks,
select_anchors,
)
from speculators.models.utils import resolve_target_layer_ids


@SpeculatorModel.register("dflash")
Expand Down Expand Up @@ -63,25 +61,12 @@ def __init__(
]
)

verifier_name_or_path = config.speculators_config.verifier.name_or_path
if verifier_name_or_path is None:
raise ValueError("Verifier name_or_path must be set in speculators_config")
verifier_config = AutoConfig.from_pretrained(verifier_name_or_path)
if hasattr(verifier_config, "text_config"):
verifier_config = verifier_config.text_config
num_verifier_layers = verifier_config.num_hidden_layers

if config.aux_hidden_state_layer_ids is not None:
self.target_layer_ids = config.aux_hidden_state_layer_ids
else:
# Eagle3 defaults; write back so they are persisted in config.json
self.target_layer_ids = [
2,
num_verifier_layers // 2,
num_verifier_layers - 3,
]
# set defaults to config if not provided - vLLM will fail otherwise
config.aux_hidden_state_layer_ids = self.target_layer_ids
if config.aux_hidden_state_layer_ids is None:
raise ValueError(
"aux_hidden_state_layer_ids must be set in DFlashSpeculatorConfig. "
"Use DFlashDraftModel.from_training_args() to resolve defaults."
)
self.target_layer_ids = config.aux_hidden_state_layer_ids
Comment thread
fynnsu marked this conversation as resolved.

self.norm = Qwen3RMSNorm(
config.transformer_layer_config.hidden_size,
Expand Down Expand Up @@ -140,12 +125,17 @@ def from_training_args(
GreedyTokenProposalConfig,
)

target_layer_ids = resolve_target_layer_ids(
kwargs.get("target_layer_ids"),
kwargs["verifier_name_or_path"],
)

config = DFlashSpeculatorConfig(
transformer_layer_config=verifier_config,
draft_vocab_size=kwargs["draft_vocab_size"],
block_size=kwargs.get("block_size", 8),
max_anchors=kwargs.get("max_anchors", 3072),
aux_hidden_state_layer_ids=kwargs.get("target_layer_ids"),
aux_hidden_state_layer_ids=target_layer_ids,
mask_token_id=kwargs.get("mask_token_id"),
speculators_config=SpeculatorsConfig(
algorithm="dflash",
Expand Down
18 changes: 5 additions & 13 deletions src/speculators/models/eagle3/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
extend_mask_for_draft_tokens,
)
from speculators.models.eagle3.model_definitions import model_classes
from speculators.models.utils import resolve_target_layer_ids
from speculators.proposals.greedy import GreedyTokenProposalConfig
from speculators.utils.loading import load_model_layers

Expand Down Expand Up @@ -427,19 +428,10 @@ def from_training_args(
Returns:
Initialized Eagle3DraftModel
"""
target_layer_ids = kwargs.get("target_layer_ids")
if target_layer_ids is None:
unmodified_verifier_config = AutoConfig.from_pretrained(
kwargs["verifier_name_or_path"]
)
num_target_layers = unmodified_verifier_config.num_hidden_layers
target_layer_ids = [2, num_target_layers // 2, num_target_layers - 3]
warnings.warn(
"--target-layer-ids is not explicitly set. Setting target "
f"layers to {target_layer_ids}. If custom target layers were used "
"when launching vllm datagen, please set them explicitly.",
stacklevel=2,
)
target_layer_ids = resolve_target_layer_ids(
kwargs.get("target_layer_ids"),
kwargs["verifier_name_or_path"],
)

config = Eagle3SpeculatorConfig(
transformer_layer_config=verifier_config,
Expand Down
33 changes: 33 additions & 0 deletions src/speculators/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import warnings

from transformers import AutoConfig, PretrainedConfig


def get_verifier_config(verifier_name_or_path: str) -> PretrainedConfig:
verifier_config = AutoConfig.from_pretrained(verifier_name_or_path)
if hasattr(verifier_config, "text_config"):
verifier_config = verifier_config.text_config
return verifier_config


DEFAULT_TARGET_LAYER_IDS_WARNING = (
"--target-layer-ids is not explicitly set. Setting target "
"layers to {target_layer_ids}. If custom target layers were used "
"when launching vllm datagen, please set them explicitly."
)


def resolve_target_layer_ids(
target_layer_ids: list[int] | None,
verifier_name_or_path: str,
) -> list[int]:
if target_layer_ids is not None:
return target_layer_ids

num_layers = get_verifier_config(verifier_name_or_path).num_hidden_layers
target_layer_ids = [2, num_layers // 2, num_layers - 3]
warnings.warn(
DEFAULT_TARGET_LAYER_IDS_WARNING.format(target_layer_ids=target_layer_ids),
stacklevel=3,
)
return target_layer_ids
Loading