diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py
index 4cd927d82..fb837a3b3 100644
--- a/scripts/data_generation_offline.py
+++ b/scripts/data_generation_offline.py
@@ -416,7 +416,20 @@ def main():
logger.info("EAGLE Offline Data Generation")
- dataset = load_from_disk(args.preprocessed_data)
+ # `prepare_data.py` persists dataset formatting metadata. If we load it as-is,
+ # only torch-formatted training columns may be exposed, hiding multimodal
+ # metadata columns like `messages_json`/`mm_file`. That would prevent
+ # `build_client_item` from forwarding `multi_modal_data` to vLLM and can
+ # produce degenerate image placeholder hidden states.
+ dataset = load_from_disk(args.preprocessed_data).with_format(None)
+
+ columns = dataset.column_names
+ logger.info("Loaded dataset columns: %s", columns)
+ if "messages_json" not in columns and "messages" not in columns:
+ logger.warning(
+ "Dataset has no `messages_json`/`messages` column. "
+ "If this dataset is multimodal, vLLM requests may miss multi_modal_data."
+ )
try:
asyncio.run(generate_and_save_hidden_states(args, dataset))
diff --git a/scripts/launch_vllm.py b/scripts/launch_vllm.py
index 0d310c240..0f298c308 100644
--- a/scripts/launch_vllm.py
+++ b/scripts/launch_vllm.py
@@ -5,6 +5,113 @@
import warnings
+def unwrap_verifier_configs(config):
+ """Return multimodal/container config and the text backbone config."""
+ multimodal_config = getattr(config, "thinker_config", config)
+ text_config = multimodal_config
+ if hasattr(text_config, "text_config"):
+ text_config = text_config.text_config
+ return multimodal_config, text_config
+
+
+def get_deepstack_visual_indexes(multimodal_config) -> list[int]:
+ """Get DeepStack layer indexes when present on the verifier."""
+ vision_config = getattr(multimodal_config, "vision_config", None)
+ deepstack_layers = getattr(vision_config, "deepstack_visual_indexes", None)
+ if deepstack_layers is None:
+ deepstack_layers = getattr(multimodal_config, "deepstack_visual_indexes", None)
+ return list(deepstack_layers or [])
+
+
+def deduplicate_layer_ids(layer_ids: list[int]) -> list[int]:
+ """Deduplicate layer ids while preserving user-specified order."""
+ seen: set[int] = set()
+ result: list[int] = []
+ for layer_id in layer_ids:
+ if layer_id not in seen:
+ seen.add(layer_id)
+ result.append(layer_id)
+ return result
+
+
+def validate_layer_ids(
+ layer_ids: list[int], num_hidden_layers: int, option_name: str
+) -> list[int]:
+ """Validate vLLM post-layer ids and preserve the effective order."""
+ validated = deduplicate_layer_ids(layer_ids)
+ invalid = [
+ layer_id
+ for layer_id in validated
+ if layer_id < 0 or layer_id > num_hidden_layers
+ ]
+ if invalid:
+ raise ValueError(
+ f"{option_name} contains invalid layer ids {invalid}. "
+ f"Expected ids in [0, {num_hidden_layers}], where 0 is the "
+ "embedding output and num_hidden_layers is the final decoder output."
+ )
+ return validated
+
+
+def get_default_target_layer_ids(multimodal_config, num_hidden_layers: int) -> list[int]:
+ """Return default auxiliary layer ids used by training."""
+ deepstack_layers = set(get_deepstack_visual_indexes(multimodal_config))
+ candidate_layer_ids = [2, num_hidden_layers // 2, num_hidden_layers - 3]
+ return [
+ layer_id - 1 if layer_id in deepstack_layers else layer_id
+ for layer_id in candidate_layer_ids
+ ]
+
+
+def resolve_layer_ids(args, multimodal_config, num_hidden_layers: int):
+ """Resolve training layer ids and exact vLLM extraction layer ids.
+
+ vLLM's extract_hidden_states path stores exactly the configured
+ eagle_aux_hidden_state_layer_ids. DFlash data loading consumes all but the
+ last stored slice as training auxiliary hidden states and treats the last
+ slice as the verifier/final reference state. Therefore, when
+ --include-last-layer is enabled, the final layer is forced to the end of the
+ extraction list but is not reported as a training target layer id.
+ """
+ if args.target_layer_ids:
+ cli_layer_ids = validate_layer_ids(
+ list(args.target_layer_ids), num_hidden_layers, "--target-layer-ids"
+ )
+ # If users pass the final layer explicitly, keep DFlash layout correct by
+ # moving it to the last extracted slice. This avoids a costly runtime
+ # tensor reorder and preserves the loader's [:, :-1] / [:, -1] split.
+ target_layer_ids = [
+ layer_id for layer_id in cli_layer_ids if layer_id != num_hidden_layers
+ ]
+ if not target_layer_ids and args.include_last_layer:
+ raise ValueError(
+ "--target-layer-ids must contain at least one non-final auxiliary "
+ "layer when --include-last-layer is enabled."
+ )
+ source = "custom"
+ else:
+ target_layer_ids = validate_layer_ids(
+ get_default_target_layer_ids(multimodal_config, num_hidden_layers),
+ num_hidden_layers,
+ "default target layer ids",
+ )
+ source = "default"
+
+ extraction_layer_ids = list(target_layer_ids)
+ if args.include_last_layer:
+ extraction_layer_ids.append(num_hidden_layers)
+ extraction_layer_ids = validate_layer_ids(
+ extraction_layer_ids, num_hidden_layers, "resolved extraction layer ids"
+ )
+ if not extraction_layer_ids:
+ raise ValueError(
+ "At least one vLLM extraction layer id must be selected. Pass one or "
+ "more --target-layer-ids values or keep --include-last-layer enabled."
+ )
+
+ return target_layer_ids, extraction_layer_ids, source
+
+
def parse_args():
parser = argparse.ArgumentParser(
description="Launch vLLM for hidden states extraction",
@@ -24,12 +131,17 @@ def parse_args():
)
parser.add_argument(
"--target-layer-ids",
+ "--layer-ids",
+ dest="target_layer_ids",
type=int,
nargs="+",
help=(
- "(Optional) A (space separated) list of integer layer ids. Defaults to "
- "[2, num_hidden_layers // 2, num_hidden_layers - 3]. "
- "Note: if set, you must also pass the same value into the training process"
+ "Auxiliary post-layer ids to extract for training. Alias: --layer-ids. "
+ "vLLM layer ids are in [0, num_hidden_layers], where 0 is the "
+ "embedding output and num_hidden_layers is the final decoder output. "
+ "When --include-last-layer is enabled, num_hidden_layers is appended "
+ "to the vLLM extraction ids and should not be passed to training. "
+ "Defaults to [2, num_hidden_layers // 2, num_hidden_layers - 3]."
),
)
parser.add_argument(
@@ -37,8 +149,9 @@ def parse_args():
action=argparse.BooleanOptionalAction,
default=True,
help=(
- "Append the last layer (num_hidden_layers) to "
- "target_layer_ids for verifier hidden states extraction. Default: True"
+ "For DFlash models, append the last layer (num_hidden_layers) to the "
+ "vLLM extraction ids as the final verifier/reference slice. "
+ "Default: True"
),
)
parser.add_argument(
@@ -56,34 +169,58 @@ def main():
from transformers import AutoConfig # noqa: PLC0415
- config = AutoConfig.from_pretrained(args.model)
- if hasattr(config, "text_config"):
- config = config.text_config
+ raw_config = AutoConfig.from_pretrained(args.model)
+ multimodal_config, config = unwrap_verifier_configs(raw_config)
num_hidden_layers = config.num_hidden_layers
if args.target_layer_ids:
- target_layer_ids = args.target_layer_ids
- if args.include_last_layer and num_hidden_layers not in target_layer_ids:
- target_layer_ids.append(num_hidden_layers)
+ training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids(
+ args, multimodal_config, num_hidden_layers
+ )
warnings.warn(
- f"Using custom target layer ids {target_layer_ids}. These "
- "must also be explicitly passed into the training script.",
+ "Using custom target layer ids. Pass "
+ f"{training_target_layer_ids} to the training script; vLLM will "
+ f"extract {extraction_layer_ids}.",
stacklevel=2,
)
else:
- target_layer_ids = [
- 2,
- num_hidden_layers // 2,
- num_hidden_layers - 3,
- num_hidden_layers,
- ]
+ training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids(
+ args, multimodal_config, num_hidden_layers
+ )
+
+ print(
+ "Layer ids: "
+ f"source={layer_id_source}, training_target_layer_ids="
+ f"{training_target_layer_ids}, extraction_layer_ids={extraction_layer_ids}"
+ )
+
+ # Build overrides for ExtractHiddenStatesConfig.
+ # For nested multimodal configs, promote text-backbone fields to top level
+ # so vLLM can resolve a valid text config.
+ hf_config_overrides: dict = {"eagle_aux_hidden_state_layer_ids": extraction_layer_ids}
+ if config is not raw_config:
+ # Promote nested text-backbone fields; drop conflicting wrapper fields.
+ _text_cfg_dict = config.to_dict()
+ for _k in ("architectures", "model_type", "auto_map", "torch_dtype"):
+ _text_cfg_dict.pop(_k, None)
+
+ # Clear nested selector attrs so HF falls back to promoted top-level
+ # text fields instead of stale dict payloads.
+ for _nested_text_attr in (
+ "text_config",
+ "text_encoder",
+ "decoder",
+ "generator",
+ ):
+ _text_cfg_dict[_nested_text_attr] = None
+
+ # kwargs override model_dict, exposing text attrs at top level.
+ hf_config_overrides = {**_text_cfg_dict, **hf_config_overrides}
speculative_config = {
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
- "draft_model_config": {
- "hf_config": {"eagle_aux_hidden_state_layer_ids": target_layer_ids}
- },
+ "draft_model_config": {"hf_config": hf_config_overrides},
}
kv_transfer_config = {
"kv_connector": "ExampleHiddenStatesConnector",
diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py
index 5c4624103..6eb433cba 100644
--- a/scripts/prepare_data.py
+++ b/scripts/prepare_data.py
@@ -190,6 +190,7 @@ def main():
turn_dropout=args.turn_dropout,
minimum_valid_tokens=args.minimum_valid_tokens,
trust_remote_code=args.trust_remote_code,
+ multimodal_output_dir=output,
)
log.info("Done preparing data")
diff --git a/scripts/train.py b/scripts/train.py
index a1e4e6de5..573e7a162 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -1,4 +1,5 @@
import argparse
+import json
import logging
import random
import warnings
@@ -65,6 +66,36 @@ def set_seed(seed: int, deterministic: bool = False):
torch.backends.cudnn.benchmark = False
+def unwrap_verifier_text_config(verifier_config: PretrainedConfig) -> PretrainedConfig:
+ """Unwrap multimodal verifier configs to their text backbone config."""
+ if hasattr(verifier_config, "thinker_config"):
+ verifier_config = verifier_config.thinker_config
+ if hasattr(verifier_config, "text_config"):
+ verifier_config = verifier_config.text_config
+ return verifier_config
+
+
+def get_default_draft_intermediate_size(verifier_config: PretrainedConfig) -> int:
+ """Infer a dense drafter FFN width from dense or MoE verifier configs."""
+ moe_ffn = getattr(verifier_config, "moe_intermediate_size", None)
+ if moe_ffn is not None:
+ experts_per_tok = int(getattr(verifier_config, "num_experts_per_tok", 1))
+ active = int(moe_ffn) * max(experts_per_tok, 1)
+ shared_ffn = getattr(verifier_config, "shared_expert_intermediate_size", None)
+ if shared_ffn is not None:
+ active += int(shared_ffn)
+ return active
+
+ dense_ffn = getattr(verifier_config, "intermediate_size", None)
+ if dense_ffn is not None:
+ return int(dense_ffn)
+
+ raise AttributeError(
+ f"Cannot infer draft intermediate_size from {type(verifier_config).__name__}; "
+ "pass --draft-intermediate-size explicitly."
+ )
+
+
def setup_dataloader(
dataset: BaseDataset,
world_size: int,
@@ -112,8 +143,18 @@ def create_transformer_layer_config( # noqa: C901
num_layers: int,
draft_arch: str,
hidden_act: str | None,
+ # --- DFlash Hybrid Attention sliding window support ---
sliding_window: int,
sliding_window_indices: list[int],
+ # --- Qwen3.6 Hybrid Attention support ---
+ draft_intermediate_size: int | None = None,
+ draft_num_attention_heads: int | None = None,
+ draft_num_key_value_heads: int | None = None,
+ draft_head_dim: int | None = None,
+ draft_rope_scaling: dict | None = None,
+ draft_rope_theta: float | None = None,
+ draft_max_position_embeddings: int | None = None,
+ mrope_full_head_hack: bool = True,
) -> PretrainedConfig:
if draft_arch not in DRAFT_ARCH_CONFIGS:
raise ValueError(
@@ -130,11 +171,9 @@ def create_transformer_layer_config( # noqa: C901
)
config_class = DRAFT_ARCH_CONFIGS[draft_arch]
- verifier_config = AutoConfig.from_pretrained(verifier_name_or_path)
-
- # For multimodal models (Qwen3VL, etc.), extract text_config
- if hasattr(verifier_config, "text_config"):
- verifier_config = verifier_config.text_config
+ verifier_config = unwrap_verifier_text_config(
+ AutoConfig.from_pretrained(verifier_name_or_path, trust_remote_code=True)
+ )
hidden_act = (
hidden_act
@@ -147,56 +186,142 @@ def create_transformer_layer_config( # noqa: C901
"nor 'hidden_activation'"
)
- head_dim = getattr(verifier_config, "head_dim", None)
- num_attention_heads = verifier_config.num_attention_heads
- num_key_value_heads = verifier_config.num_key_value_heads
+ if draft_intermediate_size is None:
+ draft_intermediate_size = get_default_draft_intermediate_size(verifier_config)
- if (
- head_dim
- and verifier_config.hidden_size % num_attention_heads != 0
- and verifier_config.hidden_size % head_dim == 0
- ):
- num_attention_heads = verifier_config.hidden_size // head_dim
- if num_attention_heads % num_key_value_heads != 0:
- num_key_value_heads = num_attention_heads
+ n_heads = draft_num_attention_heads or verifier_config.num_attention_heads
+ n_kv = draft_num_key_value_heads or verifier_config.num_key_value_heads
+ hd = draft_head_dim or getattr(verifier_config, "head_dim", None)
+ hidden_size = verifier_config.hidden_size
+ resolved_head_dim = hd or hidden_size // n_heads
+ if n_heads % n_kv != 0:
+ raise ValueError(
+ f"Invalid GQA ratio: num_attention_heads({n_heads}) must be divisible "
+ f"by num_key_value_heads({n_kv})."
+ )
+ if resolved_head_dim <= 0:
+ raise ValueError(f"Invalid head_dim({resolved_head_dim}); must be positive.")
+
+ rope_kwargs: dict = {}
+ verifier_rope_theta = getattr(verifier_config, "rope_theta", None)
+ if verifier_rope_theta is None:
+ for src_name in ("rope_parameters", "rope_scaling"):
+ src = getattr(verifier_config, src_name, None)
+ if isinstance(src, dict) and src.get("rope_theta") is not None:
+ verifier_rope_theta = src["rope_theta"]
+ logger.info(
+ "Drafter rope_theta recovered from verifier "
+ f"{src_name}.rope_theta = {verifier_rope_theta}."
+ )
+ break
+ if verifier_rope_theta is not None:
+ rope_kwargs["rope_theta"] = verifier_rope_theta
+
+ verifier_rope_scaling = getattr(verifier_config, "rope_scaling", None)
+ verifier_rope_parameters = getattr(verifier_config, "rope_parameters", None)
+ if verifier_rope_scaling is not None:
+ rope_kwargs["rope_scaling"] = dict(verifier_rope_scaling)
+ elif isinstance(verifier_rope_parameters, dict):
+ rope_kwargs["rope_scaling"] = dict(verifier_rope_parameters)
+
+ if draft_rope_scaling is not None:
+ cli_rope_scaling = dict(draft_rope_scaling)
+ cli_rope_theta = cli_rope_scaling.pop("rope_theta", None)
+ if cli_rope_theta is not None:
+ rope_kwargs["rope_theta"] = cli_rope_theta
+ rope_kwargs["rope_scaling"] = cli_rope_scaling
+ logger.info(f"Drafter rope_scaling overridden via CLI: {cli_rope_scaling}")
+
+ if draft_rope_theta is not None:
+ rope_kwargs["rope_theta"] = float(draft_rope_theta)
+ logger.info(
+ "Drafter rope_theta overridden via --draft-rope-theta: "
+ f"{rope_kwargs['rope_theta']}"
+ )
+
+ rope_scaling_dict = rope_kwargs.get("rope_scaling")
+ if isinstance(rope_scaling_dict, dict) and "mrope_section" in rope_scaling_dict:
+ if draft_rope_theta is not None:
+ rope_scaling_dict["rope_theta"] = float(draft_rope_theta)
+ elif "rope_theta" in rope_kwargs and "rope_theta" not in rope_scaling_dict:
+ rope_scaling_dict["rope_theta"] = rope_kwargs["rope_theta"]
+
+ inherited_partial = float(rope_scaling_dict.get("partial_rotary_factor", 1.0))
+ if mrope_full_head_hack and inherited_partial < 1.0:
+ old_section = list(rope_scaling_dict["mrope_section"])
+ inv = 1.0 / inherited_partial
+ if abs(inv - round(inv)) > 1e-6:
+ raise ValueError(
+ "mrope_full_head_hack cannot rescale mrope_section because "
+ f"1/partial_rotary_factor={inv} is not an integer."
+ )
+ scale = int(round(inv))
+ new_section = [int(x) * scale for x in old_section]
+ if 2 * sum(new_section) != resolved_head_dim:
+ raise ValueError(
+ "mrope_full_head_hack rescaling produced inconsistent "
+ f"mrope_section {new_section}: 2*sum={2 * sum(new_section)} "
+ f"but head_dim={resolved_head_dim}."
+ )
+ rope_scaling_dict["mrope_section"] = new_section
+ rope_scaling_dict["partial_rotary_factor"] = 1.0
+ logger.warning(
+ "MRoPE full-head hack applied: partial_rotary_factor "
+ f"{inherited_partial} -> 1.0, mrope_section {old_section} -> "
+ f"{new_section}."
+ )
+ elif not mrope_full_head_hack and inherited_partial < 1.0:
+ logger.warning(
+ "mrope_full_head_hack=False with partial_rotary_factor="
+ f"{inherited_partial} < 1.0 can cause HF trainer / vLLM "
+ "partial-rotation mismatch."
+ )
+
+ max_pos = (
+ int(draft_max_position_embeddings)
+ if draft_max_position_embeddings is not None
+ else verifier_config.max_position_embeddings
+ )
+ # --- DFlash sliding window layer_types ---
if sliding_window_indices and (
min(sliding_window_indices) < 0 or max(sliding_window_indices) >= num_layers
):
raise ValueError(
- "Sliding window indices must be validate draft layer ids "
+ "Sliding window indices must be valid draft layer ids "
"in range [0, num_layers)."
)
layer_types = [
- "sliding_attention" if i in sliding_window_indices else "full_attention"
+ "sliding_attention" if i in (sliding_window_indices or []) else "full_attention"
for i in range(num_layers)
]
config = config_class(
vocab_size=verifier_config.vocab_size,
- hidden_size=verifier_config.hidden_size,
- intermediate_size=verifier_config.intermediate_size,
+ hidden_size=hidden_size,
+ intermediate_size=draft_intermediate_size,
num_hidden_layers=num_layers,
- num_attention_heads=num_attention_heads,
- num_key_value_heads=num_key_value_heads,
+ num_attention_heads=n_heads,
+ num_key_value_heads=n_kv,
hidden_act=hidden_act,
- max_position_embeddings=verifier_config.max_position_embeddings,
+ max_position_embeddings=max_pos,
initializer_range=verifier_config.initializer_range,
rms_norm_eps=verifier_config.rms_norm_eps,
- head_dim=head_dim,
+ head_dim=resolved_head_dim,
tie_word_embeddings=False,
sliding_window=sliding_window,
layer_types=layer_types,
)
- # New rope parameters definition introduced in transformers 5.0
- if version.parse(transformers.__version__) >= version.parse("5.0.0"):
- if hasattr(verifier_config, "rope_parameters"):
- config.rope_parameters = deepcopy(verifier_config.rope_parameters)
- else:
- if hasattr(verifier_config, "rope_scaling"):
- config.rope_scaling = deepcopy(verifier_config.rope_scaling)
- config.rope_theta = getattr(verifier_config, "rope_theta", 10000.0)
+ # RoPE: use transformers >= 5.0 rope_parameters path
+ if hasattr(verifier_config, "rope_parameters"):
+ config.rope_parameters = deepcopy(verifier_config.rope_parameters)
+ # Apply CLI overrides into rope_parameters
+ if rope_kwargs.get("rope_scaling"):
+ config.rope_parameters = deepcopy(rope_kwargs["rope_scaling"])
+ if rope_kwargs.get("rope_theta") is not None:
+ if isinstance(config.rope_parameters, dict):
+ config.rope_parameters["rope_theta"] = rope_kwargs["rope_theta"]
return config
@@ -267,9 +392,9 @@ def parse_vocab_mappings(args: argparse.Namespace):
"None. Using full verifier vocab"
)
# When vocab mapping is not provided, use the full verifier vocab
- verifier_config = AutoConfig.from_pretrained(args.verifier_name_or_path)
- if hasattr(verifier_config, "text_config"):
- verifier_config = verifier_config.text_config
+ verifier_config = unwrap_verifier_text_config(
+ AutoConfig.from_pretrained(args.verifier_name_or_path, trust_remote_code=True)
+ )
return None, None, verifier_config.vocab_size
@@ -305,6 +430,14 @@ def main(args: argparse.Namespace):
num_layers=args.num_layers,
draft_arch=args.draft_arch,
hidden_act=args.draft_hidden_act,
+ draft_intermediate_size=args.draft_intermediate_size,
+ draft_num_attention_heads=args.draft_num_attention_heads,
+ draft_num_key_value_heads=args.draft_num_key_value_heads,
+ draft_head_dim=args.draft_head_dim,
+ draft_rope_scaling=args.draft_rope_scaling,
+ draft_rope_theta=args.draft_rope_theta,
+ draft_max_position_embeddings=args.draft_max_position_embeddings,
+ mrope_full_head_hack=args.draft_mrope_full_head_hack,
sliding_window=args.sliding_window,
sliding_window_indices=args.sliding_window_indices,
)
@@ -371,6 +504,7 @@ def main(args: argparse.Namespace):
transform=noise_transform,
split_ratio=0.9,
model=args.verifier_name_or_path,
+ verifier_name_or_path=args.verifier_name_or_path,
hidden_states_dtype=hidden_states_dtype,
request_timeout=args.request_timeout,
max_retries=args.max_retries,
@@ -384,6 +518,7 @@ def main(args: argparse.Namespace):
on_generate=args.on_generate,
split_ratio=-0.1,
model=args.verifier_name_or_path,
+ verifier_name_or_path=args.verifier_name_or_path,
hidden_states_dtype=hidden_states_dtype,
request_timeout=args.request_timeout,
max_retries=args.max_retries,
@@ -591,6 +726,58 @@ def parse_args():
"vLLM deployment. If another function is desired, set as a string or leave "
"as None to automatically fall back to the verifier's activation function.",
)
+ parser.add_argument(
+ "--draft-intermediate-size",
+ type=int,
+ default=None,
+ help="Override draft FFN intermediate size; useful for MoE verifiers.",
+ )
+ parser.add_argument(
+ "--draft-num-attention-heads",
+ type=int,
+ default=None,
+ help="Override the drafter's num_attention_heads.",
+ )
+ parser.add_argument(
+ "--draft-num-key-value-heads",
+ type=int,
+ default=None,
+ help="Override the drafter's num_key_value_heads.",
+ )
+ parser.add_argument(
+ "--draft-head-dim",
+ type=int,
+ default=None,
+ help="Override the drafter's per-head hidden dimension.",
+ )
+ parser.add_argument(
+ "--draft-rope-scaling",
+ type=lambda s: json.loads(s) if s else None,
+ default=None,
+ help="JSON RoPE scaling dict to apply to the drafter during training.",
+ )
+ parser.add_argument(
+ "--draft-rope-theta",
+ type=float,
+ default=None,
+ help="Override the drafter's RoPE frequency base.",
+ )
+ parser.add_argument(
+ "--draft-max-position-embeddings",
+ type=int,
+ default=None,
+ help="Override the drafter's max_position_embeddings.",
+ )
+ parser.add_argument(
+ "--draft-mrope-full-head-hack",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help=(
+ "For MRoPE configs with partial_rotary_factor < 1, rescale "
+ "mrope_section and set partial_rotary_factor=1.0 so HF training "
+ "and vLLM inference use equivalent full-head rotary semantics."
+ ),
+ )
parser.add_argument(
"--target-layer-ids",
type=int,
diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py
index 770724cbe..ba51259a3 100644
--- a/src/speculators/data_generation/preprocessing.py
+++ b/src/speculators/data_generation/preprocessing.py
@@ -1,17 +1,21 @@
import bisect
+import inspect
+import json
import random
import re
from collections.abc import Callable
from contextlib import nullcontext
from pathlib import Path
from re import Pattern
-from typing import cast
+from typing import Any, cast
import torch
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets, load_dataset
from packaging.version import Version
+from safetensors.torch import save_file
from transformers import (
+ AutoConfig,
AutoProcessor,
BatchEncoding,
BatchFeature,
@@ -35,6 +39,18 @@
ProcessorLike = PreTrainedTokenizerBase | ProcessorMixin
+MULTIMODAL_SIDECAR_DIR = "multimodal"
+MULTIMODAL_MEDIA_TYPES = {"image", "video", "audio"}
+MULTIMODAL_ENCODER_KEYS = (
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ "second_per_grids",
+ "input_features",
+ "feature_attention_mask",
+ "audio_feature_lengths",
+)
def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: int = 0):
@@ -73,6 +89,196 @@ def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: in
log.info(highlighted)
+def _normalize_media_value(value: Any) -> Any:
+ if isinstance(value, Path):
+ return str(value)
+ filename = getattr(value, "filename", None)
+ if filename:
+ return str(filename)
+ return value
+
+
+def _to_json_compatible(value: Any) -> Any:
+ if isinstance(value, dict):
+ return {k: _to_json_compatible(v) for k, v in value.items()}
+ if isinstance(value, list):
+ return [_to_json_compatible(v) for v in value]
+ if isinstance(value, tuple):
+ return [_to_json_compatible(v) for v in value]
+ if isinstance(value, Path):
+ return str(value)
+ filename = getattr(value, "filename", None)
+ if filename:
+ return str(filename)
+ return value
+
+
+def _normalize_content_segment(segment: Any) -> dict[str, Any]:
+ if isinstance(segment, str):
+ return {"type": "text", "text": segment}
+ if not isinstance(segment, dict):
+ return {"type": "text", "text": str(segment)}
+
+ seg_type = str(segment.get("type", "text"))
+ normalized = {"type": seg_type}
+ if seg_type == "text":
+ normalized["text"] = (
+ segment.get("text")
+ or segment.get("value")
+ or segment.get("content")
+ or ""
+ )
+ elif seg_type in MULTIMODAL_MEDIA_TYPES:
+ normalized[seg_type] = _normalize_media_value(
+ segment.get(seg_type)
+ or segment.get("value")
+ or segment.get("url")
+ or segment.get("path")
+ or segment.get("source")
+ )
+ else:
+ normalized["text"] = segment.get("text") or segment.get("value") or ""
+
+ for key, value in segment.items():
+ if key in normalized or key in {"content", "value", "url", "path", "source"}:
+ continue
+ normalized[key] = _to_json_compatible(value)
+ return normalized
+
+
+def _normalize_turn_content(content: Any) -> str | list[dict[str, Any]]:
+ if isinstance(content, list):
+ return [_normalize_content_segment(seg) for seg in content]
+ return content if isinstance(content, str) else str(content or "")
+
+
+def _has_multimodal_segments(content: Any) -> bool:
+ return isinstance(content, list) and any(
+ isinstance(seg, dict) and seg.get("type") in MULTIMODAL_MEDIA_TYPES
+ for seg in content
+ )
+
+
+def _is_multimodal_conversation(conv: list[dict[str, Any]]) -> bool:
+ return any(_has_multimodal_segments(turn.get("content")) for turn in conv)
+
+
+def _serialize_messages(messages: list[dict[str, Any]]) -> str:
+ return json.dumps(_to_json_compatible(messages), ensure_ascii=False)
+
+
+def _sanitize_sidecar_prefix(prefix: str) -> str:
+ return re.sub(r"[^a-zA-Z0-9_.-]+", "_", prefix).strip("_") or "dataset"
+
+
+def _build_sidecar_path(
+ multimodal_output_dir: str | Path,
+ sample_idx: int,
+ sidecar_prefix: str,
+) -> Path:
+ base_dir = Path(multimodal_output_dir) / MULTIMODAL_SIDECAR_DIR
+ base_dir.mkdir(parents=True, exist_ok=True)
+ safe_prefix = _sanitize_sidecar_prefix(sidecar_prefix)
+ return base_dir / f"{safe_prefix}_{sample_idx}.safetensors"
+
+
+def _maybe_strip_batch_dim(value: Any) -> torch.Tensor:
+ tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
+ tensor = tensor.detach().cpu().contiguous()
+ if tensor.ndim > 0 and tensor.shape[0] == 1:
+ tensor = tensor.squeeze(0)
+ return tensor
+
+
+def _save_multimodal_sidecar(
+ encoded: dict[str, Any],
+ multimodal_output_dir: str | Path,
+ sample_idx: int,
+ sidecar_prefix: str,
+) -> str:
+ sidecar_path = _build_sidecar_path(multimodal_output_dir, sample_idx, sidecar_prefix)
+ payload = {
+ key: _maybe_strip_batch_dim(encoded[key])
+ for key in MULTIMODAL_ENCODER_KEYS
+ if key in encoded
+ }
+ if not payload:
+ raise ValueError("Multimodal sample did not produce sidecar tensor fields")
+ save_file(payload, sidecar_path)
+ return str(sidecar_path.relative_to(Path(multimodal_output_dir)))
+
+
+def _build_multimodal_loss_mask(
+ input_ids: torch.Tensor,
+ base_loss_mask: torch.Tensor,
+ placeholder_token_ids: tuple[int, ...],
+) -> torch.Tensor:
+ loss_mask = base_loss_mask.to(dtype=torch.long).clone()
+ valid_ids = [tid for tid in placeholder_token_ids if tid >= 0]
+ if not valid_ids:
+ return loss_mask
+ placeholder_tensor = torch.as_tensor(
+ valid_ids, dtype=input_ids.dtype, device=input_ids.device
+ )
+ loss_mask.masked_fill_(torch.isin(input_ids, placeholder_tensor), 0)
+ return loss_mask
+
+
+def _mask_has_positive(mask: Any) -> bool:
+ mask_tensor = _maybe_strip_batch_dim(mask)
+ return bool(mask_tensor.numel() > 0 and torch.count_nonzero(mask_tensor).item() > 0)
+
+
+_PROCESSOR_KW_CACHE: dict[int, set[str]] = {}
+
+
+def _processor_kwargs(processor: Any) -> set[str]:
+ key = id(processor)
+ cached = _PROCESSOR_KW_CACHE.get(key)
+ if cached is not None:
+ return cached
+ try:
+ sig = inspect.signature(processor.apply_chat_template)
+ names = {
+ name
+ for name, param in sig.parameters.items()
+ if param.kind is not inspect.Parameter.VAR_KEYWORD
+ }
+ except (TypeError, ValueError):
+ names = set()
+ _PROCESSOR_KW_CACHE[key] = names
+ return names
+
+
+def _conversation_use_audio_in_video(conv: list[dict[str, Any]]) -> bool:
+ for turn in conv:
+ content = turn.get("content")
+ if not isinstance(content, list):
+ continue
+ for seg in content:
+ if isinstance(seg, dict) and seg.get("use_audio_in_video"):
+ return True
+ return False
+
+
+def _as_processor_content_blocks(content: Any) -> list[dict[str, Any]]:
+ if isinstance(content, list):
+ return [
+ seg if isinstance(seg, dict) else {"type": "text", "text": str(seg)}
+ for seg in content
+ ]
+ return [
+ {"type": "text", "text": content if isinstance(content, str) else str(content or "")}
+ ]
+
+
+def _conversation_for_processor(conv: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ return [
+ {**turn, "content": _as_processor_content_blocks(turn.get("content", ""))}
+ for turn in conv
+ ]
+
+
def _normalize_conversation(
conv: list[dict],
turn_dropout: bool = False,
@@ -92,7 +298,7 @@ def _normalize_conversation(
normalized = []
for i, turn in enumerate(conv):
role = turn.get("from", turn.get("role", ""))
- content = turn.get("value") or turn.get("content") or ""
+ content = _normalize_turn_content(turn.get("value") or turn.get("content") or "")
# Map various role names to standard user/assistant
if role in ("human", "user"):
@@ -165,10 +371,12 @@ def _adapt_part_for_vllm(part: str | dict):
for modality in ("image", "video", "audio"):
if part_type == modality:
- if local_path := part.get("path"):
- file_url = f"file://{Path(local_path).absolute()}"
- return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}}
- if url := part.get("url"):
+ media_value = part.get(modality) or part.get("path") or part.get("url")
+ if isinstance(media_value, str):
+ if media_value.startswith(("http://", "https://", "file://")):
+ url = media_value
+ else:
+ url = f"file://{Path(media_value).absolute()}"
return {"type": f"{modality}_url", f"{modality}_url": {"url": url}}
if part.get("base64"):
@@ -178,13 +386,6 @@ def _adapt_part_for_vllm(part: str | dict):
f"the {modality} when saving the preprocessed dataset, "
f"please express {modality} inputs using file paths or URLs."
)
- if part.get(modality):
- expr = {"type": modality, modality: "..."}
- raise ValueError(
- f"Content part {expr} is not supported. To avoid copying "
- f"the {modality} when saving the preprocessed dataset, "
- f"please express {modality} inputs using file paths or URLs."
- )
expr = {"type": modality} | {k: "..." for k in part if k != "type"}
raise NotImplementedError(f"Unknown content part: {expr}")
@@ -234,7 +435,7 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool:
return False
# Verify the mask is not all zeros
- return any(m == 1 for m in mask)
+ return _mask_has_positive(mask)
except (TypeError, ValueError, KeyError, AttributeError) as e:
log.warning(f"An error occurred when trying to return assistant mask: {e}")
return False
@@ -289,11 +490,8 @@ def _detect_assistant_pattern(processor: ProcessorLike) -> str:
else:
role_marker = prefix
- # Strip ... blocks from the role marker. Thinking model
- # templates wrap assistant content in these tags, but the test messages
- # can produce empty blocks (e.g. "\n\n\n") with reasoning models,
- # which then get baked into the regex as literals. Removing them ensures
- # that reasoning stays within the assistant content group.
+ # Remove optional ... wrappers from role markers so
+ # regex capture focuses on assistant content boundaries.
role_marker = re.sub(r".*?\s*", "", role_marker, flags=re.DOTALL)
# Determine the stable TURN-LEVEL suffix
@@ -382,6 +580,104 @@ def _create_loss_mask_from_offsets(
return loss_mask
+def _content_text(content: Any) -> str:
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ return "".join(
+ str(seg.get("text", "")) if isinstance(seg, dict) else str(seg)
+ for seg in content
+ )
+ return str(content or "")
+
+
+def _find_token_subsequence(
+ haystack: torch.Tensor,
+ needle: torch.Tensor,
+ start: int,
+) -> int | None:
+ if needle.numel() == 0 or haystack.numel() < needle.numel():
+ return None
+ for idx in range(start, int(haystack.numel() - needle.numel()) + 1):
+ if torch.equal(haystack[idx : idx + needle.numel()], needle):
+ return idx
+ return None
+
+
+def _loss_mask_from_assistant_token_spans(
+ input_ids: torch.Tensor,
+ normalized_conv: list[dict[str, Any]],
+ tokenizer: PreTrainedTokenizerBase,
+) -> torch.Tensor | None:
+ loss_mask = torch.zeros_like(input_ids, dtype=torch.long)
+ cursor = 0
+ matches_found = 0
+ for turn in normalized_conv:
+ if turn.get("role") != "assistant":
+ continue
+ text = _content_text(turn.get("content", ""))
+ if not text:
+ continue
+ tokenized = tokenizer(text, add_special_tokens=False)
+ token_ids = tokenized.get("input_ids", [])
+ if not token_ids:
+ continue
+ needle = torch.as_tensor(token_ids, dtype=torch.long, device=input_ids.device)
+ span_start = _find_token_subsequence(input_ids, needle, cursor)
+ if span_start is None:
+ log.warning("Could not align assistant content tokens in processor input_ids")
+ continue
+ span_end = span_start + int(needle.numel())
+ loss_mask[span_start:span_end] = 1
+ cursor = span_end
+ matches_found += 1
+ return loss_mask if matches_found else None
+
+
+def _loss_mask_from_ids_fallback(
+ input_ids: torch.Tensor,
+ normalized_conv: list[dict[str, Any]],
+ tokenizer: PreTrainedTokenizerBase,
+ assistant_pattern: str | Pattern[str],
+ placeholder_token_ids: tuple[int, ...] = (),
+) -> torch.Tensor:
+ formatted_raw_any = tokenizer.apply_chat_template(
+ normalized_conv,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ formatted_raw = formatted_raw_any if isinstance(formatted_raw_any, str) else ""
+ encoding = tokenizer(
+ formatted_raw,
+ return_offsets_mapping=True,
+ add_special_tokens=False,
+ )
+ mask_text = _create_loss_mask_from_offsets(
+ formatted_raw, encoding["offset_mapping"], assistant_pattern
+ ).to(torch.long)
+
+ target_len = int(input_ids.shape[0])
+ if placeholder_token_ids:
+ placeholder_tensor = torch.as_tensor(
+ placeholder_token_ids, dtype=input_ids.dtype, device=input_ids.device
+ )
+ is_placeholder = torch.isin(input_ids, placeholder_tensor)
+ if bool(is_placeholder.any()):
+ aligned = torch.zeros(target_len, dtype=torch.long)
+ text_positions = (~is_placeholder).nonzero(as_tuple=True)[0]
+ copy_len = min(int(text_positions.shape[0]), int(mask_text.shape[0]))
+ if copy_len > 0:
+ aligned.index_copy_(0, text_positions[:copy_len], mask_text[:copy_len])
+ return aligned
+
+ if mask_text.shape[0] == target_len:
+ return mask_text
+ if mask_text.shape[0] > target_len:
+ return mask_text[:target_len]
+ pad = torch.zeros(target_len - mask_text.shape[0], dtype=torch.long)
+ return torch.cat([mask_text, pad], dim=0)
+
+
def _get_input_ids_loss_mask(
normalized_conv: list[dict],
processor: ProcessorLike,
@@ -485,63 +781,170 @@ def _preprocess_batch(
assistant_pattern: str | Pattern[str] | None,
turn_dropout: bool = False,
minimum_valid_tokens: int | None = None,
+ *,
+ indices: list[int] | None = None,
+ placeholder_token_ids: tuple[int, ...] = (),
+ multimodal_output_dir: str | Path | None = None,
+ sidecar_prefix: str = "dataset",
) -> dict[str, list]:
"""Process a batch of conversations into tokenized format with loss masks."""
- results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []}
- conversations: list[dict] = examples.get("conversations", [])
-
- # MM inputs must use Chat Completions API
- if isinstance(processor, ProcessorMixin):
+ results: dict[str, list] = {
+ "input_ids": [],
+ "loss_mask": [],
+ "seq_len": [],
+ "messages_json": [],
+ "mm_file": [],
+ "use_audio_in_video": [],
+ }
+ include_messages = isinstance(processor, ProcessorMixin)
+ if include_messages:
results["messages"] = []
+ conversations: list[dict] = examples.get("conversations", [])
if not conversations:
log.warning(f"No conversations key found. Keys: {list(examples.keys())}")
return results
for idx, conv in enumerate(conversations):
+ sample_idx = indices[idx] if indices is not None else idx
if not conv or not isinstance(conv, list):
+ log.warning(
+ f"[DROP sample_idx={sample_idx}] reason=empty_or_non_list_conversation "
+ f"type={type(conv).__name__}"
+ )
continue
- # Normalize to standard format with optional turn dropout
normalized_conv = _normalize_conversation(conv, turn_dropout)
if not normalized_conv:
+ log.warning(
+ f"[DROP sample_idx={sample_idx}] reason=normalized_conversation_empty "
+ f"raw_turns={len(conv)}"
+ )
continue
+ is_multimodal = include_messages and _is_multimodal_conversation(normalized_conv)
+ messages = _adapt_conv_for_vllm(normalized_conv) if include_messages else []
+ messages_json = _serialize_messages(messages) if is_multimodal else ""
+ mm_file = ""
+ use_audio_in_video = int(_conversation_use_audio_in_video(normalized_conv))
+
try:
- input_ids, loss_mask = _get_input_ids_loss_mask(
- normalized_conv,
- processor,
- max_length=max_length,
- assistant_pattern=assistant_pattern,
- conv_idx=idx,
+ if is_multimodal:
+ allowed = _processor_kwargs(processor)
+ call_kwargs: dict[str, Any] = dict(
+ tokenize=True,
+ add_generation_prompt=False,
+ return_dict=True,
+ return_tensors="pt",
+ processor_kwargs={},
+ )
+ for key in (
+ "load_audio",
+ "load_image",
+ "load_video",
+ "load_audios",
+ "load_images",
+ "load_videos",
+ ):
+ if key in allowed:
+ call_kwargs[key] = True
+ supports_mask = (
+ assistant_pattern is None
+ and "return_assistant_tokens_mask" in allowed
+ )
+ if supports_mask:
+ call_kwargs["return_assistant_tokens_mask"] = True
+
+ encoded_any = processor.apply_chat_template(
+ _conversation_for_processor(normalized_conv), **call_kwargs
+ )
+ encoded = cast("dict[str, Any]", encoded_any)
+ input_ids = _maybe_strip_batch_dim(encoded["input_ids"]).to(torch.long)
+
+ mask_key = None
+ if supports_mask:
+ if "assistant_masks" in encoded:
+ mask_key = "assistant_masks"
+ elif "assistant_mask" in encoded:
+ mask_key = "assistant_mask"
+ if mask_key is not None:
+ candidate_loss_mask = _maybe_strip_batch_dim(encoded[mask_key]).to(
+ torch.long
+ )
+ base_loss_mask = candidate_loss_mask if _mask_has_positive(candidate_loss_mask) else None
+ else:
+ base_loss_mask = None
+
+ if base_loss_mask is None:
+ if assistant_pattern is None:
+ assistant_pattern = _detect_assistant_pattern(processor)
+ base_loss_mask = _loss_mask_from_assistant_token_spans(
+ input_ids, normalized_conv, get_tokenizer(processor)
+ )
+ if base_loss_mask is None:
+ base_loss_mask = _loss_mask_from_ids_fallback(
+ input_ids,
+ normalized_conv,
+ get_tokenizer(processor),
+ assistant_pattern,
+ placeholder_token_ids,
+ )
+
+ loss_mask = _build_multimodal_loss_mask(
+ input_ids, base_loss_mask, placeholder_token_ids
+ )
+ if len(input_ids) > max_length:
+ log.warning(
+ f"[DROP sample_idx={sample_idx}] reason=overlength_multimodal "
+ f"len(input_ids)={len(input_ids)} max_length={max_length}"
+ )
+ continue
+ if multimodal_output_dir is not None:
+ mm_file = _save_multimodal_sidecar(
+ encoded, multimodal_output_dir, sample_idx, sidecar_prefix
+ )
+ else:
+ input_ids, loss_mask = _get_input_ids_loss_mask(
+ normalized_conv,
+ processor,
+ max_length=max_length,
+ assistant_pattern=assistant_pattern,
+ conv_idx=idx,
+ )
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+
+ assert len(input_ids) == len(loss_mask), (
+ f"Shape mismatch: input_ids={len(input_ids)}, loss_mask={len(loss_mask)}"
)
+
+ if minimum_valid_tokens is not None:
+ num_valid_tokens = int(loss_mask.sum().item())
+ if num_valid_tokens < minimum_valid_tokens:
+ log.warning(
+ f"[DROP sample_idx={sample_idx}] reason=too_few_valid_tokens "
+ f"num_valid_tokens={num_valid_tokens} "
+ f"minimum_valid_tokens={minimum_valid_tokens} "
+ f"len(input_ids)={len(input_ids)} is_multimodal={is_multimodal}"
+ )
+ continue
+
+ results["input_ids"].append(input_ids)
+ results["loss_mask"].append(loss_mask.to(torch.long))
+ results["seq_len"].append(len(input_ids))
+ if include_messages:
+ results["messages"].append(messages)
+ results["messages_json"].append(messages_json)
+ results["mm_file"].append(mm_file)
+ results["use_audio_in_video"].append(use_audio_in_video)
except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e:
log.error(
- f"Failed to process conversation {idx} "
+ f"[DROP sample_idx={sample_idx}] reason=exception "
+ f"exc_type={type(e).__name__} "
f"(assistant_pattern={assistant_pattern is not None}): {e}"
)
continue
- # Assert shapes match
- assert len(input_ids) == len(loss_mask), (
- f"Shape mismatch: input_ids={len(input_ids)}, loss_mask={len(loss_mask)}"
- )
-
- # Filtering samples out with too few valid tokens
- if minimum_valid_tokens is not None:
- num_valid_tokens = int(loss_mask.sum().item())
- if num_valid_tokens < minimum_valid_tokens:
- continue
-
- # Append to results
- results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long))
- results["loss_mask"].append(loss_mask)
- results["seq_len"].append(len(input_ids))
-
- if "messages" in results:
- results["messages"].append(_adapt_conv_for_vllm(normalized_conv))
-
return results
@@ -553,6 +956,9 @@ def build_eagle3_dataset(
assistant_pattern: str | Pattern[str] | None = None,
turn_dropout: bool = False,
minimum_valid_tokens: int | None = None,
+ placeholder_token_ids: tuple[int, ...] = (),
+ multimodal_output_dir: str | Path | None = None,
+ sidecar_prefix: str = "dataset",
) -> HFDataset:
"""Build EAGLE3 dataset by tokenizing conversations and creating loss masks.
@@ -580,6 +986,11 @@ def build_eagle3_dataset(
assistant_pattern = _detect_assistant_pattern(processor)
log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...")
+ if multimodal_output_dir is not None:
+ (Path(multimodal_output_dir) / MULTIMODAL_SIDECAR_DIR).mkdir(
+ parents=True, exist_ok=True
+ )
+
original_cols = dataset.column_names
# Avoid CPU contention for MM processing:
@@ -590,22 +1001,27 @@ def build_eagle3_dataset(
else nullcontext()
):
dataset = dataset.map(
- lambda examples: _preprocess_batch(
+ lambda examples, indices: _preprocess_batch(
examples,
processor,
max_length,
assistant_pattern,
turn_dropout,
minimum_valid_tokens,
+ indices=indices,
+ placeholder_token_ids=placeholder_token_ids,
+ multimodal_output_dir=multimodal_output_dir,
+ sidecar_prefix=sidecar_prefix,
),
batched=True,
+ with_indices=True,
num_proc=num_proc,
- batch_size=1000,
+ batch_size=400,
remove_columns=original_cols,
keep_in_memory=True, # skip caching
)
- dataset.set_format(type="torch")
+ dataset.set_format(type="torch", columns=["input_ids", "loss_mask", "seq_len"])
return dataset
@@ -667,6 +1083,7 @@ def load_and_preprocess_dataset(
turn_dropout: bool = False,
minimum_valid_tokens: int | None = None,
trust_remote_code: bool = False,
+ multimodal_output_dir: Path | str | None = None,
) -> tuple[HFDataset, ProcessorLike]:
"""Load, tokenize, and preprocess a dataset for EAGLE3 training.
@@ -710,6 +1127,20 @@ def load_and_preprocess_dataset(
"Please use a model with a pre-configured chat template."
)
+ placeholder_token_ids: tuple[int, ...] = ()
+ verifier_config = AutoConfig.from_pretrained(
+ target_model_path, trust_remote_code=trust_remote_code
+ )
+ multimodal_config = getattr(verifier_config, "thinker_config", verifier_config)
+ if isinstance(processor, ProcessorMixin):
+ placeholder_token_ids = tuple(
+ int(token_id)
+ for attr in ("image_token_id", "video_token_id", "audio_token_id")
+ if (token_id := getattr(multimodal_config, attr, None)) is not None
+ )
+ if multimodal_output_dir is None:
+ multimodal_output_dir = Path(token_freq_path).parent
+
processed_datasets = []
for train_data_path in train_data_paths:
log.subsection(f"Processing {train_data_path}")
@@ -742,6 +1173,9 @@ def load_and_preprocess_dataset(
assistant_pattern=assistant_pattern,
turn_dropout=turn_dropout,
minimum_valid_tokens=minimum_valid_tokens,
+ placeholder_token_ids=placeholder_token_ids,
+ multimodal_output_dir=multimodal_output_dir,
+ sidecar_prefix=train_data_path,
)
if minimum_valid_tokens is not None:
log.info(f"Kept {len(preprocessed_dataset)} samples after filtering")
diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py
index 6549444a9..c97d2b625 100644
--- a/src/speculators/data_generation/vllm_client.py
+++ b/src/speculators/data_generation/vllm_client.py
@@ -94,7 +94,15 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs):
def extract_output(
response: Completion | ChatCompletion,
token_ids: list[int],
-) -> str:
+ *,
+ trust_server_token_ids: bool = False,
+) -> str | tuple[str, list[int]]:
+ """Extract hidden-states path (and optionally server token ids).
+
+ By default, token ids must match exactly. Set
+ ``trust_server_token_ids=True`` to accept server tokenization and return
+ ``(path, prompt_token_ids)``.
+ """
if isinstance(response, Completion):
prompt_token_ids = getattr(response.choices[0], "prompt_token_ids", None)
else:
@@ -103,25 +111,36 @@ def extract_output(
if prompt_token_ids is None:
raise InvalidResponseError("Response missing prompt_token_ids")
- if prompt_token_ids != token_ids:
+ if not trust_server_token_ids and prompt_token_ids != token_ids:
raise InvalidResponseError(
- f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}"
+ f"Prompt token IDs mismatch: expected {len(token_ids)} tokens, "
+ f"got {len(prompt_token_ids)} tokens"
)
kv_transfer_params = getattr(response, "kv_transfer_params", None)
if kv_transfer_params is None:
raise InvalidResponseError("Response missing kv_transfer_params")
- return kv_transfer_params.get("hidden_states_path")
+ path = kv_transfer_params.get("hidden_states_path")
+
+ if trust_server_token_ids:
+ return path, prompt_token_ids
+ return path
class ClientItem(TypedDict):
input_ids: list[int]
"""The input token IDs."""
+ multi_modal_data: NotRequired[dict[str, list]]
+ """Per-modality URL/path lists (image/video/audio) forwarded to vLLM via
+ ``extra_body.multi_modal_data`` so the server can attach media features
+ without re-rendering the chat template."""
+
messages: NotRequired[list[ChatCompletionMessageParam]]
- """If provided, pass `messages` to Chat Completions API
- instead of passing `token_ids` to Completions API."""
+ """Optional fallback. Only consumed when ``use_chat_completions=True`` is
+ passed to :func:`generate_hidden_states`. The default token-id path
+ ignores this field."""
async def _poll_lock_async(fd, poll_interval):
@@ -172,35 +191,43 @@ async def generate_hidden_states_async(
client_item: ClientItem,
*,
timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
-) -> str:
- """
- Runs decode w/ max_tokens 1 to generate hidden states and returns path to
- hidden states file.
-
- Args:
- client: The async OpenAI client.
- model: The model ID.
- client_item: Inputs to send via the client.
- timeout: Timeout in seconds for each request attempt. None for no timeout.
+ use_chat_completions: bool = False,
+) -> str | tuple[str, list[int]]:
+ """Generate hidden states asynchronously and return the safetensors path.
+
+ For multimodal samples, set ``use_chat_completions=True`` so vLLM runs
+ its vision encoder. In that mode the return value is a tuple
+ ``(path, server_token_ids)`` — the caller must use ``server_token_ids``
+ as the authoritative input_ids (they are positionally aligned with the
+ hidden states).
+
+ For text-only samples the default Completions path is used (strict
+ token-id parity check) and a plain ``str`` path is returned.
"""
token_ids = client_item["input_ids"]
messages = client_item.get("messages")
+ is_mm_chat = use_chat_completions and messages is not None
+
coro: Coroutine[Any, Any, Completion | ChatCompletion]
- if messages is None:
- coro = client.completions.create(
+ if is_mm_chat:
+ coro = client.chat.completions.create(
model=model,
- prompt=token_ids,
+ messages=messages,
max_tokens=1,
- extra_body={"return_token_ids": True},
+ extra_body={"add_generation_prompt": False, "return_token_ids": True},
timeout=timeout,
)
else:
- coro = client.chat.completions.create(
+ extra_body: dict[str, Any] = {"return_token_ids": True}
+ mm = client_item.get("multi_modal_data")
+ if mm:
+ extra_body["multi_modal_data"] = mm
+ coro = client.completions.create(
model=model,
- messages=messages,
+ prompt=token_ids,
max_tokens=1,
- extra_body={"add_generation_prompt": False, "return_token_ids": True},
+ extra_body=extra_body,
timeout=timeout,
)
@@ -210,7 +237,7 @@ async def generate_hidden_states_async(
else:
res = await coro
- return extract_output(res, token_ids)
+ return extract_output(res, token_ids, trust_server_token_ids=is_mm_chat)
@with_retries
@@ -220,30 +247,39 @@ def generate_hidden_states(
client_item: ClientItem,
*,
timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
-) -> str:
- """
- Runs decode w/ max_tokens 1 to generate hidden states and returns path to
- hidden states file.
+ use_chat_completions: bool = False,
+) -> str | tuple[str, list[int]]:
+ """Generate hidden states via vLLM (synchronous version).
+
+ For multimodal samples, set ``use_chat_completions=True`` so vLLM runs
+ its vision encoder. Returns ``(path, server_token_ids)`` in that mode.
+ For text-only samples returns a plain ``str`` path.
"""
token_ids = client_item["input_ids"]
messages = client_item.get("messages")
+ is_mm_chat = use_chat_completions and messages is not None
+
res: Completion | ChatCompletion
- if messages is None:
- res = client.completions.create(
+ if is_mm_chat:
+ res = client.chat.completions.create(
model=model,
- prompt=token_ids,
+ messages=messages,
max_tokens=1,
- extra_body={"return_token_ids": True},
+ extra_body={"add_generation_prompt": False, "return_token_ids": True},
timeout=timeout,
)
else:
- res = client.chat.completions.create(
+ extra_body: dict[str, Any] = {"return_token_ids": True}
+ mm = client_item.get("multi_modal_data")
+ if mm:
+ extra_body["multi_modal_data"] = mm
+ res = client.completions.create(
model=model,
- messages=messages,
+ prompt=token_ids,
max_tokens=1,
- extra_body={"add_generation_prompt": False, "return_token_ids": True},
+ extra_body=extra_body,
timeout=timeout,
)
- return extract_output(res, token_ids)
+ return extract_output(res, token_ids, trust_server_token_ids=is_mm_chat)
diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py
index 00da65fb4..187354dbd 100644
--- a/src/speculators/models/eagle3/core.py
+++ b/src/speculators/models/eagle3/core.py
@@ -1,4 +1,5 @@
import copy
+import warnings
from typing import ClassVar
import torch
@@ -14,6 +15,7 @@
)
from speculators.models.eagle3.metrics import compute_metrics
from speculators.models.eagle3.model_definitions import model_classes
+from speculators.models.eagle3.rotary_partial import install_partial_neox_rotary
from speculators.models.metrics import kl_div_loss, resolve_loss_fn
from speculators.models.utils import resolve_target_layer_ids
from speculators.proposals.greedy import GreedyTokenProposalConfig
@@ -26,6 +28,128 @@ def conditional_torch_compile(func):
return func
+def _wrap_qwen_omni_rotary_with_hf_layout(rotary_cls: type) -> type:
+ """Adapt Qwen-Omni rotary to HF MRoPE layout.
+
+ Qwen-Omni expects ``position_ids`` as ``[3, batch, seq_len]`` while our
+ training path emits ``[batch, 3, seq_len]``. This wrapper only transposes
+ the HF case to avoid MRoPE channel mis-broadcast and projection dim errors.
+ """
+
+ class HFLayoutMRoPE(rotary_cls): # type: ignore[misc,valid-type]
+ def forward(self, x, position_ids): # type: ignore[override]
+ # Only transpose clear HF layout ``[B, 3, T]``.
+ # Other shapes pass through unchanged.
+ if position_ids.dim() == 3 and position_ids.shape[1] == 3:
+ position_ids = position_ids.transpose(0, 1).contiguous()
+ return super().forward(x, position_ids)
+
+ HFLayoutMRoPE.__name__ = f"{rotary_cls.__name__}HFLayout"
+ HFLayoutMRoPE.__qualname__ = HFLayoutMRoPE.__name__
+ return HFLayoutMRoPE
+
+
+def _select_rotary_emb_class(
+ tl_config: PretrainedConfig,
+ default_cls: type,
+) -> type:
+ """Select an MRoPE-aware rotary class for multimodal Qwen draft configs."""
+ rope_params = getattr(tl_config, "rope_parameters", None)
+ has_mrope = (
+ isinstance(rope_params, dict) and rope_params.get("mrope_section") is not None
+ )
+ if not has_mrope:
+ return default_cls
+
+ try:
+ from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( # noqa: PLC0415
+ Qwen3OmniMoeThinkerTextRotaryEmbedding,
+ )
+ except ImportError:
+ warnings.warn(
+ "Draft config carries rope_parameters.mrope_section but the installed "
+ "transformers does not expose Qwen3OmniMoeThinkerTextRotaryEmbedding. "
+ "Falling back to the architecture default rotary embedding, which "
+ "will ignore MRoPE and can cause a train/inference mismatch for "
+ "multimodal inputs.",
+ UserWarning,
+ stacklevel=2,
+ )
+ return default_cls
+
+ partial_rotary_factor = float(rope_params.get("partial_rotary_factor", 1.0))
+ if partial_rotary_factor >= 1.0:
+ return _wrap_qwen_omni_rotary_with_hf_layout(
+ Qwen3OmniMoeThinkerTextRotaryEmbedding
+ )
+
+ # Align HF partial rotation with vLLM partial-neox behavior.
+ # Keep native ``partial_rotary_factor``/``mrope_section`` unchanged.
+ install_partial_neox_rotary()
+ return _wrap_qwen_omni_rotary_with_hf_layout(
+ _make_partial_mrope_rotary_cls(
+ Qwen3OmniMoeThinkerTextRotaryEmbedding, partial_rotary_factor
+ )
+ )
+
+
+def _make_partial_mrope_rotary_cls(
+ base_cls: type, partial_rotary_factor: float
+) -> type:
+ """Return a partial-MRoPE rotary class for Qwen drafts.
+
+ It emits unpadded ``[*, rotary_dim]`` cos/sin. The patched
+ ``apply_rotary_pos_emb`` rotates only the leading ``rotary_dim`` channels
+ and keeps the tail unchanged, matching vLLM partial-neox semantics.
+ """
+
+ class PartialMRoPE(base_cls): # type: ignore[misc,valid-type]
+ _partial_rotary_factor: ClassVar[float] = partial_rotary_factor
+
+ @staticmethod
+ def compute_default_rope_parameters(config=None, device=None, seq_len=None):
+ base = config.rope_parameters["rope_theta"]
+ head_dim = (
+ getattr(config, "head_dim", None)
+ or config.hidden_size // config.num_attention_heads
+ )
+ rotary_dim = int(head_dim * PartialMRoPE._partial_rotary_factor)
+ rotary_dim = (rotary_dim // 2) * 2
+ attention_factor = 1.0
+ inv_freq = 1.0 / (
+ base
+ ** (
+ torch.arange(0, rotary_dim, 2, dtype=torch.int64).to(
+ device=device, dtype=torch.float
+ )
+ / rotary_dim
+ )
+ )
+ return inv_freq, attention_factor
+
+ def __init__(self, config, device=None):
+ super().__init__(config=config, device=device)
+ head_dim = (
+ getattr(config, "head_dim", None)
+ or config.hidden_size // config.num_attention_heads
+ )
+ rotary_dim = int(head_dim * self._partial_rotary_factor)
+ rotary_dim = (rotary_dim // 2) * 2
+ self._head_dim = head_dim
+ self._rotary_dim = rotary_dim
+
+ def forward(self, x, position_ids):
+ # Keep cos/sin unpadded; patched apply_rotary_pos_emb handles
+ # partial-neox rotation on the leading rotary channels.
+ return super().forward(x, position_ids)
+
+ PartialMRoPE.__name__ = (
+ f"{base_cls.__name__}Partial{int(partial_rotary_factor * 100):03d}"
+ )
+ PartialMRoPE.__qualname__ = PartialMRoPE.__name__
+ return PartialMRoPE
+
+
@SpeculatorModel.register("eagle3")
class Eagle3DraftModel(DraftVocabMixin, SpeculatorModel):
config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc]
@@ -81,7 +205,10 @@ def __init__(self, config: Eagle3SpeculatorConfig):
# Create a modified config for the rotary embedding to use 2x the hidden size
modified_tl_config = copy.copy(config.transformer_layer_config)
modified_tl_config.hidden_size *= 2
- self.rotary_emb = self._model_definitions.rotary_emb_class(modified_tl_config)
+ rotary_cls = _select_rotary_emb_class(
+ modified_tl_config, self._model_definitions.rotary_emb_class
+ )
+ self.rotary_emb = rotary_cls(modified_tl_config)
# LAYER NORMS
norm_class = self._model_definitions.norm_class
@@ -108,7 +235,9 @@ def load_verifier_weights(self):
verifier_config = self.config.speculators_config.verifier
verifier_model_config = AutoConfig.from_pretrained(verifier_config.name_or_path) # type: ignore[arg-type]
- # For multimodal models (Qwen3VL, etc.), extract text_config
+ # For multimodal models (Qwen3VL/Omni/etc.), extract text_config
+ if hasattr(verifier_model_config, "thinker_config"):
+ verifier_model_config = verifier_model_config.thinker_config
if hasattr(verifier_model_config, "text_config"):
verifier_model_config = verifier_model_config.text_config
diff --git a/src/speculators/models/eagle3/data.py b/src/speculators/models/eagle3/data.py
index 2ac6af086..b56b6754c 100644
--- a/src/speculators/models/eagle3/data.py
+++ b/src/speculators/models/eagle3/data.py
@@ -28,7 +28,10 @@ def shift_batch(batch: BatchType):
verifier_last_hidden_states = verifier_last_hidden_states[1:]
loss_mask = loss_mask[1:]
lengths = lengths - 1
- position_ids = position_ids[1:] # Note: position_ids now start at 1
+ if position_ids.ndim == 2:
+ position_ids = position_ids[:, 1:]
+ else:
+ position_ids = position_ids[1:] # Note: position_ids now start at 1
return {
"input_ids": input_ids,
diff --git a/src/speculators/models/eagle3/metrics.py b/src/speculators/models/eagle3/metrics.py
index 91ab3bfe0..49699893f 100644
--- a/src/speculators/models/eagle3/metrics.py
+++ b/src/speculators/models/eagle3/metrics.py
@@ -1,17 +1,49 @@
"""Metrics and loss functions for Eagle3 draft model."""
+import os
+import warnings
from collections.abc import Callable
from functools import partial
import torch
from speculators.models.metrics import (
+ ce_loss,
compute_accuracy_single_step,
exp_loss_decay,
kl_div_loss,
loss_function,
)
+_EAGLE3_LOSS_CE_WEIGHT_RAW = os.getenv("EAGLE3_LOSS_CE_WEIGHT")
+if _EAGLE3_LOSS_CE_WEIGHT_RAW is None:
+ _EAGLE3_LOSS_CE_WEIGHT = 0.0
+else:
+ try:
+ _EAGLE3_LOSS_CE_WEIGHT = float(_EAGLE3_LOSS_CE_WEIGHT_RAW)
+ except ValueError:
+ warnings.warn(
+ f"EAGLE3_LOSS_CE_WEIGHT={_EAGLE3_LOSS_CE_WEIGHT_RAW!r} is not a "
+ "float; falling back to pure-KL legacy loss (weight=0.0).",
+ stacklevel=1,
+ )
+ _EAGLE3_LOSS_CE_WEIGHT = 0.0
+ if not 0.0 <= _EAGLE3_LOSS_CE_WEIGHT <= 1.0:
+ warnings.warn(
+ f"EAGLE3_LOSS_CE_WEIGHT={_EAGLE3_LOSS_CE_WEIGHT} clamped to [0, 1].",
+ stacklevel=1,
+ )
+ _EAGLE3_LOSS_CE_WEIGHT = max(0.0, min(1.0, _EAGLE3_LOSS_CE_WEIGHT))
+
+
+def eagle3_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Per-token Eagle3 loss with optional CE(verifier argmax) mixture."""
+ kl = kl_div_loss(logits, targets)
+ if _EAGLE3_LOSS_CE_WEIGHT <= 0.0:
+ return kl
+ ce = ce_loss(logits, targets)
+ return (1.0 - _EAGLE3_LOSS_CE_WEIGHT) * kl + _EAGLE3_LOSS_CE_WEIGHT * ce
+
def align_for_step(
logits: torch.Tensor, # shape: [1, total_seq_len, draft_vocab_size]
@@ -75,7 +107,7 @@ def compute_metrics(
Loss value and metrics dictionary.
"""
if loss_fn is None:
- loss_fn = kl_div_loss
+ loss_fn = eagle3_loss
s_logits, s_targets, s_loss_mask, s_prev_correct = align_for_step(
logits, targets, loss_mask, prev_correct, ttt_step
)
diff --git a/src/speculators/models/eagle3/rotary_partial.py b/src/speculators/models/eagle3/rotary_partial.py
new file mode 100644
index 000000000..bc085f512
--- /dev/null
+++ b/src/speculators/models/eagle3/rotary_partial.py
@@ -0,0 +1,109 @@
+"""Patch HF rotary helper to match vLLM partial-neox behavior.
+
+HF ``apply_rotary_pos_emb`` rotates by splitting at ``head_dim/2``.
+vLLM partial MRoPE rotates only the first ``rotary_dim`` channels and
+keeps the tail unchanged. This file aligns HF training with that runtime
+behavior, while keeping full-rotation paths unchanged.
+"""
+
+from __future__ import annotations
+
+import torch
+
+__all__ = [
+ "install_partial_neox_rotary",
+ "partial_neox_apply_rotary_pos_emb",
+]
+
+
+def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """HF/neox "rotate_half" — splits the last dim in half and swaps."""
+ half = x.shape[-1] // 2
+ x1 = x[..., :half]
+ x2 = x[..., half:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def partial_neox_apply_rotary_pos_emb(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ unsqueeze_dim: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """HF-compatible rotary helper with partial-neox fallback.
+
+ - If ``cos`` covers full head dim, behavior matches HF.
+ - If ``cos`` is shorter, rotate only the first ``rotary_dim`` channels
+ and keep the remaining channels unchanged.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ if cos.shape[-1] == q.shape[-1]:
+ # Full rotation — identical to HF apply_rotary_pos_emb.
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
+ return q_embed, k_embed
+
+ if cos.shape[-1] > q.shape[-1]:
+ raise ValueError(
+ f"cos last dim ({cos.shape[-1]}) exceeds q last dim "
+ f"({q.shape[-1]}); rotary tables larger than head_dim are "
+ "unsupported by this partial-neox replacement."
+ )
+
+ rotary_dim = cos.shape[-1]
+ # Rotate leading rotary channels; keep tail as pass-through.
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ q_rot = (q_rot * cos) + (_rotate_half(q_rot) * sin)
+ k_rot = (k_rot * cos) + (_rotate_half(k_rot) * sin)
+
+ q_embed = torch.cat([q_rot, q_pass], dim=-1)
+ k_embed = torch.cat([k_rot, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+_INSTALLED = False
+
+
+def install_partial_neox_rotary() -> None:
+ """Patch ``apply_rotary_pos_emb`` in HF ``llama`` and ``qwen3`` modules.
+
+ Idempotent. Full-rotation paths keep original behavior.
+ """
+ global _INSTALLED
+ if _INSTALLED:
+ return
+
+ # Local imports — keep transformers a soft dep at module import time.
+ from transformers.models.llama import modeling_llama # noqa: PLC0415
+ from transformers.models.qwen3 import modeling_qwen3 # noqa: PLC0415
+
+ for module in (modeling_llama, modeling_qwen3):
+ original = module.apply_rotary_pos_emb
+ # Cache original for tests / debugging — and to allow uninstall.
+ if not hasattr(module, "_speculators_original_apply_rotary_pos_emb"):
+ module._speculators_original_apply_rotary_pos_emb = original # type: ignore[attr-defined]
+ module.apply_rotary_pos_emb = partial_neox_apply_rotary_pos_emb
+
+ _INSTALLED = True
+
+
+def uninstall_partial_neox_rotary() -> None:
+ """Restore HF's original ``apply_rotary_pos_emb``. Test/debug helper."""
+ global _INSTALLED
+ if not _INSTALLED:
+ return
+ from transformers.models.llama import modeling_llama # noqa: PLC0415
+ from transformers.models.qwen3 import modeling_qwen3 # noqa: PLC0415
+
+ for module in (modeling_llama, modeling_qwen3):
+ original = getattr(
+ module, "_speculators_original_apply_rotary_pos_emb", None
+ )
+ if original is not None:
+ module.apply_rotary_pos_emb = original
+ _INSTALLED = False
diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py
index 63519024d..61cf1dded 100644
--- a/src/speculators/train/data.py
+++ b/src/speculators/train/data.py
@@ -7,6 +7,7 @@
from collections.abc import Callable
from os import PathLike
from pathlib import Path
+from types import MethodType, SimpleNamespace
from typing import Any, Literal, cast
import openai
@@ -15,6 +16,7 @@
from datasets import load_from_disk
from safetensors.torch import load_file
from torch.utils.data import Dataset
+from transformers import AutoConfig
from speculators.data_generation.vllm_client import (
DEFAULT_MAX_RETRIES,
@@ -26,6 +28,23 @@
from speculators.train.noise_transforms import TransformTensors
BatchType = dict[str, Any]
+MULTIMODAL_SIDECAR_KEYS = (
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ "second_per_grids",
+ "input_features",
+ "feature_attention_mask",
+ "audio_feature_lengths",
+)
+NON_TRAINING_KEYS = {
+ *MULTIMODAL_SIDECAR_KEYS,
+ "messages",
+ "messages_json",
+ "mm_file",
+ "use_audio_in_video",
+}
def list_files(path):
@@ -47,6 +66,16 @@ def slice_and_pad_to_length(tensor, length):
return F.pad(sliced_tensor, padding)
+def pad_last_dim_to_length(tensor: torch.Tensor, length: int) -> torch.Tensor:
+ sliced_tensor = tensor[..., :length]
+ pad_amount = length - sliced_tensor.shape[-1]
+ if pad_amount <= 0:
+ return sliced_tensor
+ padding = [0, pad_amount]
+ padding.extend([0, 0] * (sliced_tensor.dim() - 1))
+ return F.pad(sliced_tensor, padding)
+
+
def split_files(datapath: str, ratio: float = 0.9, seed: int = 0):
"""Given a datapath, split the files into a training and validation set
ratio is the proportion of files to put in the training set
@@ -111,16 +140,198 @@ def standardize_data_v1(data: dict[str, Any]) -> dict[str, Any]:
}
+def _collect_mm_payload_from_messages(
+ messages: list[dict],
+) -> dict[str, list[str]]:
+ """Extract image/video/audio URLs from vLLM-style messages.
+
+ Used to populate ``extra_body.multi_modal_data`` for token-id completions.
+ """
+ bucket: dict[str, list[str]] = {}
+ for turn in messages:
+ content = turn.get("content")
+ if not isinstance(content, list):
+ continue
+ for seg in content:
+ if not isinstance(seg, dict):
+ continue
+ seg_type = seg.get("type", "")
+ for modality in ("image", "video", "audio"):
+ if seg_type == f"{modality}_url":
+ url = (seg.get(f"{modality}_url") or {}).get("url")
+ if url:
+ bucket.setdefault(modality, []).append(url)
+ return bucket
+
+
def build_client_item(dataset_item: dict) -> ClientItem:
- out_dict = {}
- out_dict["input_ids"] = dataset_item["input_ids"].tolist()
+ """Build a vLLM client payload from one dataset row.
- if "messages" in dataset_item:
- out_dict["messages"] = dataset_item["messages"]
+ Default path sends ``input_ids`` (token-id completions). For multimodal
+ rows, URLs from ``messages_json`` are forwarded as ``multi_modal_data``.
+ ``messages`` is kept only for optional chat-completions fallback.
+ """
+ out_dict: dict[str, Any] = {}
+ input_ids = _maybe_tensor(dataset_item["input_ids"], torch.long)
+ out_dict["input_ids"] = input_ids.tolist() if input_ids is not None else []
+
+ messages_json = dataset_item.get("messages_json", "")
+ messages: list[dict] | None = None
+ if messages_json:
+ messages = json.loads(messages_json)
+ elif "messages" in dataset_item:
+ messages = dataset_item["messages"]
+
+ if messages:
+ # Keep messages around as a fallback path
+ out_dict["messages"] = messages
+ # Primary path: extract MM URLs so we can use Completions API
+ mm = _collect_mm_payload_from_messages(messages)
+ if mm:
+ out_dict["multi_modal_data"] = mm
return cast("ClientItem", out_dict)
+def _maybe_tensor(value: Any, dtype: torch.dtype | None = None) -> torch.Tensor | None:
+ if value is None:
+ return None
+ tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
+ if dtype is not None:
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+
+
+def _batchify_mm_tensor(value: torch.Tensor | None) -> torch.Tensor | None:
+ if value is None:
+ return None
+ if value.ndim == 0:
+ return value.view(1)
+ if value.ndim == 1:
+ return value.unsqueeze(0)
+ return value
+
+
+def _has_multimodal_payload(data: dict[str, Any]) -> bool:
+ return any(data.get(key) is not None for key in MULTIMODAL_SIDECAR_KEYS)
+
+
+def _make_rope_index_fn(verifier_name_or_path: str):
+ """Build a callable producing 3D MRoPE position ids for Qwen multimodal models."""
+ try:
+ verifier_root_config = AutoConfig.from_pretrained(
+ verifier_name_or_path, trust_remote_code=True
+ )
+ except Exception: # noqa: BLE001
+ return None
+
+ thinker_config = getattr(verifier_root_config, "thinker_config", None)
+ if thinker_config is not None:
+ return _make_rope_index_fn_qwen3_omni(thinker_config)
+
+ top_vision_config = getattr(verifier_root_config, "vision_config", None)
+ top_text_config = getattr(verifier_root_config, "text_config", None)
+ top_image_token_id = getattr(verifier_root_config, "image_token_id", None)
+ if (
+ top_vision_config is not None
+ and top_text_config is not None
+ and top_image_token_id is not None
+ ):
+ return _make_rope_index_fn_qwen3_5_moe(verifier_root_config)
+
+ return None
+
+
+def _make_rope_index_fn_qwen3_omni(thinker_config):
+ try:
+ from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( # noqa: PLC0415
+ Qwen3OmniMoeThinkerForConditionalGeneration,
+ )
+ except ImportError:
+ return None
+
+ vision_config = getattr(thinker_config, "vision_config", None)
+ spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1)
+ dummy = SimpleNamespace(
+ config=SimpleNamespace(
+ image_token_id=getattr(thinker_config, "image_token_id", None),
+ video_token_id=getattr(thinker_config, "video_token_id", None),
+ audio_token_id=getattr(thinker_config, "audio_token_id", None),
+ vision_start_token_id=getattr(
+ thinker_config, "vision_start_token_id", None
+ ),
+ audio_start_token_id=getattr(thinker_config, "audio_start_token_id", None),
+ position_id_per_seconds=getattr(
+ thinker_config, "position_id_per_seconds", 25
+ ),
+ ),
+ spatial_merge_size=spatial_merge_size,
+ )
+ dummy.get_llm_pos_ids_for_vision = MethodType(
+ Qwen3OmniMoeThinkerForConditionalGeneration.get_llm_pos_ids_for_vision,
+ dummy,
+ )
+ return MethodType(
+ Qwen3OmniMoeThinkerForConditionalGeneration.get_rope_index,
+ dummy,
+ )
+
+
+def _make_rope_index_fn_qwen3_5_moe(root_config):
+ try:
+ from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( # noqa: PLC0415
+ Qwen3_5MoeModel,
+ )
+ except ImportError:
+ return None
+
+ image_token_id = getattr(root_config, "image_token_id", None)
+ video_token_id = getattr(root_config, "video_token_id", None)
+ vision_config = getattr(root_config, "vision_config", None)
+ dummy = SimpleNamespace(
+ config=SimpleNamespace(
+ vision_config=vision_config,
+ image_token_id=image_token_id,
+ video_token_id=video_token_id,
+ ),
+ )
+ dummy.get_vision_position_ids = MethodType(
+ Qwen3_5MoeModel.get_vision_position_ids, dummy
+ )
+ raw_get_rope_index = MethodType(Qwen3_5MoeModel.get_rope_index, dummy)
+
+ def adapter(
+ input_ids,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ attention_mask=None,
+ **_ignored,
+ ):
+ ids = input_ids
+ image_mask = (
+ ids == image_token_id
+ if image_token_id is not None
+ else torch.zeros_like(ids, dtype=torch.bool)
+ )
+ video_mask = (
+ ids == video_token_id
+ if video_token_id is not None
+ else torch.zeros_like(ids, dtype=torch.bool)
+ )
+ mm_token_type_ids = torch.zeros_like(ids, dtype=torch.int32)
+ mm_token_type_ids[image_mask] = 1
+ mm_token_type_ids[video_mask] = 2
+ return raw_get_rope_index(
+ input_ids=ids,
+ mm_token_type_ids=mm_token_type_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ )
+
+ return adapter
+
+
class BaseDataset(Dataset):
def __init__(
self,
@@ -139,6 +350,32 @@ def _compute_approx_lengths(self):
def _get_raw_data(self, index):
raise NotImplementedError
+ def _build_position_ids(self, data: dict[str, Any], seq_len: int) -> torch.Tensor:
+ rope_index_fn = getattr(self, "_rope_index_fn", None)
+ if rope_index_fn is not None and _has_multimodal_payload(data):
+ audio_seqlens = _maybe_tensor(data.get("audio_feature_lengths"), torch.long)
+ if audio_seqlens is None and data.get("feature_attention_mask") is not None:
+ audio_seqlens = _maybe_tensor(
+ data["feature_attention_mask"], torch.long
+ ).sum(dim=-1)
+
+ position_ids, _ = rope_index_fn(
+ input_ids=data["input_ids"].unsqueeze(0),
+ image_grid_thw=_batchify_mm_tensor(
+ _maybe_tensor(data.get("image_grid_thw"), torch.long)
+ ),
+ video_grid_thw=_batchify_mm_tensor(
+ _maybe_tensor(data.get("video_grid_thw"), torch.long)
+ ),
+ use_audio_in_video=bool(data.get("use_audio_in_video", False)),
+ audio_seqlens=audio_seqlens,
+ second_per_grids=_maybe_tensor(data.get("second_per_grids")),
+ attention_mask=torch.ones(1, seq_len, dtype=torch.long),
+ )
+ return position_ids[:, 0].to(dtype=torch.long)
+
+ return torch.arange(seq_len, dtype=torch.long)
+
def __getitem__(self, index) -> BatchType | None:
data = self._get_raw_data(index)
@@ -163,8 +400,12 @@ def __getitem__(self, index) -> BatchType | None:
data["lengths"] = torch.tensor([seq_len], dtype=torch.long)
# shape: [1]
- data["position_ids"] = torch.arange(seq_len, dtype=torch.long)
- # shape: [seq_len]
+ data["position_ids"] = self._build_position_ids(data, seq_len)
+ # shape: [seq_len] or [3, seq_len] for MRoPE multimodal samples
+
+ # Keep multimodal payload only for rope index / HS generation
+ for key in NON_TRAINING_KEYS:
+ data.pop(key, None)
# data structure: {
# "hidden_states": [seq_len, 3 * hidden_size],
@@ -208,6 +449,7 @@ def __init__(
model: str | None = None,
request_timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
max_retries: int = DEFAULT_MAX_RETRIES,
+ verifier_name_or_path: str | None = None,
):
"""Initialize the ArrowDataset.
Args:
@@ -217,12 +459,14 @@ def __init__(
transform: The transform to apply to the data.
hidden_states_dtype: The dtype of the hidden states.
"""
- self.data = load_from_disk(datapath)
+ self.datapath = Path(datapath)
+ # Reset format after load so metadata columns remain visible
+ # (e.g. ``mm_file`` / ``messages_json``).
+ self.data = load_from_disk(datapath).with_format(None)
self.start_file_idx = 0
if split_ratio == 1.0:
pass
elif 1.0 > split_ratio > 0:
- self.start_file_idx = 0
split_idx = int(len(self.data) * split_ratio)
self.data = self.data.select(range(split_idx))
elif -1.0 < split_ratio < 0:
@@ -233,7 +477,7 @@ def __init__(
raise ValueError("split_ratio must be in range (-1.0, 1.0] excluding 0.0.")
self.hidden_states_path: Path = (
- Path(datapath) / "hidden_states"
+ self.datapath / "hidden_states"
if hidden_states_path is None
else Path(hidden_states_path)
)
@@ -242,8 +486,14 @@ def __init__(
self.on_generate = on_generate
self.client: openai.OpenAI | None = None
self.model = model
+ self.verifier_name_or_path = verifier_name_or_path or model
self.request_timeout = request_timeout
self.max_retries = max_retries
+ self._rope_index_fn = (
+ _make_rope_index_fn(self.verifier_name_or_path)
+ if self.verifier_name_or_path is not None
+ else None
+ )
# Delay super init so that `_compute_approx_lengths` has required data
super().__init__(max_len, transform, hidden_states_dtype)
@@ -273,28 +523,57 @@ def _compute_approx_lengths(self) -> list[int]:
"""Get lengths of the dataset samples."""
return list(self.data.with_format(None)["seq_len"])
- def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None:
+ def _maybe_generate_hs(
+ self, index: int, *, is_multimodal: bool = False
+ ) -> dict[str, torch.Tensor] | None:
+ """Generate hidden states from vLLM on demand.
+
+ For multimodal rows, Chat Completions is used and server token_ids are
+ treated as authoritative for HS alignment.
+ """
if not self.client:
self._setup_client()
dataset_item = self.data[index]
client_item = build_client_item(dataset_item)
+ # Use Chat Completions for multimodal rows to trigger vision encoding.
+ # If tokenization drifts, trust vLLM token_ids (HS-aligned).
+ use_chat = is_multimodal and client_item.get("messages") is not None
+
try:
- hs_filepath = generate_hidden_states(
+ result = generate_hidden_states(
self.client, # type:ignore[arg-type]
self.model, # type:ignore[arg-type]
client_item,
timeout=self.request_timeout,
max_retries=self.max_retries,
+ use_chat_completions=use_chat,
)
+ if use_chat:
+ hs_filepath, server_token_ids = result # type:ignore[misc]
+ else:
+ hs_filepath = result # type:ignore[assignment]
+ server_token_ids = None
+
loaded_hs = _maybe_load_hs_file(Path(hs_filepath))
+ # Overwrite token_ids with the server's authoritative version
+ # so downstream _get_raw_data uses the correct positionally-
+ # aligned sequence for loss/hidden-state pairing.
+ if loaded_hs is not None and server_token_ids is not None:
+ loaded_hs["token_ids"] = torch.tensor(
+ server_token_ids, dtype=torch.long
+ )
+
match self.on_generate:
case "cache":
file_idx = self._map_to_file_idx(index)
target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"
+ # Ensure destination directory exists before moving files.
+ # Missing parent dir can cause all generated samples to fail.
+ target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(hs_filepath, target_path)
case "delete":
Path(hs_filepath).unlink()
@@ -308,14 +587,26 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None:
return loaded_hs
def _get_raw_data(self, index):
+ # Read one row once, then normalize tensors and metadata together.
+ row = self.data[index]
+ stored_input_ids = _maybe_tensor(row["input_ids"], torch.long)
+ stored_loss_mask = _maybe_tensor(row["loss_mask"], torch.long)
+ # Fast-fail: skip corrupted/incomplete rows before expensive HS I/O.
+ if stored_input_ids is None or stored_loss_mask is None:
+ return None
+
+ is_multimodal = bool(row.get("messages_json", ""))
+
file_idx = self._map_to_file_idx(index)
- candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"
- loaded_hs = _maybe_load_hs_file(candidate_path)
+ hs_file = self.hidden_states_path / f"hs_{file_idx}.safetensors"
+ loaded_hs = _maybe_load_hs_file(hs_file)
if loaded_hs is None:
match self.on_missing:
case "generate":
- loaded_hs = self._maybe_generate_hs(index)
+ loaded_hs = self._maybe_generate_hs(
+ index, is_multimodal=is_multimodal
+ )
case "skip":
return None
case "warn":
@@ -337,25 +628,61 @@ def _get_raw_data(self, index):
# "token_ids": [seq_len]
# }
- if not torch.equal(loaded_hs["token_ids"], self.data[index]["input_ids"]):
+ # For multimodal samples generated via Chat Completions, vLLM's
+ # server-side tokenization is authoritative (it ran the vision
+ # encoder on that exact token sequence). The stored input_ids from
+ # prepare_data.py may differ by ±1 token at vision-placeholder
+ # boundaries, so we trust loaded_hs["token_ids"] unconditionally.
+ # For text-only samples (or pre-generated offline HS), we still
+ # enforce strict equality to catch data corruption early.
+ authoritative_input_ids = loaded_hs["token_ids"]
+ if not is_multimodal and not torch.equal(authoritative_input_ids, stored_input_ids):
warnings.warn(
- f"Loaded token ids {loaded_hs['token_ids']} for index {index} don't"
- f"match input ids {self.data[index]['input_ids']}",
+ f"Token IDs mismatch for text sample {index}: "
+ f"hs has {len(authoritative_input_ids)} tokens, "
+ f"dataset has {len(stored_input_ids)} tokens. Skipping.",
stacklevel=1,
)
return None
- return {
+ # Use authoritative_input_ids (from HS file) as ground truth.
+ # For multimodal, this is vLLM's tokenization (aligned with HS);
+ # for text-only, it equals stored_input_ids (verified above).
+ # loss_mask length must match; truncate/pad if server tokenized
+ # slightly differently (common for multimodal ±1 token drift).
+ seq_len_hs = len(authoritative_input_ids)
+ if len(stored_loss_mask) != seq_len_hs:
+ if len(stored_loss_mask) > seq_len_hs:
+ stored_loss_mask = stored_loss_mask[:seq_len_hs]
+ else:
+ stored_loss_mask = torch.cat([
+ stored_loss_mask,
+ torch.zeros(seq_len_hs - len(stored_loss_mask), dtype=torch.long),
+ ])
+
+ data = {
"hidden_states": loaded_hs["hidden_states"][:, :-1].flatten(
1
), # [seq_len, 3 * hidden_size]
- "input_ids": loaded_hs["token_ids"], # [seq_len]
+ "input_ids": authoritative_input_ids, # [seq_len]
"verifier_last_hidden_states": loaded_hs["hidden_states"][
:, -1
], # [seq_len, hidden_size]
- "loss_mask": self.data[index]["loss_mask"], # [seq_len]
+ "loss_mask": stored_loss_mask, # [seq_len]
+ "messages_json": row.get("messages_json", ""),
+ "mm_file": row.get("mm_file", ""),
+ "use_audio_in_video": bool(row.get("use_audio_in_video", 0)),
}
+ mm_path = data["mm_file"]
+ if mm_path:
+ mm = load_file(self.datapath / mm_path)
+ for key in MULTIMODAL_SIDECAR_KEYS:
+ if key in mm:
+ data[key] = mm[key]
+
+ return data
+
class SampleFileDataset(BaseDataset):
def __init__(
@@ -473,7 +800,13 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType:
batch = [create_empty_sample(hidden_size, dtype=dtype)]
collated_data = {}
+ has_mrope = any(
+ b["position_ids"].ndim == 2 for b in batch # type: ignore[index]
+ )
for key in batch[0]: # type: ignore[union-attr]
+ if key == "position_ids":
+ continue
+
# Concatenate the tensors along the seq (0th) dimension
collated_data[key] = torch.cat([b[key] for b in batch], dim=0) # type: ignore[index]
# shape: [total_seq_len, ...]
@@ -485,6 +818,40 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType:
).unsqueeze(0)
# shape: [1, max_len, ...]
+ if has_mrope:
+ # MRoPE samples carry position_ids of shape [3, seq_len] (T/H/W
+ # channels). After shift_batch they remain [3, seq_len-1].
+ # We concatenate along the seq dimension, pad to max_len, then
+ # add a batch dim. Qwen3-Omni Thinker rotary (used by EAGLE3 when
+ # the verifier carries rope_parameters.mrope_section) expects
+ # ``position_ids`` shaped ``[3, batch, seq_len]`` — channels
+ # first — NOT the HF Llama4 / Gemma3 ``[batch, 3, seq_len]``
+ # convention. Picking the wrong layout silently broadcasts the
+ # 3 channels through the attention reshape and blows up
+ # ``o_proj`` with a feature dim of ``3 * num_heads * head_dim``
+ # (e.g. 12288 instead of 4096 for Qwen3.6 draft heads).
+ position_ids = []
+ for sample in batch: # type: ignore[assignment]
+ pos = sample["position_ids"]
+ if pos.ndim == 1:
+ # Text-only sample in a mixed batch: broadcast to 3 channels
+ pos = pos.unsqueeze(0).expand(3, -1)
+ position_ids.append(pos)
+ # shape of each: [3, sample_seq_len]; cat along seq dim → [3, total_seq_len]
+ collated_positions = torch.cat(position_ids, dim=-1)
+ # Pad seq dim to max_len → [3, max_len], then add batch dim at
+ # axis 1 → [3, 1, max_len] for Qwen-Omni rotary.
+ collated_data["position_ids"] = pad_last_dim_to_length(
+ collated_positions, max_len
+ ).unsqueeze(1)
+ else:
+ # Standard 1D position_ids: [seq_len] per sample.
+ # Cat along seq dim → [total_seq_len], pad → [max_len], batch → [1, max_len]
+ collated_positions = torch.cat([b["position_ids"] for b in batch], dim=0) # type: ignore[index]
+ collated_data["position_ids"] = slice_and_pad_to_length(
+ collated_positions, max_len
+ ).unsqueeze(0)
+
# Include lengths until while they fit in max_len
# The last included length is (if necessary) truncated
# Any additional lengths are discarded
diff --git a/src/speculators/train/vocab_mapping.py b/src/speculators/train/vocab_mapping.py
index 90351c53a..1d66f35be 100644
--- a/src/speculators/train/vocab_mapping.py
+++ b/src/speculators/train/vocab_mapping.py
@@ -111,7 +111,9 @@ def get_target_vocab_size(target_vocab_size, target_model_path):
config = AutoConfig.from_pretrained(target_model_path)
- # For multimodal models (Qwen3VL, etc.), extract text_config
+ # Multimodal verifiers may nest the text backbone.
+ if hasattr(config, "thinker_config"):
+ config = config.thinker_config
if hasattr(config, "text_config"):
config = config.text_config
diff --git a/tests/unit/models/test_eagle3_rotary_partial.py b/tests/unit/models/test_eagle3_rotary_partial.py
new file mode 100644
index 000000000..d51c0e21b
--- /dev/null
+++ b/tests/unit/models/test_eagle3_rotary_partial.py
@@ -0,0 +1,197 @@
+"""Bit-equivalence test for the partial-neox rotary monkey-patch.
+
+This test pins the contract that motivates the patch: when
+``cos.shape[-1] == rotary_dim < head_dim``, the patched
+``apply_rotary_pos_emb`` rotates the **same channel pairs** that vLLM's
+``MRotaryEmbedding`` rotates at inference time — i.e.
+
+ rotated channels: [0, rotary_dim/2) paired with [rotary_dim/2, rotary_dim)
+ pass-through: [rotary_dim, head_dim)
+
+We verify two properties on a hand-rolled vLLM-equivalent reference:
+
+1. **Full-rotation parity** (``rotary_dim == head_dim``): the patched
+ helper is byte-equivalent to HF's original ``apply_rotary_pos_emb``,
+ so DFlash / plain Llama drafters are unaffected.
+
+2. **Partial-rotation parity** (``rotary_dim < head_dim``, e.g.
+ Qwen3.6's ``head_dim=256``, ``partial_rotary_factor=0.25``): the
+ patched helper matches a hand-coded vLLM neox-partial reference to
+ well below fp32 round-off (``< 1e-5`` max-abs-diff), and the
+ pass-through tail is preserved exactly (``allclose`` with
+ ``atol=0``).
+"""
+
+from __future__ import annotations
+
+import math
+
+import torch
+
+try:
+ import pytest
+except ImportError: # pragma: no cover - environments without pytest
+ pytest = None # type: ignore[assignment]
+
+
+def _vllm_neox_partial_reference(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ unsqueeze_dim: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Hand-coded vLLM neox-partial RoPE — matches MRotaryEmbedding._forward.
+
+ Independent of the implementation under test; written from the spec
+ (see ``vllm/model_executor/layers/rotary_embedding/__init__.py``).
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ rotary_dim = cos.shape[-1]
+
+ def _rot(x: torch.Tensor) -> torch.Tensor:
+ x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
+ half = rotary_dim // 2
+ x1, x2 = x_rot[..., :half], x_rot[..., half:]
+ # neox rotate_half on the rotated slice
+ rotated = torch.cat((-x2, x1), dim=-1)
+ out_rot = (x_rot * cos) + (rotated * sin)
+ return torch.cat([out_rot, x_pass], dim=-1)
+
+ return _rot(q), _rot(k)
+
+
+def test_full_rotation_is_byte_equivalent_to_hf():
+ """When cos covers the full head_dim, the patched fn must match HF exactly."""
+ from speculators.models.eagle3.rotary_partial import (
+ partial_neox_apply_rotary_pos_emb,
+ )
+ from transformers.models.llama.modeling_llama import (
+ apply_rotary_pos_emb as hf_apply,
+ )
+
+ torch.manual_seed(0)
+ batch, heads, seq, head_dim = 2, 4, 7, 64
+ q = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32)
+ k = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32)
+ # cos/sin over full head_dim
+ cos = torch.cos(torch.randn(batch, seq, head_dim))
+ sin = torch.sin(torch.randn(batch, seq, head_dim))
+
+ q_hf, k_hf = hf_apply(q, k, cos, sin, unsqueeze_dim=1)
+ q_p, k_p = partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)
+
+ # Strict equality — same arithmetic, same order of operations.
+ assert torch.equal(q_hf, q_p), "patched fn drifted from HF on full rotation"
+ assert torch.equal(k_hf, k_p), "patched fn drifted from HF on full rotation"
+
+
+def test_partial_rotation_matches_vllm_reference():
+ """rotary_dim < head_dim must match vLLM's neox-partial channel layout."""
+ from speculators.models.eagle3.rotary_partial import (
+ partial_neox_apply_rotary_pos_emb,
+ )
+
+ torch.manual_seed(1)
+ # Realistic Qwen3.6 shape: head_dim=256, partial_rotary_factor=0.25
+ batch, heads, seq, head_dim = 2, 4, 5, 256
+ rotary_dim = 64 # int(256 * 0.25)
+
+ q = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32)
+ k = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32)
+ cos = torch.cos(torch.randn(batch, seq, rotary_dim))
+ sin = torch.sin(torch.randn(batch, seq, rotary_dim))
+
+ q_ref, k_ref = _vllm_neox_partial_reference(q, k, cos, sin, unsqueeze_dim=1)
+ q_got, k_got = partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)
+
+ # Same arithmetic up to algebraic associativity; allow tiny fp32 jitter.
+ assert torch.allclose(q_got, q_ref, atol=1e-6, rtol=0)
+ assert torch.allclose(k_got, k_ref, atol=1e-6, rtol=0)
+
+
+def test_partial_rotation_preserves_passthrough_channels():
+ """Channels >= rotary_dim must be untouched (atol=0)."""
+ from speculators.models.eagle3.rotary_partial import (
+ partial_neox_apply_rotary_pos_emb,
+ )
+
+ torch.manual_seed(2)
+ batch, heads, seq, head_dim = 1, 2, 3, 64
+ rotary_dim = 16
+
+ q = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32)
+ k = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32)
+ cos = torch.cos(torch.randn(batch, seq, rotary_dim))
+ sin = torch.sin(torch.randn(batch, seq, rotary_dim))
+
+ q_out, k_out = partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)
+
+ # Pass-through tail must be bit-exact.
+ assert torch.equal(q_out[..., rotary_dim:], q[..., rotary_dim:])
+ assert torch.equal(k_out[..., rotary_dim:], k[..., rotary_dim:])
+ # And the rotated head must NOT be equal to the input (sanity).
+ assert not torch.equal(q_out[..., :rotary_dim], q[..., :rotary_dim])
+
+
+def test_install_is_idempotent_and_byte_safe_on_full_rotation():
+ """Installing the patch must not perturb full-rotation HF callers."""
+ from speculators.models.eagle3.rotary_partial import (
+ install_partial_neox_rotary,
+ uninstall_partial_neox_rotary,
+ )
+ from transformers.models.llama import modeling_llama
+ from transformers.models.llama.modeling_llama import (
+ apply_rotary_pos_emb as hf_apply_pre,
+ )
+
+ # Snapshot HF behaviour BEFORE installing.
+ torch.manual_seed(3)
+ head_dim = 32
+ q = torch.randn(1, 2, 4, head_dim)
+ k = torch.randn(1, 2, 4, head_dim)
+ cos = torch.cos(torch.randn(1, 4, head_dim))
+ sin = torch.sin(torch.randn(1, 4, head_dim))
+ q_pre, k_pre = hf_apply_pre(q, k, cos, sin, unsqueeze_dim=1)
+
+ install_partial_neox_rotary()
+ install_partial_neox_rotary() # idempotent
+
+ # After install, the module-level symbol must be the patched one.
+ q_post, k_post = modeling_llama.apply_rotary_pos_emb(
+ q, k, cos, sin, unsqueeze_dim=1
+ )
+ assert torch.equal(q_pre, q_post)
+ assert torch.equal(k_pre, k_post)
+
+ uninstall_partial_neox_rotary()
+ uninstall_partial_neox_rotary() # idempotent
+
+ # After uninstall, the symbol must be HF's original (object identity).
+ assert (
+ modeling_llama.apply_rotary_pos_emb is hf_apply_pre
+ or modeling_llama.apply_rotary_pos_emb.__wrapped__ is hf_apply_pre
+ )
+
+
+def test_rejects_cos_longer_than_head_dim():
+ """Defensive — cos can't be larger than q's last dim."""
+ from speculators.models.eagle3.rotary_partial import (
+ partial_neox_apply_rotary_pos_emb,
+ )
+
+ q = torch.randn(1, 1, 1, 8)
+ k = torch.randn(1, 1, 1, 8)
+ cos = torch.randn(1, 1, 16) # > head_dim=8
+ sin = torch.randn(1, 1, 16)
+ if pytest is not None:
+ with pytest.raises(ValueError, match="exceeds q last dim"):
+ partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)
+ return
+ try:
+ partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)
+ except ValueError as e:
+ assert "exceeds q last dim" in str(e)
+ else: # pragma: no cover
+ raise AssertionError("expected ValueError on oversize cos")