diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 4cd927d82..fb837a3b3 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -416,7 +416,20 @@ def main(): logger.info("EAGLE Offline Data Generation") - dataset = load_from_disk(args.preprocessed_data) + # `prepare_data.py` persists dataset formatting metadata. If we load it as-is, + # only torch-formatted training columns may be exposed, hiding multimodal + # metadata columns like `messages_json`/`mm_file`. That would prevent + # `build_client_item` from forwarding `multi_modal_data` to vLLM and can + # produce degenerate image placeholder hidden states. + dataset = load_from_disk(args.preprocessed_data).with_format(None) + + columns = dataset.column_names + logger.info("Loaded dataset columns: %s", columns) + if "messages_json" not in columns and "messages" not in columns: + logger.warning( + "Dataset has no `messages_json`/`messages` column. " + "If this dataset is multimodal, vLLM requests may miss multi_modal_data." + ) try: asyncio.run(generate_and_save_hidden_states(args, dataset)) diff --git a/scripts/launch_vllm.py b/scripts/launch_vllm.py index 0d310c240..0f298c308 100644 --- a/scripts/launch_vllm.py +++ b/scripts/launch_vllm.py @@ -5,6 +5,113 @@ import warnings +def unwrap_verifier_configs(config): + """Return multimodal/container config and the text backbone config.""" + multimodal_config = getattr(config, "thinker_config", config) + text_config = multimodal_config + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + return multimodal_config, text_config + + +def get_deepstack_visual_indexes(multimodal_config) -> list[int]: + """Get DeepStack layer indexes when present on the verifier.""" + vision_config = getattr(multimodal_config, "vision_config", None) + deepstack_layers = getattr(vision_config, "deepstack_visual_indexes", None) + if deepstack_layers is None: + deepstack_layers = getattr(multimodal_config, "deepstack_visual_indexes", None) + return list(deepstack_layers or []) + + +def deduplicate_layer_ids(layer_ids: list[int]) -> list[int]: + """Deduplicate layer ids while preserving user-specified order.""" + seen: set[int] = set() + result: list[int] = [] + for layer_id in layer_ids: + if layer_id not in seen: + seen.add(layer_id) + result.append(layer_id) + return result + + +def validate_layer_ids( + layer_ids: list[int], num_hidden_layers: int, option_name: str +) -> list[int]: + """Validate vLLM post-layer ids and preserve the effective order.""" + validated = deduplicate_layer_ids(layer_ids) + invalid = [ + layer_id + for layer_id in validated + if layer_id < 0 or layer_id > num_hidden_layers + ] + if invalid: + raise ValueError( + f"{option_name} contains invalid layer ids {invalid}. " + f"Expected ids in [0, {num_hidden_layers}], where 0 is the " + "embedding output and num_hidden_layers is the final decoder output." + ) + return validated + + +def get_default_target_layer_ids(multimodal_config, num_hidden_layers: int) -> list[int]: + """Return default auxiliary layer ids used by training.""" + deepstack_layers = set(get_deepstack_visual_indexes(multimodal_config)) + candidate_layer_ids = [2, num_hidden_layers // 2, num_hidden_layers - 3] + return [ + layer_id - 1 if layer_id in deepstack_layers else layer_id + for layer_id in candidate_layer_ids + ] + + +def resolve_layer_ids(args, multimodal_config, num_hidden_layers: int): + """Resolve training layer ids and exact vLLM extraction layer ids. + + vLLM's extract_hidden_states path stores exactly the configured + eagle_aux_hidden_state_layer_ids. DFlash data loading consumes all but the + last stored slice as training auxiliary hidden states and treats the last + slice as the verifier/final reference state. Therefore, when + --include-last-layer is enabled, the final layer is forced to the end of the + extraction list but is not reported as a training target layer id. + """ + if args.target_layer_ids: + cli_layer_ids = validate_layer_ids( + list(args.target_layer_ids), num_hidden_layers, "--target-layer-ids" + ) + # If users pass the final layer explicitly, keep DFlash layout correct by + # moving it to the last extracted slice. This avoids a costly runtime + # tensor reorder and preserves the loader's [:, :-1] / [:, -1] split. + target_layer_ids = [ + layer_id for layer_id in cli_layer_ids if layer_id != num_hidden_layers + ] + if not target_layer_ids and args.include_last_layer: + raise ValueError( + "--target-layer-ids must contain at least one non-final auxiliary " + "layer when --include-last-layer is enabled." + ) + source = "custom" + else: + target_layer_ids = validate_layer_ids( + get_default_target_layer_ids(multimodal_config, num_hidden_layers), + num_hidden_layers, + "default target layer ids", + ) + source = "default" + + extraction_layer_ids = list(target_layer_ids) + if args.include_last_layer: + extraction_layer_ids.append(num_hidden_layers) + extraction_layer_ids = validate_layer_ids( + extraction_layer_ids, num_hidden_layers, "resolved extraction layer ids" + ) + if not extraction_layer_ids: + raise ValueError( + "At least one vLLM extraction layer id must be selected. Pass one or " + "more --target-layer-ids values or keep --include-last-layer enabled." + ) + + return target_layer_ids, extraction_layer_ids, source + + def parse_args(): parser = argparse.ArgumentParser( description="Launch vLLM for hidden states extraction", @@ -24,12 +131,17 @@ def parse_args(): ) parser.add_argument( "--target-layer-ids", + "--layer-ids", + dest="target_layer_ids", type=int, nargs="+", help=( - "(Optional) A (space separated) list of integer layer ids. Defaults to " - "[2, num_hidden_layers // 2, num_hidden_layers - 3]. " - "Note: if set, you must also pass the same value into the training process" + "Auxiliary post-layer ids to extract for training. Alias: --layer-ids. " + "vLLM layer ids are in [0, num_hidden_layers], where 0 is the " + "embedding output and num_hidden_layers is the final decoder output. " + "When --include-last-layer is enabled, num_hidden_layers is appended " + "to the vLLM extraction ids and should not be passed to training. " + "Defaults to [2, num_hidden_layers // 2, num_hidden_layers - 3]." ), ) parser.add_argument( @@ -37,8 +149,9 @@ def parse_args(): action=argparse.BooleanOptionalAction, default=True, help=( - "Append the last layer (num_hidden_layers) to " - "target_layer_ids for verifier hidden states extraction. Default: True" + "For DFlash models, append the last layer (num_hidden_layers) to the " + "vLLM extraction ids as the final verifier/reference slice. " + "Default: True" ), ) parser.add_argument( @@ -56,34 +169,58 @@ def main(): from transformers import AutoConfig # noqa: PLC0415 - config = AutoConfig.from_pretrained(args.model) - if hasattr(config, "text_config"): - config = config.text_config + raw_config = AutoConfig.from_pretrained(args.model) + multimodal_config, config = unwrap_verifier_configs(raw_config) num_hidden_layers = config.num_hidden_layers if args.target_layer_ids: - target_layer_ids = args.target_layer_ids - if args.include_last_layer and num_hidden_layers not in target_layer_ids: - target_layer_ids.append(num_hidden_layers) + training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids( + args, multimodal_config, num_hidden_layers + ) warnings.warn( - f"Using custom target layer ids {target_layer_ids}. These " - "must also be explicitly passed into the training script.", + "Using custom target layer ids. Pass " + f"{training_target_layer_ids} to the training script; vLLM will " + f"extract {extraction_layer_ids}.", stacklevel=2, ) else: - target_layer_ids = [ - 2, - num_hidden_layers // 2, - num_hidden_layers - 3, - num_hidden_layers, - ] + training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids( + args, multimodal_config, num_hidden_layers + ) + + print( + "Layer ids: " + f"source={layer_id_source}, training_target_layer_ids=" + f"{training_target_layer_ids}, extraction_layer_ids={extraction_layer_ids}" + ) + + # Build overrides for ExtractHiddenStatesConfig. + # For nested multimodal configs, promote text-backbone fields to top level + # so vLLM can resolve a valid text config. + hf_config_overrides: dict = {"eagle_aux_hidden_state_layer_ids": extraction_layer_ids} + if config is not raw_config: + # Promote nested text-backbone fields; drop conflicting wrapper fields. + _text_cfg_dict = config.to_dict() + for _k in ("architectures", "model_type", "auto_map", "torch_dtype"): + _text_cfg_dict.pop(_k, None) + + # Clear nested selector attrs so HF falls back to promoted top-level + # text fields instead of stale dict payloads. + for _nested_text_attr in ( + "text_config", + "text_encoder", + "decoder", + "generator", + ): + _text_cfg_dict[_nested_text_attr] = None + + # kwargs override model_dict, exposing text attrs at top level. + hf_config_overrides = {**_text_cfg_dict, **hf_config_overrides} speculative_config = { "method": "extract_hidden_states", "num_speculative_tokens": 1, - "draft_model_config": { - "hf_config": {"eagle_aux_hidden_state_layer_ids": target_layer_ids} - }, + "draft_model_config": {"hf_config": hf_config_overrides}, } kv_transfer_config = { "kv_connector": "ExampleHiddenStatesConnector", diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 5c4624103..6eb433cba 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -190,6 +190,7 @@ def main(): turn_dropout=args.turn_dropout, minimum_valid_tokens=args.minimum_valid_tokens, trust_remote_code=args.trust_remote_code, + multimodal_output_dir=output, ) log.info("Done preparing data") diff --git a/scripts/train.py b/scripts/train.py index a1e4e6de5..573e7a162 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,4 +1,5 @@ import argparse +import json import logging import random import warnings @@ -65,6 +66,36 @@ def set_seed(seed: int, deterministic: bool = False): torch.backends.cudnn.benchmark = False +def unwrap_verifier_text_config(verifier_config: PretrainedConfig) -> PretrainedConfig: + """Unwrap multimodal verifier configs to their text backbone config.""" + if hasattr(verifier_config, "thinker_config"): + verifier_config = verifier_config.thinker_config + if hasattr(verifier_config, "text_config"): + verifier_config = verifier_config.text_config + return verifier_config + + +def get_default_draft_intermediate_size(verifier_config: PretrainedConfig) -> int: + """Infer a dense drafter FFN width from dense or MoE verifier configs.""" + moe_ffn = getattr(verifier_config, "moe_intermediate_size", None) + if moe_ffn is not None: + experts_per_tok = int(getattr(verifier_config, "num_experts_per_tok", 1)) + active = int(moe_ffn) * max(experts_per_tok, 1) + shared_ffn = getattr(verifier_config, "shared_expert_intermediate_size", None) + if shared_ffn is not None: + active += int(shared_ffn) + return active + + dense_ffn = getattr(verifier_config, "intermediate_size", None) + if dense_ffn is not None: + return int(dense_ffn) + + raise AttributeError( + f"Cannot infer draft intermediate_size from {type(verifier_config).__name__}; " + "pass --draft-intermediate-size explicitly." + ) + + def setup_dataloader( dataset: BaseDataset, world_size: int, @@ -112,8 +143,18 @@ def create_transformer_layer_config( # noqa: C901 num_layers: int, draft_arch: str, hidden_act: str | None, + # --- DFlash Hybrid Attention sliding window support --- sliding_window: int, sliding_window_indices: list[int], + # --- Qwen3.6 Hybrid Attention support --- + draft_intermediate_size: int | None = None, + draft_num_attention_heads: int | None = None, + draft_num_key_value_heads: int | None = None, + draft_head_dim: int | None = None, + draft_rope_scaling: dict | None = None, + draft_rope_theta: float | None = None, + draft_max_position_embeddings: int | None = None, + mrope_full_head_hack: bool = True, ) -> PretrainedConfig: if draft_arch not in DRAFT_ARCH_CONFIGS: raise ValueError( @@ -130,11 +171,9 @@ def create_transformer_layer_config( # noqa: C901 ) config_class = DRAFT_ARCH_CONFIGS[draft_arch] - verifier_config = AutoConfig.from_pretrained(verifier_name_or_path) - - # For multimodal models (Qwen3VL, etc.), extract text_config - if hasattr(verifier_config, "text_config"): - verifier_config = verifier_config.text_config + verifier_config = unwrap_verifier_text_config( + AutoConfig.from_pretrained(verifier_name_or_path, trust_remote_code=True) + ) hidden_act = ( hidden_act @@ -147,56 +186,142 @@ def create_transformer_layer_config( # noqa: C901 "nor 'hidden_activation'" ) - head_dim = getattr(verifier_config, "head_dim", None) - num_attention_heads = verifier_config.num_attention_heads - num_key_value_heads = verifier_config.num_key_value_heads + if draft_intermediate_size is None: + draft_intermediate_size = get_default_draft_intermediate_size(verifier_config) - if ( - head_dim - and verifier_config.hidden_size % num_attention_heads != 0 - and verifier_config.hidden_size % head_dim == 0 - ): - num_attention_heads = verifier_config.hidden_size // head_dim - if num_attention_heads % num_key_value_heads != 0: - num_key_value_heads = num_attention_heads + n_heads = draft_num_attention_heads or verifier_config.num_attention_heads + n_kv = draft_num_key_value_heads or verifier_config.num_key_value_heads + hd = draft_head_dim or getattr(verifier_config, "head_dim", None) + hidden_size = verifier_config.hidden_size + resolved_head_dim = hd or hidden_size // n_heads + if n_heads % n_kv != 0: + raise ValueError( + f"Invalid GQA ratio: num_attention_heads({n_heads}) must be divisible " + f"by num_key_value_heads({n_kv})." + ) + if resolved_head_dim <= 0: + raise ValueError(f"Invalid head_dim({resolved_head_dim}); must be positive.") + + rope_kwargs: dict = {} + verifier_rope_theta = getattr(verifier_config, "rope_theta", None) + if verifier_rope_theta is None: + for src_name in ("rope_parameters", "rope_scaling"): + src = getattr(verifier_config, src_name, None) + if isinstance(src, dict) and src.get("rope_theta") is not None: + verifier_rope_theta = src["rope_theta"] + logger.info( + "Drafter rope_theta recovered from verifier " + f"{src_name}.rope_theta = {verifier_rope_theta}." + ) + break + if verifier_rope_theta is not None: + rope_kwargs["rope_theta"] = verifier_rope_theta + + verifier_rope_scaling = getattr(verifier_config, "rope_scaling", None) + verifier_rope_parameters = getattr(verifier_config, "rope_parameters", None) + if verifier_rope_scaling is not None: + rope_kwargs["rope_scaling"] = dict(verifier_rope_scaling) + elif isinstance(verifier_rope_parameters, dict): + rope_kwargs["rope_scaling"] = dict(verifier_rope_parameters) + + if draft_rope_scaling is not None: + cli_rope_scaling = dict(draft_rope_scaling) + cli_rope_theta = cli_rope_scaling.pop("rope_theta", None) + if cli_rope_theta is not None: + rope_kwargs["rope_theta"] = cli_rope_theta + rope_kwargs["rope_scaling"] = cli_rope_scaling + logger.info(f"Drafter rope_scaling overridden via CLI: {cli_rope_scaling}") + + if draft_rope_theta is not None: + rope_kwargs["rope_theta"] = float(draft_rope_theta) + logger.info( + "Drafter rope_theta overridden via --draft-rope-theta: " + f"{rope_kwargs['rope_theta']}" + ) + + rope_scaling_dict = rope_kwargs.get("rope_scaling") + if isinstance(rope_scaling_dict, dict) and "mrope_section" in rope_scaling_dict: + if draft_rope_theta is not None: + rope_scaling_dict["rope_theta"] = float(draft_rope_theta) + elif "rope_theta" in rope_kwargs and "rope_theta" not in rope_scaling_dict: + rope_scaling_dict["rope_theta"] = rope_kwargs["rope_theta"] + + inherited_partial = float(rope_scaling_dict.get("partial_rotary_factor", 1.0)) + if mrope_full_head_hack and inherited_partial < 1.0: + old_section = list(rope_scaling_dict["mrope_section"]) + inv = 1.0 / inherited_partial + if abs(inv - round(inv)) > 1e-6: + raise ValueError( + "mrope_full_head_hack cannot rescale mrope_section because " + f"1/partial_rotary_factor={inv} is not an integer." + ) + scale = int(round(inv)) + new_section = [int(x) * scale for x in old_section] + if 2 * sum(new_section) != resolved_head_dim: + raise ValueError( + "mrope_full_head_hack rescaling produced inconsistent " + f"mrope_section {new_section}: 2*sum={2 * sum(new_section)} " + f"but head_dim={resolved_head_dim}." + ) + rope_scaling_dict["mrope_section"] = new_section + rope_scaling_dict["partial_rotary_factor"] = 1.0 + logger.warning( + "MRoPE full-head hack applied: partial_rotary_factor " + f"{inherited_partial} -> 1.0, mrope_section {old_section} -> " + f"{new_section}." + ) + elif not mrope_full_head_hack and inherited_partial < 1.0: + logger.warning( + "mrope_full_head_hack=False with partial_rotary_factor=" + f"{inherited_partial} < 1.0 can cause HF trainer / vLLM " + "partial-rotation mismatch." + ) + + max_pos = ( + int(draft_max_position_embeddings) + if draft_max_position_embeddings is not None + else verifier_config.max_position_embeddings + ) + # --- DFlash sliding window layer_types --- if sliding_window_indices and ( min(sliding_window_indices) < 0 or max(sliding_window_indices) >= num_layers ): raise ValueError( - "Sliding window indices must be validate draft layer ids " + "Sliding window indices must be valid draft layer ids " "in range [0, num_layers)." ) layer_types = [ - "sliding_attention" if i in sliding_window_indices else "full_attention" + "sliding_attention" if i in (sliding_window_indices or []) else "full_attention" for i in range(num_layers) ] config = config_class( vocab_size=verifier_config.vocab_size, - hidden_size=verifier_config.hidden_size, - intermediate_size=verifier_config.intermediate_size, + hidden_size=hidden_size, + intermediate_size=draft_intermediate_size, num_hidden_layers=num_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, + num_attention_heads=n_heads, + num_key_value_heads=n_kv, hidden_act=hidden_act, - max_position_embeddings=verifier_config.max_position_embeddings, + max_position_embeddings=max_pos, initializer_range=verifier_config.initializer_range, rms_norm_eps=verifier_config.rms_norm_eps, - head_dim=head_dim, + head_dim=resolved_head_dim, tie_word_embeddings=False, sliding_window=sliding_window, layer_types=layer_types, ) - # New rope parameters definition introduced in transformers 5.0 - if version.parse(transformers.__version__) >= version.parse("5.0.0"): - if hasattr(verifier_config, "rope_parameters"): - config.rope_parameters = deepcopy(verifier_config.rope_parameters) - else: - if hasattr(verifier_config, "rope_scaling"): - config.rope_scaling = deepcopy(verifier_config.rope_scaling) - config.rope_theta = getattr(verifier_config, "rope_theta", 10000.0) + # RoPE: use transformers >= 5.0 rope_parameters path + if hasattr(verifier_config, "rope_parameters"): + config.rope_parameters = deepcopy(verifier_config.rope_parameters) + # Apply CLI overrides into rope_parameters + if rope_kwargs.get("rope_scaling"): + config.rope_parameters = deepcopy(rope_kwargs["rope_scaling"]) + if rope_kwargs.get("rope_theta") is not None: + if isinstance(config.rope_parameters, dict): + config.rope_parameters["rope_theta"] = rope_kwargs["rope_theta"] return config @@ -267,9 +392,9 @@ def parse_vocab_mappings(args: argparse.Namespace): "None. Using full verifier vocab" ) # When vocab mapping is not provided, use the full verifier vocab - verifier_config = AutoConfig.from_pretrained(args.verifier_name_or_path) - if hasattr(verifier_config, "text_config"): - verifier_config = verifier_config.text_config + verifier_config = unwrap_verifier_text_config( + AutoConfig.from_pretrained(args.verifier_name_or_path, trust_remote_code=True) + ) return None, None, verifier_config.vocab_size @@ -305,6 +430,14 @@ def main(args: argparse.Namespace): num_layers=args.num_layers, draft_arch=args.draft_arch, hidden_act=args.draft_hidden_act, + draft_intermediate_size=args.draft_intermediate_size, + draft_num_attention_heads=args.draft_num_attention_heads, + draft_num_key_value_heads=args.draft_num_key_value_heads, + draft_head_dim=args.draft_head_dim, + draft_rope_scaling=args.draft_rope_scaling, + draft_rope_theta=args.draft_rope_theta, + draft_max_position_embeddings=args.draft_max_position_embeddings, + mrope_full_head_hack=args.draft_mrope_full_head_hack, sliding_window=args.sliding_window, sliding_window_indices=args.sliding_window_indices, ) @@ -371,6 +504,7 @@ def main(args: argparse.Namespace): transform=noise_transform, split_ratio=0.9, model=args.verifier_name_or_path, + verifier_name_or_path=args.verifier_name_or_path, hidden_states_dtype=hidden_states_dtype, request_timeout=args.request_timeout, max_retries=args.max_retries, @@ -384,6 +518,7 @@ def main(args: argparse.Namespace): on_generate=args.on_generate, split_ratio=-0.1, model=args.verifier_name_or_path, + verifier_name_or_path=args.verifier_name_or_path, hidden_states_dtype=hidden_states_dtype, request_timeout=args.request_timeout, max_retries=args.max_retries, @@ -591,6 +726,58 @@ def parse_args(): "vLLM deployment. If another function is desired, set as a string or leave " "as None to automatically fall back to the verifier's activation function.", ) + parser.add_argument( + "--draft-intermediate-size", + type=int, + default=None, + help="Override draft FFN intermediate size; useful for MoE verifiers.", + ) + parser.add_argument( + "--draft-num-attention-heads", + type=int, + default=None, + help="Override the drafter's num_attention_heads.", + ) + parser.add_argument( + "--draft-num-key-value-heads", + type=int, + default=None, + help="Override the drafter's num_key_value_heads.", + ) + parser.add_argument( + "--draft-head-dim", + type=int, + default=None, + help="Override the drafter's per-head hidden dimension.", + ) + parser.add_argument( + "--draft-rope-scaling", + type=lambda s: json.loads(s) if s else None, + default=None, + help="JSON RoPE scaling dict to apply to the drafter during training.", + ) + parser.add_argument( + "--draft-rope-theta", + type=float, + default=None, + help="Override the drafter's RoPE frequency base.", + ) + parser.add_argument( + "--draft-max-position-embeddings", + type=int, + default=None, + help="Override the drafter's max_position_embeddings.", + ) + parser.add_argument( + "--draft-mrope-full-head-hack", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For MRoPE configs with partial_rotary_factor < 1, rescale " + "mrope_section and set partial_rotary_factor=1.0 so HF training " + "and vLLM inference use equivalent full-head rotary semantics." + ), + ) parser.add_argument( "--target-layer-ids", type=int, diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 770724cbe..ba51259a3 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -1,17 +1,21 @@ import bisect +import inspect +import json import random import re from collections.abc import Callable from contextlib import nullcontext from pathlib import Path from re import Pattern -from typing import cast +from typing import Any, cast import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset from packaging.version import Version +from safetensors.torch import save_file from transformers import ( + AutoConfig, AutoProcessor, BatchEncoding, BatchFeature, @@ -35,6 +39,18 @@ ProcessorLike = PreTrainedTokenizerBase | ProcessorMixin +MULTIMODAL_SIDECAR_DIR = "multimodal" +MULTIMODAL_MEDIA_TYPES = {"image", "video", "audio"} +MULTIMODAL_ENCODER_KEYS = ( + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grids", + "input_features", + "feature_attention_mask", + "audio_feature_lengths", +) def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: int = 0): @@ -73,6 +89,196 @@ def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: in log.info(highlighted) +def _normalize_media_value(value: Any) -> Any: + if isinstance(value, Path): + return str(value) + filename = getattr(value, "filename", None) + if filename: + return str(filename) + return value + + +def _to_json_compatible(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_json_compatible(v) for k, v in value.items()} + if isinstance(value, list): + return [_to_json_compatible(v) for v in value] + if isinstance(value, tuple): + return [_to_json_compatible(v) for v in value] + if isinstance(value, Path): + return str(value) + filename = getattr(value, "filename", None) + if filename: + return str(filename) + return value + + +def _normalize_content_segment(segment: Any) -> dict[str, Any]: + if isinstance(segment, str): + return {"type": "text", "text": segment} + if not isinstance(segment, dict): + return {"type": "text", "text": str(segment)} + + seg_type = str(segment.get("type", "text")) + normalized = {"type": seg_type} + if seg_type == "text": + normalized["text"] = ( + segment.get("text") + or segment.get("value") + or segment.get("content") + or "" + ) + elif seg_type in MULTIMODAL_MEDIA_TYPES: + normalized[seg_type] = _normalize_media_value( + segment.get(seg_type) + or segment.get("value") + or segment.get("url") + or segment.get("path") + or segment.get("source") + ) + else: + normalized["text"] = segment.get("text") or segment.get("value") or "" + + for key, value in segment.items(): + if key in normalized or key in {"content", "value", "url", "path", "source"}: + continue + normalized[key] = _to_json_compatible(value) + return normalized + + +def _normalize_turn_content(content: Any) -> str | list[dict[str, Any]]: + if isinstance(content, list): + return [_normalize_content_segment(seg) for seg in content] + return content if isinstance(content, str) else str(content or "") + + +def _has_multimodal_segments(content: Any) -> bool: + return isinstance(content, list) and any( + isinstance(seg, dict) and seg.get("type") in MULTIMODAL_MEDIA_TYPES + for seg in content + ) + + +def _is_multimodal_conversation(conv: list[dict[str, Any]]) -> bool: + return any(_has_multimodal_segments(turn.get("content")) for turn in conv) + + +def _serialize_messages(messages: list[dict[str, Any]]) -> str: + return json.dumps(_to_json_compatible(messages), ensure_ascii=False) + + +def _sanitize_sidecar_prefix(prefix: str) -> str: + return re.sub(r"[^a-zA-Z0-9_.-]+", "_", prefix).strip("_") or "dataset" + + +def _build_sidecar_path( + multimodal_output_dir: str | Path, + sample_idx: int, + sidecar_prefix: str, +) -> Path: + base_dir = Path(multimodal_output_dir) / MULTIMODAL_SIDECAR_DIR + base_dir.mkdir(parents=True, exist_ok=True) + safe_prefix = _sanitize_sidecar_prefix(sidecar_prefix) + return base_dir / f"{safe_prefix}_{sample_idx}.safetensors" + + +def _maybe_strip_batch_dim(value: Any) -> torch.Tensor: + tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value) + tensor = tensor.detach().cpu().contiguous() + if tensor.ndim > 0 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + return tensor + + +def _save_multimodal_sidecar( + encoded: dict[str, Any], + multimodal_output_dir: str | Path, + sample_idx: int, + sidecar_prefix: str, +) -> str: + sidecar_path = _build_sidecar_path(multimodal_output_dir, sample_idx, sidecar_prefix) + payload = { + key: _maybe_strip_batch_dim(encoded[key]) + for key in MULTIMODAL_ENCODER_KEYS + if key in encoded + } + if not payload: + raise ValueError("Multimodal sample did not produce sidecar tensor fields") + save_file(payload, sidecar_path) + return str(sidecar_path.relative_to(Path(multimodal_output_dir))) + + +def _build_multimodal_loss_mask( + input_ids: torch.Tensor, + base_loss_mask: torch.Tensor, + placeholder_token_ids: tuple[int, ...], +) -> torch.Tensor: + loss_mask = base_loss_mask.to(dtype=torch.long).clone() + valid_ids = [tid for tid in placeholder_token_ids if tid >= 0] + if not valid_ids: + return loss_mask + placeholder_tensor = torch.as_tensor( + valid_ids, dtype=input_ids.dtype, device=input_ids.device + ) + loss_mask.masked_fill_(torch.isin(input_ids, placeholder_tensor), 0) + return loss_mask + + +def _mask_has_positive(mask: Any) -> bool: + mask_tensor = _maybe_strip_batch_dim(mask) + return bool(mask_tensor.numel() > 0 and torch.count_nonzero(mask_tensor).item() > 0) + + +_PROCESSOR_KW_CACHE: dict[int, set[str]] = {} + + +def _processor_kwargs(processor: Any) -> set[str]: + key = id(processor) + cached = _PROCESSOR_KW_CACHE.get(key) + if cached is not None: + return cached + try: + sig = inspect.signature(processor.apply_chat_template) + names = { + name + for name, param in sig.parameters.items() + if param.kind is not inspect.Parameter.VAR_KEYWORD + } + except (TypeError, ValueError): + names = set() + _PROCESSOR_KW_CACHE[key] = names + return names + + +def _conversation_use_audio_in_video(conv: list[dict[str, Any]]) -> bool: + for turn in conv: + content = turn.get("content") + if not isinstance(content, list): + continue + for seg in content: + if isinstance(seg, dict) and seg.get("use_audio_in_video"): + return True + return False + + +def _as_processor_content_blocks(content: Any) -> list[dict[str, Any]]: + if isinstance(content, list): + return [ + seg if isinstance(seg, dict) else {"type": "text", "text": str(seg)} + for seg in content + ] + return [ + {"type": "text", "text": content if isinstance(content, str) else str(content or "")} + ] + + +def _conversation_for_processor(conv: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + {**turn, "content": _as_processor_content_blocks(turn.get("content", ""))} + for turn in conv + ] + + def _normalize_conversation( conv: list[dict], turn_dropout: bool = False, @@ -92,7 +298,7 @@ def _normalize_conversation( normalized = [] for i, turn in enumerate(conv): role = turn.get("from", turn.get("role", "")) - content = turn.get("value") or turn.get("content") or "" + content = _normalize_turn_content(turn.get("value") or turn.get("content") or "") # Map various role names to standard user/assistant if role in ("human", "user"): @@ -165,10 +371,12 @@ def _adapt_part_for_vllm(part: str | dict): for modality in ("image", "video", "audio"): if part_type == modality: - if local_path := part.get("path"): - file_url = f"file://{Path(local_path).absolute()}" - return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} - if url := part.get("url"): + media_value = part.get(modality) or part.get("path") or part.get("url") + if isinstance(media_value, str): + if media_value.startswith(("http://", "https://", "file://")): + url = media_value + else: + url = f"file://{Path(media_value).absolute()}" return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} if part.get("base64"): @@ -178,13 +386,6 @@ def _adapt_part_for_vllm(part: str | dict): f"the {modality} when saving the preprocessed dataset, " f"please express {modality} inputs using file paths or URLs." ) - if part.get(modality): - expr = {"type": modality, modality: "..."} - raise ValueError( - f"Content part {expr} is not supported. To avoid copying " - f"the {modality} when saving the preprocessed dataset, " - f"please express {modality} inputs using file paths or URLs." - ) expr = {"type": modality} | {k: "..." for k in part if k != "type"} raise NotImplementedError(f"Unknown content part: {expr}") @@ -234,7 +435,7 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: return False # Verify the mask is not all zeros - return any(m == 1 for m in mask) + return _mask_has_positive(mask) except (TypeError, ValueError, KeyError, AttributeError) as e: log.warning(f"An error occurred when trying to return assistant mask: {e}") return False @@ -289,11 +490,8 @@ def _detect_assistant_pattern(processor: ProcessorLike) -> str: else: role_marker = prefix - # Strip ... blocks from the role marker. Thinking model - # templates wrap assistant content in these tags, but the test messages - # can produce empty blocks (e.g. "\n\n\n") with reasoning models, - # which then get baked into the regex as literals. Removing them ensures - # that reasoning stays within the assistant content group. + # Remove optional ... wrappers from role markers so + # regex capture focuses on assistant content boundaries. role_marker = re.sub(r".*?\s*", "", role_marker, flags=re.DOTALL) # Determine the stable TURN-LEVEL suffix @@ -382,6 +580,104 @@ def _create_loss_mask_from_offsets( return loss_mask +def _content_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + return "".join( + str(seg.get("text", "")) if isinstance(seg, dict) else str(seg) + for seg in content + ) + return str(content or "") + + +def _find_token_subsequence( + haystack: torch.Tensor, + needle: torch.Tensor, + start: int, +) -> int | None: + if needle.numel() == 0 or haystack.numel() < needle.numel(): + return None + for idx in range(start, int(haystack.numel() - needle.numel()) + 1): + if torch.equal(haystack[idx : idx + needle.numel()], needle): + return idx + return None + + +def _loss_mask_from_assistant_token_spans( + input_ids: torch.Tensor, + normalized_conv: list[dict[str, Any]], + tokenizer: PreTrainedTokenizerBase, +) -> torch.Tensor | None: + loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + cursor = 0 + matches_found = 0 + for turn in normalized_conv: + if turn.get("role") != "assistant": + continue + text = _content_text(turn.get("content", "")) + if not text: + continue + tokenized = tokenizer(text, add_special_tokens=False) + token_ids = tokenized.get("input_ids", []) + if not token_ids: + continue + needle = torch.as_tensor(token_ids, dtype=torch.long, device=input_ids.device) + span_start = _find_token_subsequence(input_ids, needle, cursor) + if span_start is None: + log.warning("Could not align assistant content tokens in processor input_ids") + continue + span_end = span_start + int(needle.numel()) + loss_mask[span_start:span_end] = 1 + cursor = span_end + matches_found += 1 + return loss_mask if matches_found else None + + +def _loss_mask_from_ids_fallback( + input_ids: torch.Tensor, + normalized_conv: list[dict[str, Any]], + tokenizer: PreTrainedTokenizerBase, + assistant_pattern: str | Pattern[str], + placeholder_token_ids: tuple[int, ...] = (), +) -> torch.Tensor: + formatted_raw_any = tokenizer.apply_chat_template( + normalized_conv, + tokenize=False, + add_generation_prompt=False, + ) + formatted_raw = formatted_raw_any if isinstance(formatted_raw_any, str) else "" + encoding = tokenizer( + formatted_raw, + return_offsets_mapping=True, + add_special_tokens=False, + ) + mask_text = _create_loss_mask_from_offsets( + formatted_raw, encoding["offset_mapping"], assistant_pattern + ).to(torch.long) + + target_len = int(input_ids.shape[0]) + if placeholder_token_ids: + placeholder_tensor = torch.as_tensor( + placeholder_token_ids, dtype=input_ids.dtype, device=input_ids.device + ) + is_placeholder = torch.isin(input_ids, placeholder_tensor) + if bool(is_placeholder.any()): + aligned = torch.zeros(target_len, dtype=torch.long) + text_positions = (~is_placeholder).nonzero(as_tuple=True)[0] + copy_len = min(int(text_positions.shape[0]), int(mask_text.shape[0])) + if copy_len > 0: + aligned.index_copy_(0, text_positions[:copy_len], mask_text[:copy_len]) + return aligned + + if mask_text.shape[0] == target_len: + return mask_text + if mask_text.shape[0] > target_len: + return mask_text[:target_len] + pad = torch.zeros(target_len - mask_text.shape[0], dtype=torch.long) + return torch.cat([mask_text, pad], dim=0) + + def _get_input_ids_loss_mask( normalized_conv: list[dict], processor: ProcessorLike, @@ -485,63 +781,170 @@ def _preprocess_batch( assistant_pattern: str | Pattern[str] | None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, + *, + indices: list[int] | None = None, + placeholder_token_ids: tuple[int, ...] = (), + multimodal_output_dir: str | Path | None = None, + sidecar_prefix: str = "dataset", ) -> dict[str, list]: """Process a batch of conversations into tokenized format with loss masks.""" - results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} - conversations: list[dict] = examples.get("conversations", []) - - # MM inputs must use Chat Completions API - if isinstance(processor, ProcessorMixin): + results: dict[str, list] = { + "input_ids": [], + "loss_mask": [], + "seq_len": [], + "messages_json": [], + "mm_file": [], + "use_audio_in_video": [], + } + include_messages = isinstance(processor, ProcessorMixin) + if include_messages: results["messages"] = [] + conversations: list[dict] = examples.get("conversations", []) if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") return results for idx, conv in enumerate(conversations): + sample_idx = indices[idx] if indices is not None else idx if not conv or not isinstance(conv, list): + log.warning( + f"[DROP sample_idx={sample_idx}] reason=empty_or_non_list_conversation " + f"type={type(conv).__name__}" + ) continue - # Normalize to standard format with optional turn dropout normalized_conv = _normalize_conversation(conv, turn_dropout) if not normalized_conv: + log.warning( + f"[DROP sample_idx={sample_idx}] reason=normalized_conversation_empty " + f"raw_turns={len(conv)}" + ) continue + is_multimodal = include_messages and _is_multimodal_conversation(normalized_conv) + messages = _adapt_conv_for_vllm(normalized_conv) if include_messages else [] + messages_json = _serialize_messages(messages) if is_multimodal else "" + mm_file = "" + use_audio_in_video = int(_conversation_use_audio_in_video(normalized_conv)) + try: - input_ids, loss_mask = _get_input_ids_loss_mask( - normalized_conv, - processor, - max_length=max_length, - assistant_pattern=assistant_pattern, - conv_idx=idx, + if is_multimodal: + allowed = _processor_kwargs(processor) + call_kwargs: dict[str, Any] = dict( + tokenize=True, + add_generation_prompt=False, + return_dict=True, + return_tensors="pt", + processor_kwargs={}, + ) + for key in ( + "load_audio", + "load_image", + "load_video", + "load_audios", + "load_images", + "load_videos", + ): + if key in allowed: + call_kwargs[key] = True + supports_mask = ( + assistant_pattern is None + and "return_assistant_tokens_mask" in allowed + ) + if supports_mask: + call_kwargs["return_assistant_tokens_mask"] = True + + encoded_any = processor.apply_chat_template( + _conversation_for_processor(normalized_conv), **call_kwargs + ) + encoded = cast("dict[str, Any]", encoded_any) + input_ids = _maybe_strip_batch_dim(encoded["input_ids"]).to(torch.long) + + mask_key = None + if supports_mask: + if "assistant_masks" in encoded: + mask_key = "assistant_masks" + elif "assistant_mask" in encoded: + mask_key = "assistant_mask" + if mask_key is not None: + candidate_loss_mask = _maybe_strip_batch_dim(encoded[mask_key]).to( + torch.long + ) + base_loss_mask = candidate_loss_mask if _mask_has_positive(candidate_loss_mask) else None + else: + base_loss_mask = None + + if base_loss_mask is None: + if assistant_pattern is None: + assistant_pattern = _detect_assistant_pattern(processor) + base_loss_mask = _loss_mask_from_assistant_token_spans( + input_ids, normalized_conv, get_tokenizer(processor) + ) + if base_loss_mask is None: + base_loss_mask = _loss_mask_from_ids_fallback( + input_ids, + normalized_conv, + get_tokenizer(processor), + assistant_pattern, + placeholder_token_ids, + ) + + loss_mask = _build_multimodal_loss_mask( + input_ids, base_loss_mask, placeholder_token_ids + ) + if len(input_ids) > max_length: + log.warning( + f"[DROP sample_idx={sample_idx}] reason=overlength_multimodal " + f"len(input_ids)={len(input_ids)} max_length={max_length}" + ) + continue + if multimodal_output_dir is not None: + mm_file = _save_multimodal_sidecar( + encoded, multimodal_output_dir, sample_idx, sidecar_prefix + ) + else: + input_ids, loss_mask = _get_input_ids_loss_mask( + normalized_conv, + processor, + max_length=max_length, + assistant_pattern=assistant_pattern, + conv_idx=idx, + ) + input_ids = torch.tensor(input_ids, dtype=torch.long) + + assert len(input_ids) == len(loss_mask), ( + f"Shape mismatch: input_ids={len(input_ids)}, loss_mask={len(loss_mask)}" ) + + if minimum_valid_tokens is not None: + num_valid_tokens = int(loss_mask.sum().item()) + if num_valid_tokens < minimum_valid_tokens: + log.warning( + f"[DROP sample_idx={sample_idx}] reason=too_few_valid_tokens " + f"num_valid_tokens={num_valid_tokens} " + f"minimum_valid_tokens={minimum_valid_tokens} " + f"len(input_ids)={len(input_ids)} is_multimodal={is_multimodal}" + ) + continue + + results["input_ids"].append(input_ids) + results["loss_mask"].append(loss_mask.to(torch.long)) + results["seq_len"].append(len(input_ids)) + if include_messages: + results["messages"].append(messages) + results["messages_json"].append(messages_json) + results["mm_file"].append(mm_file) + results["use_audio_in_video"].append(use_audio_in_video) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: log.error( - f"Failed to process conversation {idx} " + f"[DROP sample_idx={sample_idx}] reason=exception " + f"exc_type={type(e).__name__} " f"(assistant_pattern={assistant_pattern is not None}): {e}" ) continue - # Assert shapes match - assert len(input_ids) == len(loss_mask), ( - f"Shape mismatch: input_ids={len(input_ids)}, loss_mask={len(loss_mask)}" - ) - - # Filtering samples out with too few valid tokens - if minimum_valid_tokens is not None: - num_valid_tokens = int(loss_mask.sum().item()) - if num_valid_tokens < minimum_valid_tokens: - continue - - # Append to results - results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) - results["loss_mask"].append(loss_mask) - results["seq_len"].append(len(input_ids)) - - if "messages" in results: - results["messages"].append(_adapt_conv_for_vllm(normalized_conv)) - return results @@ -553,6 +956,9 @@ def build_eagle3_dataset( assistant_pattern: str | Pattern[str] | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, + placeholder_token_ids: tuple[int, ...] = (), + multimodal_output_dir: str | Path | None = None, + sidecar_prefix: str = "dataset", ) -> HFDataset: """Build EAGLE3 dataset by tokenizing conversations and creating loss masks. @@ -580,6 +986,11 @@ def build_eagle3_dataset( assistant_pattern = _detect_assistant_pattern(processor) log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...") + if multimodal_output_dir is not None: + (Path(multimodal_output_dir) / MULTIMODAL_SIDECAR_DIR).mkdir( + parents=True, exist_ok=True + ) + original_cols = dataset.column_names # Avoid CPU contention for MM processing: @@ -590,22 +1001,27 @@ def build_eagle3_dataset( else nullcontext() ): dataset = dataset.map( - lambda examples: _preprocess_batch( + lambda examples, indices: _preprocess_batch( examples, processor, max_length, assistant_pattern, turn_dropout, minimum_valid_tokens, + indices=indices, + placeholder_token_ids=placeholder_token_ids, + multimodal_output_dir=multimodal_output_dir, + sidecar_prefix=sidecar_prefix, ), batched=True, + with_indices=True, num_proc=num_proc, - batch_size=1000, + batch_size=400, remove_columns=original_cols, keep_in_memory=True, # skip caching ) - dataset.set_format(type="torch") + dataset.set_format(type="torch", columns=["input_ids", "loss_mask", "seq_len"]) return dataset @@ -667,6 +1083,7 @@ def load_and_preprocess_dataset( turn_dropout: bool = False, minimum_valid_tokens: int | None = None, trust_remote_code: bool = False, + multimodal_output_dir: Path | str | None = None, ) -> tuple[HFDataset, ProcessorLike]: """Load, tokenize, and preprocess a dataset for EAGLE3 training. @@ -710,6 +1127,20 @@ def load_and_preprocess_dataset( "Please use a model with a pre-configured chat template." ) + placeholder_token_ids: tuple[int, ...] = () + verifier_config = AutoConfig.from_pretrained( + target_model_path, trust_remote_code=trust_remote_code + ) + multimodal_config = getattr(verifier_config, "thinker_config", verifier_config) + if isinstance(processor, ProcessorMixin): + placeholder_token_ids = tuple( + int(token_id) + for attr in ("image_token_id", "video_token_id", "audio_token_id") + if (token_id := getattr(multimodal_config, attr, None)) is not None + ) + if multimodal_output_dir is None: + multimodal_output_dir = Path(token_freq_path).parent + processed_datasets = [] for train_data_path in train_data_paths: log.subsection(f"Processing {train_data_path}") @@ -742,6 +1173,9 @@ def load_and_preprocess_dataset( assistant_pattern=assistant_pattern, turn_dropout=turn_dropout, minimum_valid_tokens=minimum_valid_tokens, + placeholder_token_ids=placeholder_token_ids, + multimodal_output_dir=multimodal_output_dir, + sidecar_prefix=train_data_path, ) if minimum_valid_tokens is not None: log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 6549444a9..c97d2b625 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -94,7 +94,15 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs): def extract_output( response: Completion | ChatCompletion, token_ids: list[int], -) -> str: + *, + trust_server_token_ids: bool = False, +) -> str | tuple[str, list[int]]: + """Extract hidden-states path (and optionally server token ids). + + By default, token ids must match exactly. Set + ``trust_server_token_ids=True`` to accept server tokenization and return + ``(path, prompt_token_ids)``. + """ if isinstance(response, Completion): prompt_token_ids = getattr(response.choices[0], "prompt_token_ids", None) else: @@ -103,25 +111,36 @@ def extract_output( if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") - if prompt_token_ids != token_ids: + if not trust_server_token_ids and prompt_token_ids != token_ids: raise InvalidResponseError( - f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" + f"Prompt token IDs mismatch: expected {len(token_ids)} tokens, " + f"got {len(prompt_token_ids)} tokens" ) kv_transfer_params = getattr(response, "kv_transfer_params", None) if kv_transfer_params is None: raise InvalidResponseError("Response missing kv_transfer_params") - return kv_transfer_params.get("hidden_states_path") + path = kv_transfer_params.get("hidden_states_path") + + if trust_server_token_ids: + return path, prompt_token_ids + return path class ClientItem(TypedDict): input_ids: list[int] """The input token IDs.""" + multi_modal_data: NotRequired[dict[str, list]] + """Per-modality URL/path lists (image/video/audio) forwarded to vLLM via + ``extra_body.multi_modal_data`` so the server can attach media features + without re-rendering the chat template.""" + messages: NotRequired[list[ChatCompletionMessageParam]] - """If provided, pass `messages` to Chat Completions API - instead of passing `token_ids` to Completions API.""" + """Optional fallback. Only consumed when ``use_chat_completions=True`` is + passed to :func:`generate_hidden_states`. The default token-id path + ignores this field.""" async def _poll_lock_async(fd, poll_interval): @@ -172,35 +191,43 @@ async def generate_hidden_states_async( client_item: ClientItem, *, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, -) -> str: - """ - Runs decode w/ max_tokens 1 to generate hidden states and returns path to - hidden states file. - - Args: - client: The async OpenAI client. - model: The model ID. - client_item: Inputs to send via the client. - timeout: Timeout in seconds for each request attempt. None for no timeout. + use_chat_completions: bool = False, +) -> str | tuple[str, list[int]]: + """Generate hidden states asynchronously and return the safetensors path. + + For multimodal samples, set ``use_chat_completions=True`` so vLLM runs + its vision encoder. In that mode the return value is a tuple + ``(path, server_token_ids)`` — the caller must use ``server_token_ids`` + as the authoritative input_ids (they are positionally aligned with the + hidden states). + + For text-only samples the default Completions path is used (strict + token-id parity check) and a plain ``str`` path is returned. """ token_ids = client_item["input_ids"] messages = client_item.get("messages") + is_mm_chat = use_chat_completions and messages is not None + coro: Coroutine[Any, Any, Completion | ChatCompletion] - if messages is None: - coro = client.completions.create( + if is_mm_chat: + coro = client.chat.completions.create( model=model, - prompt=token_ids, + messages=messages, max_tokens=1, - extra_body={"return_token_ids": True}, + extra_body={"add_generation_prompt": False, "return_token_ids": True}, timeout=timeout, ) else: - coro = client.chat.completions.create( + extra_body: dict[str, Any] = {"return_token_ids": True} + mm = client_item.get("multi_modal_data") + if mm: + extra_body["multi_modal_data"] = mm + coro = client.completions.create( model=model, - messages=messages, + prompt=token_ids, max_tokens=1, - extra_body={"add_generation_prompt": False, "return_token_ids": True}, + extra_body=extra_body, timeout=timeout, ) @@ -210,7 +237,7 @@ async def generate_hidden_states_async( else: res = await coro - return extract_output(res, token_ids) + return extract_output(res, token_ids, trust_server_token_ids=is_mm_chat) @with_retries @@ -220,30 +247,39 @@ def generate_hidden_states( client_item: ClientItem, *, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, -) -> str: - """ - Runs decode w/ max_tokens 1 to generate hidden states and returns path to - hidden states file. + use_chat_completions: bool = False, +) -> str | tuple[str, list[int]]: + """Generate hidden states via vLLM (synchronous version). + + For multimodal samples, set ``use_chat_completions=True`` so vLLM runs + its vision encoder. Returns ``(path, server_token_ids)`` in that mode. + For text-only samples returns a plain ``str`` path. """ token_ids = client_item["input_ids"] messages = client_item.get("messages") + is_mm_chat = use_chat_completions and messages is not None + res: Completion | ChatCompletion - if messages is None: - res = client.completions.create( + if is_mm_chat: + res = client.chat.completions.create( model=model, - prompt=token_ids, + messages=messages, max_tokens=1, - extra_body={"return_token_ids": True}, + extra_body={"add_generation_prompt": False, "return_token_ids": True}, timeout=timeout, ) else: - res = client.chat.completions.create( + extra_body: dict[str, Any] = {"return_token_ids": True} + mm = client_item.get("multi_modal_data") + if mm: + extra_body["multi_modal_data"] = mm + res = client.completions.create( model=model, - messages=messages, + prompt=token_ids, max_tokens=1, - extra_body={"add_generation_prompt": False, "return_token_ids": True}, + extra_body=extra_body, timeout=timeout, ) - return extract_output(res, token_ids) + return extract_output(res, token_ids, trust_server_token_ids=is_mm_chat) diff --git a/src/speculators/models/eagle3/core.py b/src/speculators/models/eagle3/core.py index 00da65fb4..187354dbd 100644 --- a/src/speculators/models/eagle3/core.py +++ b/src/speculators/models/eagle3/core.py @@ -1,4 +1,5 @@ import copy +import warnings from typing import ClassVar import torch @@ -14,6 +15,7 @@ ) from speculators.models.eagle3.metrics import compute_metrics from speculators.models.eagle3.model_definitions import model_classes +from speculators.models.eagle3.rotary_partial import install_partial_neox_rotary from speculators.models.metrics import kl_div_loss, resolve_loss_fn from speculators.models.utils import resolve_target_layer_ids from speculators.proposals.greedy import GreedyTokenProposalConfig @@ -26,6 +28,128 @@ def conditional_torch_compile(func): return func +def _wrap_qwen_omni_rotary_with_hf_layout(rotary_cls: type) -> type: + """Adapt Qwen-Omni rotary to HF MRoPE layout. + + Qwen-Omni expects ``position_ids`` as ``[3, batch, seq_len]`` while our + training path emits ``[batch, 3, seq_len]``. This wrapper only transposes + the HF case to avoid MRoPE channel mis-broadcast and projection dim errors. + """ + + class HFLayoutMRoPE(rotary_cls): # type: ignore[misc,valid-type] + def forward(self, x, position_ids): # type: ignore[override] + # Only transpose clear HF layout ``[B, 3, T]``. + # Other shapes pass through unchanged. + if position_ids.dim() == 3 and position_ids.shape[1] == 3: + position_ids = position_ids.transpose(0, 1).contiguous() + return super().forward(x, position_ids) + + HFLayoutMRoPE.__name__ = f"{rotary_cls.__name__}HFLayout" + HFLayoutMRoPE.__qualname__ = HFLayoutMRoPE.__name__ + return HFLayoutMRoPE + + +def _select_rotary_emb_class( + tl_config: PretrainedConfig, + default_cls: type, +) -> type: + """Select an MRoPE-aware rotary class for multimodal Qwen draft configs.""" + rope_params = getattr(tl_config, "rope_parameters", None) + has_mrope = ( + isinstance(rope_params, dict) and rope_params.get("mrope_section") is not None + ) + if not has_mrope: + return default_cls + + try: + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( # noqa: PLC0415 + Qwen3OmniMoeThinkerTextRotaryEmbedding, + ) + except ImportError: + warnings.warn( + "Draft config carries rope_parameters.mrope_section but the installed " + "transformers does not expose Qwen3OmniMoeThinkerTextRotaryEmbedding. " + "Falling back to the architecture default rotary embedding, which " + "will ignore MRoPE and can cause a train/inference mismatch for " + "multimodal inputs.", + UserWarning, + stacklevel=2, + ) + return default_cls + + partial_rotary_factor = float(rope_params.get("partial_rotary_factor", 1.0)) + if partial_rotary_factor >= 1.0: + return _wrap_qwen_omni_rotary_with_hf_layout( + Qwen3OmniMoeThinkerTextRotaryEmbedding + ) + + # Align HF partial rotation with vLLM partial-neox behavior. + # Keep native ``partial_rotary_factor``/``mrope_section`` unchanged. + install_partial_neox_rotary() + return _wrap_qwen_omni_rotary_with_hf_layout( + _make_partial_mrope_rotary_cls( + Qwen3OmniMoeThinkerTextRotaryEmbedding, partial_rotary_factor + ) + ) + + +def _make_partial_mrope_rotary_cls( + base_cls: type, partial_rotary_factor: float +) -> type: + """Return a partial-MRoPE rotary class for Qwen drafts. + + It emits unpadded ``[*, rotary_dim]`` cos/sin. The patched + ``apply_rotary_pos_emb`` rotates only the leading ``rotary_dim`` channels + and keeps the tail unchanged, matching vLLM partial-neox semantics. + """ + + class PartialMRoPE(base_cls): # type: ignore[misc,valid-type] + _partial_rotary_factor: ClassVar[float] = partial_rotary_factor + + @staticmethod + def compute_default_rope_parameters(config=None, device=None, seq_len=None): + base = config.rope_parameters["rope_theta"] + head_dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + rotary_dim = int(head_dim * PartialMRoPE._partial_rotary_factor) + rotary_dim = (rotary_dim // 2) * 2 + attention_factor = 1.0 + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, rotary_dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / rotary_dim + ) + ) + return inv_freq, attention_factor + + def __init__(self, config, device=None): + super().__init__(config=config, device=device) + head_dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + rotary_dim = int(head_dim * self._partial_rotary_factor) + rotary_dim = (rotary_dim // 2) * 2 + self._head_dim = head_dim + self._rotary_dim = rotary_dim + + def forward(self, x, position_ids): + # Keep cos/sin unpadded; patched apply_rotary_pos_emb handles + # partial-neox rotation on the leading rotary channels. + return super().forward(x, position_ids) + + PartialMRoPE.__name__ = ( + f"{base_cls.__name__}Partial{int(partial_rotary_factor * 100):03d}" + ) + PartialMRoPE.__qualname__ = PartialMRoPE.__name__ + return PartialMRoPE + + @SpeculatorModel.register("eagle3") class Eagle3DraftModel(DraftVocabMixin, SpeculatorModel): config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc] @@ -81,7 +205,10 @@ def __init__(self, config: Eagle3SpeculatorConfig): # Create a modified config for the rotary embedding to use 2x the hidden size modified_tl_config = copy.copy(config.transformer_layer_config) modified_tl_config.hidden_size *= 2 - self.rotary_emb = self._model_definitions.rotary_emb_class(modified_tl_config) + rotary_cls = _select_rotary_emb_class( + modified_tl_config, self._model_definitions.rotary_emb_class + ) + self.rotary_emb = rotary_cls(modified_tl_config) # LAYER NORMS norm_class = self._model_definitions.norm_class @@ -108,7 +235,9 @@ def load_verifier_weights(self): verifier_config = self.config.speculators_config.verifier verifier_model_config = AutoConfig.from_pretrained(verifier_config.name_or_path) # type: ignore[arg-type] - # For multimodal models (Qwen3VL, etc.), extract text_config + # For multimodal models (Qwen3VL/Omni/etc.), extract text_config + if hasattr(verifier_model_config, "thinker_config"): + verifier_model_config = verifier_model_config.thinker_config if hasattr(verifier_model_config, "text_config"): verifier_model_config = verifier_model_config.text_config diff --git a/src/speculators/models/eagle3/data.py b/src/speculators/models/eagle3/data.py index 2ac6af086..b56b6754c 100644 --- a/src/speculators/models/eagle3/data.py +++ b/src/speculators/models/eagle3/data.py @@ -28,7 +28,10 @@ def shift_batch(batch: BatchType): verifier_last_hidden_states = verifier_last_hidden_states[1:] loss_mask = loss_mask[1:] lengths = lengths - 1 - position_ids = position_ids[1:] # Note: position_ids now start at 1 + if position_ids.ndim == 2: + position_ids = position_ids[:, 1:] + else: + position_ids = position_ids[1:] # Note: position_ids now start at 1 return { "input_ids": input_ids, diff --git a/src/speculators/models/eagle3/metrics.py b/src/speculators/models/eagle3/metrics.py index 91ab3bfe0..49699893f 100644 --- a/src/speculators/models/eagle3/metrics.py +++ b/src/speculators/models/eagle3/metrics.py @@ -1,17 +1,49 @@ """Metrics and loss functions for Eagle3 draft model.""" +import os +import warnings from collections.abc import Callable from functools import partial import torch from speculators.models.metrics import ( + ce_loss, compute_accuracy_single_step, exp_loss_decay, kl_div_loss, loss_function, ) +_EAGLE3_LOSS_CE_WEIGHT_RAW = os.getenv("EAGLE3_LOSS_CE_WEIGHT") +if _EAGLE3_LOSS_CE_WEIGHT_RAW is None: + _EAGLE3_LOSS_CE_WEIGHT = 0.0 +else: + try: + _EAGLE3_LOSS_CE_WEIGHT = float(_EAGLE3_LOSS_CE_WEIGHT_RAW) + except ValueError: + warnings.warn( + f"EAGLE3_LOSS_CE_WEIGHT={_EAGLE3_LOSS_CE_WEIGHT_RAW!r} is not a " + "float; falling back to pure-KL legacy loss (weight=0.0).", + stacklevel=1, + ) + _EAGLE3_LOSS_CE_WEIGHT = 0.0 + if not 0.0 <= _EAGLE3_LOSS_CE_WEIGHT <= 1.0: + warnings.warn( + f"EAGLE3_LOSS_CE_WEIGHT={_EAGLE3_LOSS_CE_WEIGHT} clamped to [0, 1].", + stacklevel=1, + ) + _EAGLE3_LOSS_CE_WEIGHT = max(0.0, min(1.0, _EAGLE3_LOSS_CE_WEIGHT)) + + +def eagle3_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Per-token Eagle3 loss with optional CE(verifier argmax) mixture.""" + kl = kl_div_loss(logits, targets) + if _EAGLE3_LOSS_CE_WEIGHT <= 0.0: + return kl + ce = ce_loss(logits, targets) + return (1.0 - _EAGLE3_LOSS_CE_WEIGHT) * kl + _EAGLE3_LOSS_CE_WEIGHT * ce + def align_for_step( logits: torch.Tensor, # shape: [1, total_seq_len, draft_vocab_size] @@ -75,7 +107,7 @@ def compute_metrics( Loss value and metrics dictionary. """ if loss_fn is None: - loss_fn = kl_div_loss + loss_fn = eagle3_loss s_logits, s_targets, s_loss_mask, s_prev_correct = align_for_step( logits, targets, loss_mask, prev_correct, ttt_step ) diff --git a/src/speculators/models/eagle3/rotary_partial.py b/src/speculators/models/eagle3/rotary_partial.py new file mode 100644 index 000000000..bc085f512 --- /dev/null +++ b/src/speculators/models/eagle3/rotary_partial.py @@ -0,0 +1,109 @@ +"""Patch HF rotary helper to match vLLM partial-neox behavior. + +HF ``apply_rotary_pos_emb`` rotates by splitting at ``head_dim/2``. +vLLM partial MRoPE rotates only the first ``rotary_dim`` channels and +keeps the tail unchanged. This file aligns HF training with that runtime +behavior, while keeping full-rotation paths unchanged. +""" + +from __future__ import annotations + +import torch + +__all__ = [ + "install_partial_neox_rotary", + "partial_neox_apply_rotary_pos_emb", +] + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """HF/neox "rotate_half" — splits the last dim in half and swaps.""" + half = x.shape[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + return torch.cat((-x2, x1), dim=-1) + + +def partial_neox_apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """HF-compatible rotary helper with partial-neox fallback. + + - If ``cos`` covers full head dim, behavior matches HF. + - If ``cos`` is shorter, rotate only the first ``rotary_dim`` channels + and keep the remaining channels unchanged. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + if cos.shape[-1] == q.shape[-1]: + # Full rotation — identical to HF apply_rotary_pos_emb. + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + if cos.shape[-1] > q.shape[-1]: + raise ValueError( + f"cos last dim ({cos.shape[-1]}) exceeds q last dim " + f"({q.shape[-1]}); rotary tables larger than head_dim are " + "unsupported by this partial-neox replacement." + ) + + rotary_dim = cos.shape[-1] + # Rotate leading rotary channels; keep tail as pass-through. + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_rot = (q_rot * cos) + (_rotate_half(q_rot) * sin) + k_rot = (k_rot * cos) + (_rotate_half(k_rot) * sin) + + q_embed = torch.cat([q_rot, q_pass], dim=-1) + k_embed = torch.cat([k_rot, k_pass], dim=-1) + return q_embed, k_embed + + +_INSTALLED = False + + +def install_partial_neox_rotary() -> None: + """Patch ``apply_rotary_pos_emb`` in HF ``llama`` and ``qwen3`` modules. + + Idempotent. Full-rotation paths keep original behavior. + """ + global _INSTALLED + if _INSTALLED: + return + + # Local imports — keep transformers a soft dep at module import time. + from transformers.models.llama import modeling_llama # noqa: PLC0415 + from transformers.models.qwen3 import modeling_qwen3 # noqa: PLC0415 + + for module in (modeling_llama, modeling_qwen3): + original = module.apply_rotary_pos_emb + # Cache original for tests / debugging — and to allow uninstall. + if not hasattr(module, "_speculators_original_apply_rotary_pos_emb"): + module._speculators_original_apply_rotary_pos_emb = original # type: ignore[attr-defined] + module.apply_rotary_pos_emb = partial_neox_apply_rotary_pos_emb + + _INSTALLED = True + + +def uninstall_partial_neox_rotary() -> None: + """Restore HF's original ``apply_rotary_pos_emb``. Test/debug helper.""" + global _INSTALLED + if not _INSTALLED: + return + from transformers.models.llama import modeling_llama # noqa: PLC0415 + from transformers.models.qwen3 import modeling_qwen3 # noqa: PLC0415 + + for module in (modeling_llama, modeling_qwen3): + original = getattr( + module, "_speculators_original_apply_rotary_pos_emb", None + ) + if original is not None: + module.apply_rotary_pos_emb = original + _INSTALLED = False diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 63519024d..61cf1dded 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -7,6 +7,7 @@ from collections.abc import Callable from os import PathLike from pathlib import Path +from types import MethodType, SimpleNamespace from typing import Any, Literal, cast import openai @@ -15,6 +16,7 @@ from datasets import load_from_disk from safetensors.torch import load_file from torch.utils.data import Dataset +from transformers import AutoConfig from speculators.data_generation.vllm_client import ( DEFAULT_MAX_RETRIES, @@ -26,6 +28,23 @@ from speculators.train.noise_transforms import TransformTensors BatchType = dict[str, Any] +MULTIMODAL_SIDECAR_KEYS = ( + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grids", + "input_features", + "feature_attention_mask", + "audio_feature_lengths", +) +NON_TRAINING_KEYS = { + *MULTIMODAL_SIDECAR_KEYS, + "messages", + "messages_json", + "mm_file", + "use_audio_in_video", +} def list_files(path): @@ -47,6 +66,16 @@ def slice_and_pad_to_length(tensor, length): return F.pad(sliced_tensor, padding) +def pad_last_dim_to_length(tensor: torch.Tensor, length: int) -> torch.Tensor: + sliced_tensor = tensor[..., :length] + pad_amount = length - sliced_tensor.shape[-1] + if pad_amount <= 0: + return sliced_tensor + padding = [0, pad_amount] + padding.extend([0, 0] * (sliced_tensor.dim() - 1)) + return F.pad(sliced_tensor, padding) + + def split_files(datapath: str, ratio: float = 0.9, seed: int = 0): """Given a datapath, split the files into a training and validation set ratio is the proportion of files to put in the training set @@ -111,16 +140,198 @@ def standardize_data_v1(data: dict[str, Any]) -> dict[str, Any]: } +def _collect_mm_payload_from_messages( + messages: list[dict], +) -> dict[str, list[str]]: + """Extract image/video/audio URLs from vLLM-style messages. + + Used to populate ``extra_body.multi_modal_data`` for token-id completions. + """ + bucket: dict[str, list[str]] = {} + for turn in messages: + content = turn.get("content") + if not isinstance(content, list): + continue + for seg in content: + if not isinstance(seg, dict): + continue + seg_type = seg.get("type", "") + for modality in ("image", "video", "audio"): + if seg_type == f"{modality}_url": + url = (seg.get(f"{modality}_url") or {}).get("url") + if url: + bucket.setdefault(modality, []).append(url) + return bucket + + def build_client_item(dataset_item: dict) -> ClientItem: - out_dict = {} - out_dict["input_ids"] = dataset_item["input_ids"].tolist() + """Build a vLLM client payload from one dataset row. - if "messages" in dataset_item: - out_dict["messages"] = dataset_item["messages"] + Default path sends ``input_ids`` (token-id completions). For multimodal + rows, URLs from ``messages_json`` are forwarded as ``multi_modal_data``. + ``messages`` is kept only for optional chat-completions fallback. + """ + out_dict: dict[str, Any] = {} + input_ids = _maybe_tensor(dataset_item["input_ids"], torch.long) + out_dict["input_ids"] = input_ids.tolist() if input_ids is not None else [] + + messages_json = dataset_item.get("messages_json", "") + messages: list[dict] | None = None + if messages_json: + messages = json.loads(messages_json) + elif "messages" in dataset_item: + messages = dataset_item["messages"] + + if messages: + # Keep messages around as a fallback path + out_dict["messages"] = messages + # Primary path: extract MM URLs so we can use Completions API + mm = _collect_mm_payload_from_messages(messages) + if mm: + out_dict["multi_modal_data"] = mm return cast("ClientItem", out_dict) +def _maybe_tensor(value: Any, dtype: torch.dtype | None = None) -> torch.Tensor | None: + if value is None: + return None + tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value) + if dtype is not None: + tensor = tensor.to(dtype=dtype) + return tensor + + +def _batchify_mm_tensor(value: torch.Tensor | None) -> torch.Tensor | None: + if value is None: + return None + if value.ndim == 0: + return value.view(1) + if value.ndim == 1: + return value.unsqueeze(0) + return value + + +def _has_multimodal_payload(data: dict[str, Any]) -> bool: + return any(data.get(key) is not None for key in MULTIMODAL_SIDECAR_KEYS) + + +def _make_rope_index_fn(verifier_name_or_path: str): + """Build a callable producing 3D MRoPE position ids for Qwen multimodal models.""" + try: + verifier_root_config = AutoConfig.from_pretrained( + verifier_name_or_path, trust_remote_code=True + ) + except Exception: # noqa: BLE001 + return None + + thinker_config = getattr(verifier_root_config, "thinker_config", None) + if thinker_config is not None: + return _make_rope_index_fn_qwen3_omni(thinker_config) + + top_vision_config = getattr(verifier_root_config, "vision_config", None) + top_text_config = getattr(verifier_root_config, "text_config", None) + top_image_token_id = getattr(verifier_root_config, "image_token_id", None) + if ( + top_vision_config is not None + and top_text_config is not None + and top_image_token_id is not None + ): + return _make_rope_index_fn_qwen3_5_moe(verifier_root_config) + + return None + + +def _make_rope_index_fn_qwen3_omni(thinker_config): + try: + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( # noqa: PLC0415 + Qwen3OmniMoeThinkerForConditionalGeneration, + ) + except ImportError: + return None + + vision_config = getattr(thinker_config, "vision_config", None) + spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1) + dummy = SimpleNamespace( + config=SimpleNamespace( + image_token_id=getattr(thinker_config, "image_token_id", None), + video_token_id=getattr(thinker_config, "video_token_id", None), + audio_token_id=getattr(thinker_config, "audio_token_id", None), + vision_start_token_id=getattr( + thinker_config, "vision_start_token_id", None + ), + audio_start_token_id=getattr(thinker_config, "audio_start_token_id", None), + position_id_per_seconds=getattr( + thinker_config, "position_id_per_seconds", 25 + ), + ), + spatial_merge_size=spatial_merge_size, + ) + dummy.get_llm_pos_ids_for_vision = MethodType( + Qwen3OmniMoeThinkerForConditionalGeneration.get_llm_pos_ids_for_vision, + dummy, + ) + return MethodType( + Qwen3OmniMoeThinkerForConditionalGeneration.get_rope_index, + dummy, + ) + + +def _make_rope_index_fn_qwen3_5_moe(root_config): + try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( # noqa: PLC0415 + Qwen3_5MoeModel, + ) + except ImportError: + return None + + image_token_id = getattr(root_config, "image_token_id", None) + video_token_id = getattr(root_config, "video_token_id", None) + vision_config = getattr(root_config, "vision_config", None) + dummy = SimpleNamespace( + config=SimpleNamespace( + vision_config=vision_config, + image_token_id=image_token_id, + video_token_id=video_token_id, + ), + ) + dummy.get_vision_position_ids = MethodType( + Qwen3_5MoeModel.get_vision_position_ids, dummy + ) + raw_get_rope_index = MethodType(Qwen3_5MoeModel.get_rope_index, dummy) + + def adapter( + input_ids, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + **_ignored, + ): + ids = input_ids + image_mask = ( + ids == image_token_id + if image_token_id is not None + else torch.zeros_like(ids, dtype=torch.bool) + ) + video_mask = ( + ids == video_token_id + if video_token_id is not None + else torch.zeros_like(ids, dtype=torch.bool) + ) + mm_token_type_ids = torch.zeros_like(ids, dtype=torch.int32) + mm_token_type_ids[image_mask] = 1 + mm_token_type_ids[video_mask] = 2 + return raw_get_rope_index( + input_ids=ids, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + + return adapter + + class BaseDataset(Dataset): def __init__( self, @@ -139,6 +350,32 @@ def _compute_approx_lengths(self): def _get_raw_data(self, index): raise NotImplementedError + def _build_position_ids(self, data: dict[str, Any], seq_len: int) -> torch.Tensor: + rope_index_fn = getattr(self, "_rope_index_fn", None) + if rope_index_fn is not None and _has_multimodal_payload(data): + audio_seqlens = _maybe_tensor(data.get("audio_feature_lengths"), torch.long) + if audio_seqlens is None and data.get("feature_attention_mask") is not None: + audio_seqlens = _maybe_tensor( + data["feature_attention_mask"], torch.long + ).sum(dim=-1) + + position_ids, _ = rope_index_fn( + input_ids=data["input_ids"].unsqueeze(0), + image_grid_thw=_batchify_mm_tensor( + _maybe_tensor(data.get("image_grid_thw"), torch.long) + ), + video_grid_thw=_batchify_mm_tensor( + _maybe_tensor(data.get("video_grid_thw"), torch.long) + ), + use_audio_in_video=bool(data.get("use_audio_in_video", False)), + audio_seqlens=audio_seqlens, + second_per_grids=_maybe_tensor(data.get("second_per_grids")), + attention_mask=torch.ones(1, seq_len, dtype=torch.long), + ) + return position_ids[:, 0].to(dtype=torch.long) + + return torch.arange(seq_len, dtype=torch.long) + def __getitem__(self, index) -> BatchType | None: data = self._get_raw_data(index) @@ -163,8 +400,12 @@ def __getitem__(self, index) -> BatchType | None: data["lengths"] = torch.tensor([seq_len], dtype=torch.long) # shape: [1] - data["position_ids"] = torch.arange(seq_len, dtype=torch.long) - # shape: [seq_len] + data["position_ids"] = self._build_position_ids(data, seq_len) + # shape: [seq_len] or [3, seq_len] for MRoPE multimodal samples + + # Keep multimodal payload only for rope index / HS generation + for key in NON_TRAINING_KEYS: + data.pop(key, None) # data structure: { # "hidden_states": [seq_len, 3 * hidden_size], @@ -208,6 +449,7 @@ def __init__( model: str | None = None, request_timeout: float | None = DEFAULT_REQUEST_TIMEOUT, max_retries: int = DEFAULT_MAX_RETRIES, + verifier_name_or_path: str | None = None, ): """Initialize the ArrowDataset. Args: @@ -217,12 +459,14 @@ def __init__( transform: The transform to apply to the data. hidden_states_dtype: The dtype of the hidden states. """ - self.data = load_from_disk(datapath) + self.datapath = Path(datapath) + # Reset format after load so metadata columns remain visible + # (e.g. ``mm_file`` / ``messages_json``). + self.data = load_from_disk(datapath).with_format(None) self.start_file_idx = 0 if split_ratio == 1.0: pass elif 1.0 > split_ratio > 0: - self.start_file_idx = 0 split_idx = int(len(self.data) * split_ratio) self.data = self.data.select(range(split_idx)) elif -1.0 < split_ratio < 0: @@ -233,7 +477,7 @@ def __init__( raise ValueError("split_ratio must be in range (-1.0, 1.0] excluding 0.0.") self.hidden_states_path: Path = ( - Path(datapath) / "hidden_states" + self.datapath / "hidden_states" if hidden_states_path is None else Path(hidden_states_path) ) @@ -242,8 +486,14 @@ def __init__( self.on_generate = on_generate self.client: openai.OpenAI | None = None self.model = model + self.verifier_name_or_path = verifier_name_or_path or model self.request_timeout = request_timeout self.max_retries = max_retries + self._rope_index_fn = ( + _make_rope_index_fn(self.verifier_name_or_path) + if self.verifier_name_or_path is not None + else None + ) # Delay super init so that `_compute_approx_lengths` has required data super().__init__(max_len, transform, hidden_states_dtype) @@ -273,28 +523,57 @@ def _compute_approx_lengths(self) -> list[int]: """Get lengths of the dataset samples.""" return list(self.data.with_format(None)["seq_len"]) - def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: + def _maybe_generate_hs( + self, index: int, *, is_multimodal: bool = False + ) -> dict[str, torch.Tensor] | None: + """Generate hidden states from vLLM on demand. + + For multimodal rows, Chat Completions is used and server token_ids are + treated as authoritative for HS alignment. + """ if not self.client: self._setup_client() dataset_item = self.data[index] client_item = build_client_item(dataset_item) + # Use Chat Completions for multimodal rows to trigger vision encoding. + # If tokenization drifts, trust vLLM token_ids (HS-aligned). + use_chat = is_multimodal and client_item.get("messages") is not None + try: - hs_filepath = generate_hidden_states( + result = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] client_item, timeout=self.request_timeout, max_retries=self.max_retries, + use_chat_completions=use_chat, ) + if use_chat: + hs_filepath, server_token_ids = result # type:ignore[misc] + else: + hs_filepath = result # type:ignore[assignment] + server_token_ids = None + loaded_hs = _maybe_load_hs_file(Path(hs_filepath)) + # Overwrite token_ids with the server's authoritative version + # so downstream _get_raw_data uses the correct positionally- + # aligned sequence for loss/hidden-state pairing. + if loaded_hs is not None and server_token_ids is not None: + loaded_hs["token_ids"] = torch.tensor( + server_token_ids, dtype=torch.long + ) + match self.on_generate: case "cache": file_idx = self._map_to_file_idx(index) target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" + # Ensure destination directory exists before moving files. + # Missing parent dir can cause all generated samples to fail. + target_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(hs_filepath, target_path) case "delete": Path(hs_filepath).unlink() @@ -308,14 +587,26 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: return loaded_hs def _get_raw_data(self, index): + # Read one row once, then normalize tensors and metadata together. + row = self.data[index] + stored_input_ids = _maybe_tensor(row["input_ids"], torch.long) + stored_loss_mask = _maybe_tensor(row["loss_mask"], torch.long) + # Fast-fail: skip corrupted/incomplete rows before expensive HS I/O. + if stored_input_ids is None or stored_loss_mask is None: + return None + + is_multimodal = bool(row.get("messages_json", "")) + file_idx = self._map_to_file_idx(index) - candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" - loaded_hs = _maybe_load_hs_file(candidate_path) + hs_file = self.hidden_states_path / f"hs_{file_idx}.safetensors" + loaded_hs = _maybe_load_hs_file(hs_file) if loaded_hs is None: match self.on_missing: case "generate": - loaded_hs = self._maybe_generate_hs(index) + loaded_hs = self._maybe_generate_hs( + index, is_multimodal=is_multimodal + ) case "skip": return None case "warn": @@ -337,25 +628,61 @@ def _get_raw_data(self, index): # "token_ids": [seq_len] # } - if not torch.equal(loaded_hs["token_ids"], self.data[index]["input_ids"]): + # For multimodal samples generated via Chat Completions, vLLM's + # server-side tokenization is authoritative (it ran the vision + # encoder on that exact token sequence). The stored input_ids from + # prepare_data.py may differ by ±1 token at vision-placeholder + # boundaries, so we trust loaded_hs["token_ids"] unconditionally. + # For text-only samples (or pre-generated offline HS), we still + # enforce strict equality to catch data corruption early. + authoritative_input_ids = loaded_hs["token_ids"] + if not is_multimodal and not torch.equal(authoritative_input_ids, stored_input_ids): warnings.warn( - f"Loaded token ids {loaded_hs['token_ids']} for index {index} don't" - f"match input ids {self.data[index]['input_ids']}", + f"Token IDs mismatch for text sample {index}: " + f"hs has {len(authoritative_input_ids)} tokens, " + f"dataset has {len(stored_input_ids)} tokens. Skipping.", stacklevel=1, ) return None - return { + # Use authoritative_input_ids (from HS file) as ground truth. + # For multimodal, this is vLLM's tokenization (aligned with HS); + # for text-only, it equals stored_input_ids (verified above). + # loss_mask length must match; truncate/pad if server tokenized + # slightly differently (common for multimodal ±1 token drift). + seq_len_hs = len(authoritative_input_ids) + if len(stored_loss_mask) != seq_len_hs: + if len(stored_loss_mask) > seq_len_hs: + stored_loss_mask = stored_loss_mask[:seq_len_hs] + else: + stored_loss_mask = torch.cat([ + stored_loss_mask, + torch.zeros(seq_len_hs - len(stored_loss_mask), dtype=torch.long), + ]) + + data = { "hidden_states": loaded_hs["hidden_states"][:, :-1].flatten( 1 ), # [seq_len, 3 * hidden_size] - "input_ids": loaded_hs["token_ids"], # [seq_len] + "input_ids": authoritative_input_ids, # [seq_len] "verifier_last_hidden_states": loaded_hs["hidden_states"][ :, -1 ], # [seq_len, hidden_size] - "loss_mask": self.data[index]["loss_mask"], # [seq_len] + "loss_mask": stored_loss_mask, # [seq_len] + "messages_json": row.get("messages_json", ""), + "mm_file": row.get("mm_file", ""), + "use_audio_in_video": bool(row.get("use_audio_in_video", 0)), } + mm_path = data["mm_file"] + if mm_path: + mm = load_file(self.datapath / mm_path) + for key in MULTIMODAL_SIDECAR_KEYS: + if key in mm: + data[key] = mm[key] + + return data + class SampleFileDataset(BaseDataset): def __init__( @@ -473,7 +800,13 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType: batch = [create_empty_sample(hidden_size, dtype=dtype)] collated_data = {} + has_mrope = any( + b["position_ids"].ndim == 2 for b in batch # type: ignore[index] + ) for key in batch[0]: # type: ignore[union-attr] + if key == "position_ids": + continue + # Concatenate the tensors along the seq (0th) dimension collated_data[key] = torch.cat([b[key] for b in batch], dim=0) # type: ignore[index] # shape: [total_seq_len, ...] @@ -485,6 +818,40 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType: ).unsqueeze(0) # shape: [1, max_len, ...] + if has_mrope: + # MRoPE samples carry position_ids of shape [3, seq_len] (T/H/W + # channels). After shift_batch they remain [3, seq_len-1]. + # We concatenate along the seq dimension, pad to max_len, then + # add a batch dim. Qwen3-Omni Thinker rotary (used by EAGLE3 when + # the verifier carries rope_parameters.mrope_section) expects + # ``position_ids`` shaped ``[3, batch, seq_len]`` — channels + # first — NOT the HF Llama4 / Gemma3 ``[batch, 3, seq_len]`` + # convention. Picking the wrong layout silently broadcasts the + # 3 channels through the attention reshape and blows up + # ``o_proj`` with a feature dim of ``3 * num_heads * head_dim`` + # (e.g. 12288 instead of 4096 for Qwen3.6 draft heads). + position_ids = [] + for sample in batch: # type: ignore[assignment] + pos = sample["position_ids"] + if pos.ndim == 1: + # Text-only sample in a mixed batch: broadcast to 3 channels + pos = pos.unsqueeze(0).expand(3, -1) + position_ids.append(pos) + # shape of each: [3, sample_seq_len]; cat along seq dim → [3, total_seq_len] + collated_positions = torch.cat(position_ids, dim=-1) + # Pad seq dim to max_len → [3, max_len], then add batch dim at + # axis 1 → [3, 1, max_len] for Qwen-Omni rotary. + collated_data["position_ids"] = pad_last_dim_to_length( + collated_positions, max_len + ).unsqueeze(1) + else: + # Standard 1D position_ids: [seq_len] per sample. + # Cat along seq dim → [total_seq_len], pad → [max_len], batch → [1, max_len] + collated_positions = torch.cat([b["position_ids"] for b in batch], dim=0) # type: ignore[index] + collated_data["position_ids"] = slice_and_pad_to_length( + collated_positions, max_len + ).unsqueeze(0) + # Include lengths until while they fit in max_len # The last included length is (if necessary) truncated # Any additional lengths are discarded diff --git a/src/speculators/train/vocab_mapping.py b/src/speculators/train/vocab_mapping.py index 90351c53a..1d66f35be 100644 --- a/src/speculators/train/vocab_mapping.py +++ b/src/speculators/train/vocab_mapping.py @@ -111,7 +111,9 @@ def get_target_vocab_size(target_vocab_size, target_model_path): config = AutoConfig.from_pretrained(target_model_path) - # For multimodal models (Qwen3VL, etc.), extract text_config + # Multimodal verifiers may nest the text backbone. + if hasattr(config, "thinker_config"): + config = config.thinker_config if hasattr(config, "text_config"): config = config.text_config diff --git a/tests/unit/models/test_eagle3_rotary_partial.py b/tests/unit/models/test_eagle3_rotary_partial.py new file mode 100644 index 000000000..d51c0e21b --- /dev/null +++ b/tests/unit/models/test_eagle3_rotary_partial.py @@ -0,0 +1,197 @@ +"""Bit-equivalence test for the partial-neox rotary monkey-patch. + +This test pins the contract that motivates the patch: when +``cos.shape[-1] == rotary_dim < head_dim``, the patched +``apply_rotary_pos_emb`` rotates the **same channel pairs** that vLLM's +``MRotaryEmbedding`` rotates at inference time — i.e. + + rotated channels: [0, rotary_dim/2) paired with [rotary_dim/2, rotary_dim) + pass-through: [rotary_dim, head_dim) + +We verify two properties on a hand-rolled vLLM-equivalent reference: + +1. **Full-rotation parity** (``rotary_dim == head_dim``): the patched + helper is byte-equivalent to HF's original ``apply_rotary_pos_emb``, + so DFlash / plain Llama drafters are unaffected. + +2. **Partial-rotation parity** (``rotary_dim < head_dim``, e.g. + Qwen3.6's ``head_dim=256``, ``partial_rotary_factor=0.25``): the + patched helper matches a hand-coded vLLM neox-partial reference to + well below fp32 round-off (``< 1e-5`` max-abs-diff), and the + pass-through tail is preserved exactly (``allclose`` with + ``atol=0``). +""" + +from __future__ import annotations + +import math + +import torch + +try: + import pytest +except ImportError: # pragma: no cover - environments without pytest + pytest = None # type: ignore[assignment] + + +def _vllm_neox_partial_reference( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Hand-coded vLLM neox-partial RoPE — matches MRotaryEmbedding._forward. + + Independent of the implementation under test; written from the spec + (see ``vllm/model_executor/layers/rotary_embedding/__init__.py``). + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + rotary_dim = cos.shape[-1] + + def _rot(x: torch.Tensor) -> torch.Tensor: + x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] + half = rotary_dim // 2 + x1, x2 = x_rot[..., :half], x_rot[..., half:] + # neox rotate_half on the rotated slice + rotated = torch.cat((-x2, x1), dim=-1) + out_rot = (x_rot * cos) + (rotated * sin) + return torch.cat([out_rot, x_pass], dim=-1) + + return _rot(q), _rot(k) + + +def test_full_rotation_is_byte_equivalent_to_hf(): + """When cos covers the full head_dim, the patched fn must match HF exactly.""" + from speculators.models.eagle3.rotary_partial import ( + partial_neox_apply_rotary_pos_emb, + ) + from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb as hf_apply, + ) + + torch.manual_seed(0) + batch, heads, seq, head_dim = 2, 4, 7, 64 + q = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32) + k = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32) + # cos/sin over full head_dim + cos = torch.cos(torch.randn(batch, seq, head_dim)) + sin = torch.sin(torch.randn(batch, seq, head_dim)) + + q_hf, k_hf = hf_apply(q, k, cos, sin, unsqueeze_dim=1) + q_p, k_p = partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + + # Strict equality — same arithmetic, same order of operations. + assert torch.equal(q_hf, q_p), "patched fn drifted from HF on full rotation" + assert torch.equal(k_hf, k_p), "patched fn drifted from HF on full rotation" + + +def test_partial_rotation_matches_vllm_reference(): + """rotary_dim < head_dim must match vLLM's neox-partial channel layout.""" + from speculators.models.eagle3.rotary_partial import ( + partial_neox_apply_rotary_pos_emb, + ) + + torch.manual_seed(1) + # Realistic Qwen3.6 shape: head_dim=256, partial_rotary_factor=0.25 + batch, heads, seq, head_dim = 2, 4, 5, 256 + rotary_dim = 64 # int(256 * 0.25) + + q = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32) + k = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32) + cos = torch.cos(torch.randn(batch, seq, rotary_dim)) + sin = torch.sin(torch.randn(batch, seq, rotary_dim)) + + q_ref, k_ref = _vllm_neox_partial_reference(q, k, cos, sin, unsqueeze_dim=1) + q_got, k_got = partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + + # Same arithmetic up to algebraic associativity; allow tiny fp32 jitter. + assert torch.allclose(q_got, q_ref, atol=1e-6, rtol=0) + assert torch.allclose(k_got, k_ref, atol=1e-6, rtol=0) + + +def test_partial_rotation_preserves_passthrough_channels(): + """Channels >= rotary_dim must be untouched (atol=0).""" + from speculators.models.eagle3.rotary_partial import ( + partial_neox_apply_rotary_pos_emb, + ) + + torch.manual_seed(2) + batch, heads, seq, head_dim = 1, 2, 3, 64 + rotary_dim = 16 + + q = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32) + k = torch.randn(batch, heads, seq, head_dim, dtype=torch.float32) + cos = torch.cos(torch.randn(batch, seq, rotary_dim)) + sin = torch.sin(torch.randn(batch, seq, rotary_dim)) + + q_out, k_out = partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + + # Pass-through tail must be bit-exact. + assert torch.equal(q_out[..., rotary_dim:], q[..., rotary_dim:]) + assert torch.equal(k_out[..., rotary_dim:], k[..., rotary_dim:]) + # And the rotated head must NOT be equal to the input (sanity). + assert not torch.equal(q_out[..., :rotary_dim], q[..., :rotary_dim]) + + +def test_install_is_idempotent_and_byte_safe_on_full_rotation(): + """Installing the patch must not perturb full-rotation HF callers.""" + from speculators.models.eagle3.rotary_partial import ( + install_partial_neox_rotary, + uninstall_partial_neox_rotary, + ) + from transformers.models.llama import modeling_llama + from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb as hf_apply_pre, + ) + + # Snapshot HF behaviour BEFORE installing. + torch.manual_seed(3) + head_dim = 32 + q = torch.randn(1, 2, 4, head_dim) + k = torch.randn(1, 2, 4, head_dim) + cos = torch.cos(torch.randn(1, 4, head_dim)) + sin = torch.sin(torch.randn(1, 4, head_dim)) + q_pre, k_pre = hf_apply_pre(q, k, cos, sin, unsqueeze_dim=1) + + install_partial_neox_rotary() + install_partial_neox_rotary() # idempotent + + # After install, the module-level symbol must be the patched one. + q_post, k_post = modeling_llama.apply_rotary_pos_emb( + q, k, cos, sin, unsqueeze_dim=1 + ) + assert torch.equal(q_pre, q_post) + assert torch.equal(k_pre, k_post) + + uninstall_partial_neox_rotary() + uninstall_partial_neox_rotary() # idempotent + + # After uninstall, the symbol must be HF's original (object identity). + assert ( + modeling_llama.apply_rotary_pos_emb is hf_apply_pre + or modeling_llama.apply_rotary_pos_emb.__wrapped__ is hf_apply_pre + ) + + +def test_rejects_cos_longer_than_head_dim(): + """Defensive — cos can't be larger than q's last dim.""" + from speculators.models.eagle3.rotary_partial import ( + partial_neox_apply_rotary_pos_emb, + ) + + q = torch.randn(1, 1, 1, 8) + k = torch.randn(1, 1, 1, 8) + cos = torch.randn(1, 1, 16) # > head_dim=8 + sin = torch.randn(1, 1, 16) + if pytest is not None: + with pytest.raises(ValueError, match="exceeds q last dim"): + partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + return + try: + partial_neox_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + except ValueError as e: + assert "exceeds q last dim" in str(e) + else: # pragma: no cover + raise AssertionError("expected ValueError on oversize cos")