diff --git a/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh b/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh new file mode 100644 index 000000000..80fbf3b45 --- /dev/null +++ b/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh @@ -0,0 +1,226 @@ +#!/bin/bash +# Online Eagle3 Training Script for Qwen3-VL-4B on hao05/llava-cot-5k-reannotated +# +# Runs the full online training pipeline: +# 1. Download the multimodal Parquet dataset snapshot from Hugging Face +# 2. Materialize local image files and an absolute-path JSONL +# 3. Prepare arrow data with multimodal preprocessing +# 4. Launch a hidden-state extraction vLLM server +# 5. Train Eagle3 with on-the-fly hidden-state generation +# +# Usage: +# bash examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh +# +# Notes: +# - `prepare_data.py` currently accepts local json/jsonl files or built-in dataset +# aliases. This example snapshots the public HF Parquet dataset locally first. +# - The uploaded dataset stores image bytes in Parquet and preserves original +# relative paths in `image_path`. This script materializes image files and a +# JSONL with absolute image paths so vLLM can load images reliably during +# online training. +# - For more data and a longer training run that can improve accuracy, replace +# `hao05/llava-cot-5k-reannotated` with `hao05/llava-cot-48k-reannotated` +# and adjust `MAX_SAMPLES` / `EPOCHS` as needed. +# +# ### Example E2E run for Qwen3-VL-4B on 5k samples from LLaVA-CoT ### +# +# Note: This 5k setup is primarily a pipeline sanity check. It is enough to +# verify that multimodal online training, hidden-state generation, and +# checkpointing all work end-to-end, but it is not intended to represent final +# model quality. +# +# Timing from an observed run on 4x NVIDIA GeForce RTX 5090 32GB GPUs +# (vLLM on GPUs 0,1 and training on GPUs 2,3): +# Data preprocessing: 460 seconds (7 mins 40 secs) +# vLLM server startup: 45 seconds +# Training (5 epochs): 1110 seconds (18 mins 30 secs) +# Total (prepare_data start to checkpoint save): 1615 seconds (26 mins 55 secs) +# +# Final validation metrics from that run: +# val/loss_epoch: 8.676 +# val/full_acc_0_epoch: 57.7% +# val/full_acc_1_epoch: 31.9% +# val/full_acc_2_epoch: 17.9% + +set -euo pipefail + +# ============ Configuration ============ +MODEL="${MODEL:-Qwen/Qwen3-VL-4B-Instruct}" +DATASET_REPO="${DATASET_REPO:-hao05/llava-cot-5k-reannotated}" +DATASET_DIR="${DATASET_DIR:-./data/llava-cot-5k-reannotated}" +DATASET_JSONL="$DATASET_DIR/train.absolute_paths.jsonl" +OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_vl_4b_llava_cot_online}" +HIDDEN_STATES_DIR="$OUTPUT_DIR/hidden_states_online" +CHECKPOINT_DIR="$OUTPUT_DIR/checkpoints" +VLLM_PORT="${VLLM_PORT:-8000}" +MAX_SAMPLES="${MAX_SAMPLES:-5000}" +SEQ_LENGTH="${SEQ_LENGTH:-4096}" +VLLM_MAX_MODEL_LEN="${VLLM_MAX_MODEL_LEN:-5120}" +VLLM_TP="${VLLM_TP:-2}" +EPOCHS="${EPOCHS:-5}" +LR="${LR:-1e-4}" +VLLM_EXTRA_ARGS="${VLLM_EXTRA_ARGS:-}" +VLLM_LOG_FILE="${VLLM_LOG_FILE:-./vllm_server.log}" + +# GPU assignments +VLLM_GPUS="${VLLM_GPUS:-0,1}" +TRAIN_GPUS="${TRAIN_GPUS:-2,3}" +NUM_TRAIN_GPUS="${NUM_TRAIN_GPUS:-2}" +# ======================================= + +# Optional mirror for environments without direct access to huggingface.co +# export HF_ENDPOINT=https://hf-mirror.com + +mkdir -p "$DATASET_DIR" "$OUTPUT_DIR" +read -r -a VLLM_EXTRA_ARR <<< "$VLLM_EXTRA_ARGS" + +echo "=== Step 1: Downloading dataset snapshot ===" +hf download "$DATASET_REPO" \ + --repo-type dataset \ + --local-dir "$DATASET_DIR" \ + --include "README.md" \ + --include "data/*.parquet" + +echo "=== Step 2: Materializing Parquet dataset to absolute-path JSONL ===" +python - "$DATASET_DIR" "$MAX_SAMPLES" <<'PY' +import json +import sys +from pathlib import Path + +from datasets import Image, load_dataset + +dataset_dir = Path(sys.argv[1]).resolve() +max_samples_arg = sys.argv[2] +max_samples = None +if max_samples_arg and max_samples_arg.lower() not in {"0", "all", "none"}: + max_samples = int(max_samples_arg) +dst = dataset_dir / "train.absolute_paths.jsonl" +parquet_files = sorted((dataset_dir / "data").glob("train-*.parquet")) +if not parquet_files: + raise FileNotFoundError(f"No Parquet shards found under {dataset_dir / 'data'}") + + +def absolutize_image_ref(image_ref: object) -> object: + if not isinstance(image_ref, str): + return image_ref + if image_ref.startswith(("http://", "https://", "/")): + return image_ref + return str((dataset_dir / image_ref).resolve()) + + +def safe_relative_path(image_path: object, row_idx: int) -> Path: + path_text = str(image_path) if isinstance(image_path, str) else f"images/{row_idx:08d}.jpg" + path = Path(path_text) + if path.is_absolute() or ".." in path.parts: + path = Path("images") / path.name + return path + + +def materialize_image(sample: dict, row_idx: int) -> str: + image = sample.get("image") + image_path = sample.get("image_path") + image_bytes = None + if isinstance(image, dict): + image_path = image_path or image.get("path") + image_bytes = image.get("bytes") + rel_path = safe_relative_path(image_path, row_idx) + image_file = dataset_dir / rel_path + if image_bytes is not None: + image_file.parent.mkdir(parents=True, exist_ok=True) + image_file.write_bytes(image_bytes) + elif not image_file.exists(): + raise FileNotFoundError(f"Missing image bytes and file for row {row_idx}: {image_path}") + return str(image_file.resolve()) + + +ds = load_dataset( + "parquet", + data_files={"train": [str(path) for path in parquet_files]}, + split="train", +).cast_column("image", Image(decode=False)) + +count = 0 +with dst.open("w", encoding="utf-8") as fout: + for row_idx, sample in enumerate(ds): + if max_samples is not None and count >= max_samples: + break + + sample["image"] = materialize_image(sample, row_idx) + sample.pop("image_path", None) + + for turn in sample.get("conversations", []): + content = turn.get("content") + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") in {"image", "image_url"}: + if "image" in item: + item["image"] = absolutize_image_ref(item["image"]) + elif isinstance(item.get("image_url"), dict): + url = item["image_url"].get("url") + item["image_url"]["url"] = absolutize_image_ref(url) + + fout.write(json.dumps(sample, ensure_ascii=False) + "\n") + count += 1 + +print(f"Wrote {count} rows to {dst}") +PY + +echo "=== Step 3: Preparing multimodal data ===" +python scripts/prepare_data.py \ + --model "$MODEL" \ + --data "$DATASET_JSONL" \ + --output "$OUTPUT_DIR" \ + --max-samples "$MAX_SAMPLES" \ + --seq-length "$SEQ_LENGTH" \ + --multimodal + +echo "=== Step 4: Launching vLLM server ===" +echo "vLLM logs will be written to: $VLLM_LOG_FILE" + +CUDA_VISIBLE_DEVICES="$VLLM_GPUS" python scripts/launch_vllm.py "$MODEL" \ + --hidden-states-path "$HIDDEN_STATES_DIR" \ + -- \ + --port "$VLLM_PORT" \ + --tensor-parallel-size "$VLLM_TP" \ + --max-model-len "$VLLM_MAX_MODEL_LEN" \ + --limit-mm-per-prompt '{"image":1}' \ + "${VLLM_EXTRA_ARR[@]}" \ + > "$VLLM_LOG_FILE" 2>&1 & +VLLM_PID=$! + +cleanup() { + echo "Stopping vLLM server..." + kill "$VLLM_PID" 2>/dev/null || true + wait "$VLLM_PID" 2>/dev/null || true +} +trap cleanup EXIT + +echo "Waiting for vLLM server to be ready..." +until curl -sf "http://localhost:${VLLM_PORT}/health" > /dev/null 2>&1; do + sleep 2 +done +echo "vLLM server ready." + +echo "=== Step 5: Online training ===" +CUDA_VISIBLE_DEVICES="$TRAIN_GPUS" torchrun \ + --standalone --nproc_per_node "$NUM_TRAIN_GPUS" \ + scripts/train.py \ + --verifier-name-or-path "$MODEL" \ + --data-path "$OUTPUT_DIR" \ + --hidden-states-path "$HIDDEN_STATES_DIR" \ + --vllm-endpoint "http://localhost:${VLLM_PORT}/v1" \ + --save-path "$CHECKPOINT_DIR" \ + --epochs "$EPOCHS" \ + --lr "$LR" \ + --total-seq-len "$SEQ_LENGTH" \ + --num-layers 1 \ + --ttt-steps 3 \ + --ttt-step-loss-decay 1.0 \ + --on-missing generate \ + --on-generate cache \ + --run-name eagle3_qwen3_vl_4b_llava_cot_5k_online + +echo "Done. Checkpoints saved to $CHECKPOINT_DIR" diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 04882c09b..d041eccd7 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -246,6 +246,24 @@ def check_safetensors_file(path: Path, tokens: list[int]): ) +def _to_token_id_list(value: Any) -> list[int]: + if hasattr(value, "tolist"): + return value.tolist() + return list(value) + + +def _build_queue_item(idx: int, item: dict[str, Any]) -> dict[str, Any]: + queue_item: dict[str, Any] = { + "idx": idx, + "input_ids": _to_token_id_list(item["input_ids"]), + } + + if "messages" in item: + queue_item["messages"] = item["messages"] + + return queue_item + + async def worker( client, model: str, @@ -276,7 +294,8 @@ async def worker( queue.task_done() continue - input_ids = item["input_ids"].tolist() + input_ids = item["input_ids"] + messages = item.get("messages") target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" @@ -286,6 +305,7 @@ async def worker( client, model, input_ids, + messages=messages, timeout=request_timeout, max_retries=max_retries, ) @@ -330,7 +350,7 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): # deadlocking when all workers have died. while not cancel_event.is_set(): try: - queue.put_nowait({"idx": i, "input_ids": item["input_ids"]}) + queue.put_nowait(_build_queue_item(i, item)) break except asyncio.QueueFull: await asyncio.sleep(0.1) @@ -368,6 +388,8 @@ async def generate_and_save_hidden_states(args, dataset): existing_file_indices = get_existing_hidden_state_indices(hidden_states_dir) num_samples = len(dataset) + if "messages" in dataset.column_names: + logger.info("Detected multimodal preprocessed dataset") to_process = get_indices_to_process( num_samples, args.max_samples, existing_file_indices @@ -428,7 +450,7 @@ async def generate_and_save_hidden_states(args, dataset): await _shutdown_workers(workers, queue, cancel_event) num_saved = len(to_process) - len(skipped_indices) - logger.info(f"Saved {num_saved} new data points to {args.output}") + logger.info(f"Saved {num_saved} new data points to {hidden_states_dir}") if skipped_indices: logger.warning( f"Skipped {len(skipped_indices)} samples due to errors: {skipped_indices}" diff --git a/scripts/launch_vllm.py b/scripts/launch_vllm.py index a9bbb8573..13ac3ae4e 100644 --- a/scripts/launch_vllm.py +++ b/scripts/launch_vllm.py @@ -49,6 +49,42 @@ def parse_args(): return parser.parse_known_args() +def _is_multimodal_config(config) -> bool: + if any( + hasattr(config, field_name) + for field_name in ( + "vision_config", + "visual_config", + "image_token_id", + "video_token_id", + ) + ): + return True + + model_type = str(getattr(config, "model_type", "")).lower() + if any( + marker in model_type + for marker in ("vl", "vision", "llava", "mllama", "paligemma", "pixtral") + ): + return True + + architectures = getattr(config, "architectures", []) or [] + return any( + any( + marker in str(architecture).lower() + for marker in ( + "vl", + "vision", + "llava", + "mllama", + "paligemma", + "pixtral", + ) + ) + for architecture in architectures + ) + + def main(): args, vllm_args = parse_args() if "--" in vllm_args: @@ -57,9 +93,18 @@ def main(): from transformers import AutoConfig # noqa: PLC0415 config = AutoConfig.from_pretrained(args.model) - if hasattr(config, "text_config"): - config = config.text_config - num_hidden_layers = config.num_hidden_layers + text_config = config.text_config if hasattr(config, "text_config") else config + num_hidden_layers = text_config.num_hidden_layers + + if _is_multimodal_config(config) and "--enforce-eager" not in vllm_args: + # vLLM 0.20 multimodal hidden-state extraction can hit CUDA graph + # shape mismatches after image/video token expansion. Eager mode avoids + # that runtime crash. + vllm_args.append("--enforce-eager") + warnings.warn( + "Adding --enforce-eager for multimodal hidden-state extraction.", + stacklevel=2, + ) if args.target_layer_ids: target_layer_ids = args.target_layer_ids @@ -67,7 +112,9 @@ def main(): target_layer_ids.append(num_hidden_layers) warnings.warn( f"Using custom target layer ids {target_layer_ids}. These " - "must also be explicitly passed into the training script.", + "must also be explicitly aligned in the training script. " + "If the final verifier layer is included here, pass only the " + "auxiliary layers to training.", stacklevel=2, ) else: @@ -78,12 +125,29 @@ def main(): num_hidden_layers, ] + draft_hf_config = {"eagle_aux_hidden_state_layer_ids": target_layer_ids} + if text_config is not config: + # vLLM's ExtractHiddenStatesConfig flattens the draft config and does not + # preserve nested text_config for multimodal verifiers. Clear the nested + # text_config and copy the text-only shape fields onto the draft config so + # hidden-state extraction can derive the draft model shape from the + # flattened config. + draft_hf_config["text_config"] = None + for field_name in ( + "num_attention_heads", + "num_hidden_layers", + "hidden_size", + "num_key_value_heads", + "head_dim", + ): + field_value = getattr(text_config, field_name, None) + if field_value is not None: + draft_hf_config[field_name] = field_value + speculative_config = { "method": "extract_hidden_states", "num_speculative_tokens": 1, - "draft_model_config": { - "hf_config": {"eagle_aux_hidden_state_layer_ids": target_layer_ids} - }, + "draft_model_config": {"hf_config": draft_hf_config}, } kv_transfer_config = { "kv_connector": "ExampleHiddenStatesConnector", diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 3390d9eef..24fc2bd1a 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -134,6 +134,14 @@ def parse_args(): "trainable tokens." ), ) + parser.add_argument( + "--multimodal", + action="store_true", + help=( + "Enable multimodal preprocessing with AutoProcessor and preserve " + "messages needed for vLLM chat hidden-state generation." + ), + ) return parser.parse_args() @@ -177,6 +185,7 @@ def main(): assistant_pattern=args.assistant_pattern, turn_dropout=args.turn_dropout, minimum_valid_tokens=args.minimum_valid_tokens, + is_multimodal=args.multimodal, ) log.info("Done preparing data") diff --git a/scripts/train.py b/scripts/train.py index cbcfabcae..508a73958 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -66,6 +66,7 @@ def setup_dataloader( world_size: int, local_rank: int, hidden_size: int, + hidden_states_dtype: torch.dtype, num_workers: int = 12, prefetch_factor: int = 4, preprocess=None, @@ -96,7 +97,12 @@ def setup_dataloader( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=create_collate_fn(args.total_seq_len, hidden_size, preprocess), + collate_fn=create_collate_fn( + args.total_seq_len, + hidden_size, + hidden_states_dtype=hidden_states_dtype, + preprocess=preprocess, + ), persistent_workers=True, ) @@ -339,6 +345,7 @@ def main(args: argparse.Namespace): world_size, local_rank, transformer_layer_config.hidden_size, + hidden_states_dtype, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, preprocess=preprocess, @@ -348,6 +355,7 @@ def main(args: argparse.Namespace): world_size, local_rank, transformer_layer_config.hidden_size, + hidden_states_dtype, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, preprocess=preprocess, diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 3adf0c18b..33fdb59dc 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -1,14 +1,17 @@ import bisect import random import re +from functools import partial from pathlib import Path from re import Pattern from typing import Any, cast +from urllib.parse import unquote, urlparse import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase +from transformers.image_utils import load_image from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger @@ -58,6 +61,215 @@ def _visualize_sample(preprocessed, tokenizer, idx: int = 0): log.info(highlighted) +def _normalize_content(content: Any) -> Any: + """Normalize multimodal content into a processor-friendly representation.""" + + def _strip_image_tokens(text: str) -> str: + return re.sub(r"", "", text, flags=re.IGNORECASE).strip() + + if isinstance(content, list): + normalized_items = [] + for item in content: + if isinstance(item, str): + normalized_items.append( + {"type": "text", "text": _strip_image_tokens(item)} + ) + continue + if isinstance(item, dict): + image_ref = _get_image_ref(item) + if image_ref is not None: + normalized_items.append({"type": "image", "image": image_ref}) + continue + text_val = item.get("text") + if text_val is not None: + normalized_items.append( + {"type": "text", "text": _strip_image_tokens(text_val)} + ) + continue + normalized_items.append(item) + return normalized_items + return content + + +def _get_conversations_from_examples(examples: dict) -> list: + """Extract conversations/messages from a batched dataset example.""" + if "conversations" in examples: + return examples.get("conversations", []) + if "messages" in examples: + return examples.get("messages", []) + return [] + + +def _flatten_singleton_batch(value: Any, *, field_name: str) -> Any: + """Collapse a singleton batch dimension from HF outputs.""" + if isinstance(value, torch.Tensor): + value = value.tolist() + + if isinstance(value, list) and value and isinstance(value[0], list): + if len(value) != 1: + raise ValueError( + f"{field_name} returned non-singleton batch: batch_size={len(value)}" + ) + return value[0] + + return value + + +def _get_tokenizer_from_processor(processor: Any) -> PreTrainedTokenizerBase: + """Extract the tokenizer from a processor.""" + tokenizer = getattr(processor, "tokenizer", None) + if tokenizer is None: + raise ValueError("Processor does not provide a tokenizer attribute.") + return cast("PreTrainedTokenizerBase", tokenizer) + + +def _load_image_for_processor(image_ref: Any) -> Any: + """Load a local/data image reference for HF multimodal processors.""" + if isinstance(image_ref, str): + parsed = urlparse(image_ref) + if parsed.scheme == "file": + image_ref = unquote(parsed.path) + return load_image(image_ref) + + +def _get_image_ref(item: dict) -> Any | None: + """Extract a serializable image reference from a multimodal content item.""" + if item.get("type") not in ("image", "image_url", "input_image"): + return None + + image_ref = item.get("image") + if image_ref is not None: + return image_ref + + image_url = item.get("image_url") + if isinstance(image_url, dict): + return image_url.get("url") + return image_url + + +def _extract_processor_images_from_conversation(conv: list[dict]) -> list[Any]: + """Extract image inputs in the same order as the chat template placeholders.""" + images = [] + for turn in conv: + content = turn.get("content") + if not isinstance(content, list): + continue + + for item in content: + if not isinstance(item, dict): + continue + image_ref = _get_image_ref(item) + if image_ref is not None: + images.append(_load_image_for_processor(image_ref)) + + return images + + +def _get_image_token_ids(tokenizer: PreTrainedTokenizerBase) -> set[int]: + """Return known image placeholder token IDs for VL token expansion.""" + image_token_ids = set() + convert_tokens_to_ids = getattr(tokenizer, "convert_tokens_to_ids", None) + if not callable(convert_tokens_to_ids): + return image_token_ids + + unk_token_id = getattr(tokenizer, "unk_token_id", None) + for token in ("<|image_pad|>", "", ""): + token_id = convert_tokens_to_ids(token) + if isinstance(token_id, int) and token_id >= 0 and token_id != unk_token_id: + image_token_ids.add(token_id) + + return image_token_ids + + +def _expand_loss_mask_for_multimodal_tokens( + input_ids: list[int], + loss_mask: torch.Tensor, + expanded_input_ids: list[int], + image_token_ids: set[int], +) -> torch.Tensor: + """Expand loss masks when a processor expands image placeholders.""" + if input_ids == expanded_input_ids: + return loss_mask + + if not image_token_ids: + raise ValueError("Cannot align expanded multimodal input IDs without image IDs") + + mask_values = loss_mask.tolist() + expanded_mask: list[int] = [] + src_idx = 0 + dst_idx = 0 + + while src_idx < len(input_ids): + token_id = input_ids[src_idx] + if dst_idx >= len(expanded_input_ids): + break + + if token_id in image_token_ids: + start_idx = dst_idx + while ( + dst_idx < len(expanded_input_ids) + and expanded_input_ids[dst_idx] == token_id + ): + dst_idx += 1 + if dst_idx == start_idx: + raise ValueError( + f"Unable to align image token {token_id} at position {src_idx}" + ) + expanded_mask.extend([0] * (dst_idx - start_idx)) + src_idx += 1 + continue + + if expanded_input_ids[dst_idx] != token_id: + raise ValueError( + "Unable to align expanded multimodal input IDs at " + f"source={src_idx}, expanded={dst_idx}: " + f"{token_id} != {expanded_input_ids[dst_idx]}" + ) + expanded_mask.append(int(mask_values[src_idx])) + src_idx += 1 + dst_idx += 1 + + if dst_idx != len(expanded_input_ids): + raise ValueError( + "Expanded multimodal input IDs contain trailing tokens after alignment" + ) + + return torch.tensor(expanded_mask, dtype=loss_mask.dtype) + + +def _expand_multimodal_inputs_with_images( + processor: Any, + tokenizer: PreTrainedTokenizerBase, + formatted_text: str, + normalized_conv: list[dict], + input_ids: list[int], + loss_mask: torch.Tensor, + max_length: int, +) -> tuple[list[int], torch.Tensor]: + """Use actual images so HF preprocessing matches vLLM VL token expansion.""" + images = _extract_processor_images_from_conversation(normalized_conv) + if not images: + return input_ids, loss_mask + + encoded = processor( + text=[formatted_text], + images=images, + max_length=max_length, + truncation=True, + ) + expanded_input_ids = _flatten_singleton_batch( + encoded["input_ids"], + field_name="Multimodal processor input_ids with images", + ) + expanded_loss_mask = _expand_loss_mask_for_multimodal_tokens( + input_ids, + loss_mask, + expanded_input_ids, + _get_image_token_ids(tokenizer), + ) + return expanded_input_ids, expanded_loss_mask + + def _normalize_conversation( conv: list[dict], turn_dropout: bool = False, @@ -78,6 +290,7 @@ def _normalize_conversation( for i, turn in enumerate(conv): role = turn.get("from", turn.get("role", "")) content = turn.get("value") or turn.get("content") or "" + content = _normalize_content(content) # Map various role names to standard user/assistant if role in ("human", "user"): @@ -107,13 +320,13 @@ def _normalize_conversation( return normalized -def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: - """Check if tokenizer truly supports HF assistant token mask. +def _supports_assistant_mask(caller: Any) -> bool: + """Check if tokenizer/processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ try: - res_any = tokenizer.apply_chat_template( + res_any = caller.apply_chat_template( [{"role": "assistant", "content": "test"}], tokenize=True, return_assistant_tokens_mask=True, @@ -126,13 +339,14 @@ def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: return False # Verify the mask is not all zeros + mask = _flatten_singleton_batch(mask, field_name="assistant mask") return any(m == 1 for m in mask) except (TypeError, ValueError, KeyError, AttributeError): return False -def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: - """Auto-detect the assistant message pattern from the tokenizer's chat template. +def _detect_assistant_pattern(caller: Any) -> str: + """Auto-detect the assistant message pattern from a chat template caller. Uses multi-turn conversation but extracts pattern from the LAST assistant message only. @@ -144,7 +358,7 @@ def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: {"role": "assistant", "content": "ASSISTANT_MSG_2"}, ] - formatted = tokenizer.apply_chat_template( + formatted = caller.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False ) assert isinstance(formatted, str), "Expected string from apply_chat_template" @@ -266,7 +480,7 @@ def _preprocess_batch( """Process a batch of conversations into tokenized format with loss masks.""" results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} - conversations = examples.get("conversations", []) + conversations = _get_conversations_from_examples(examples) if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") @@ -294,14 +508,23 @@ def _preprocess_batch( encoded = cast("dict[str, Any]", encoded_any) # input IDs and loss mask - input_ids = encoded["input_ids"] + input_ids = _flatten_singleton_batch( + encoded["input_ids"], + field_name="Text apply_chat_template input_ids", + ) # HF uses 'assistant_masks' in recent versions mask_key = ( "assistant_masks" if "assistant_masks" in encoded else "assistant_mask" ) - loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) + loss_mask = torch.tensor( + _flatten_singleton_batch( + encoded[mask_key], + field_name="Text apply_chat_template assistant mask", + ), + dtype=torch.long, + ) else: # Fallback: regex-based detection @@ -325,8 +548,13 @@ def _preprocess_batch( ) # input IDs and loss mask - input_ids = encoding["input_ids"] - offsets = encoding["offset_mapping"] + input_ids = _flatten_singleton_batch( + encoding["input_ids"], field_name="Text tokenizer input_ids" + ) + offsets = _flatten_singleton_batch( + encoding["offset_mapping"], + field_name="Text tokenizer offset_mapping", + ) loss_mask = _create_loss_mask_from_offsets( formatted_raw, offsets, assistant_pattern @@ -359,6 +587,134 @@ def _preprocess_batch( return results +def _preprocess_batch_multimodal( # noqa: PLR0912, PLR0915 + examples: dict, + processor: Any, + max_length: int, + assistant_pattern: str | Pattern[str] | None, + turn_dropout: bool = False, + minimum_valid_tokens: int | None = None, +) -> dict[str, list]: + """Process a batch of multimodal conversations into token IDs and loss masks.""" + + results: dict[str, list] = { + "input_ids": [], + "loss_mask": [], + "messages": [], + "seq_len": [], + } + conversations = _get_conversations_from_examples(examples) + + if not conversations: + log.warning(f"No conversations key found. Keys: {list(examples.keys())}") + return results + + tokenizer = _get_tokenizer_from_processor(processor) + + for idx, conv in enumerate(conversations): + if not conv or not isinstance(conv, list): + continue + + normalized_conv = _normalize_conversation(conv, turn_dropout) + if not normalized_conv: + continue + + try: + formatted_raw = processor.apply_chat_template( + normalized_conv, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(formatted_raw, str) + + if assistant_pattern is None: + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) + encoded = cast("dict[str, Any]", encoded_any) + + input_ids = _flatten_singleton_batch( + encoded["input_ids"], + field_name="Multimodal apply_chat_template input_ids", + ) + mask_key = ( + "assistant_masks" + if "assistant_masks" in encoded + else "assistant_mask" + ) + loss_mask = torch.tensor( + _flatten_singleton_batch( + encoded[mask_key], + field_name="Multimodal apply_chat_template assistant mask", + ), + dtype=torch.long, + ) + else: + encoding = tokenizer( + formatted_raw, + return_offsets_mapping=True, + max_length=max_length, + truncation=True, + add_special_tokens=False, + ) + + input_ids = _flatten_singleton_batch( + encoding["input_ids"], field_name="Multimodal tokenizer input_ids" + ) + offsets = _flatten_singleton_batch( + encoding["offset_mapping"], + field_name="Multimodal tokenizer offset_mapping", + ) + loss_mask = _create_loss_mask_from_offsets( + formatted_raw, offsets, assistant_pattern + ) + + assert len(input_ids) == len(loss_mask), ( + f"Shape mismatch: input_ids={len(input_ids)}, " + f"loss_mask={len(loss_mask)}" + ) + + input_ids, loss_mask = _expand_multimodal_inputs_with_images( + processor, + tokenizer, + formatted_raw, + normalized_conv, + input_ids, + loss_mask, + max_length, + ) + + if minimum_valid_tokens is not None: + num_valid_tokens = int(loss_mask.sum().item()) + if num_valid_tokens < minimum_valid_tokens: + continue + + results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) + results["loss_mask"].append(loss_mask) + results["messages"].append(normalized_conv) + results["seq_len"].append(len(input_ids)) + + except ( + TypeError, + ValueError, + KeyError, + AttributeError, + RuntimeError, + OSError, + ) as e: + log.error( + f"Failed to process conversation {idx} " + f"(assistant_pattern={assistant_pattern is not None}): {e}" + ) + continue + + return results + + def build_eagle3_dataset( dataset: HFDataset, tokenizer: PreTrainedTokenizerBase, @@ -367,6 +723,7 @@ def build_eagle3_dataset( assistant_pattern: str | Pattern[str] | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, + processor: Any | None = None, ) -> HFDataset: """Build EAGLE3 dataset by tokenizing conversations and creating loss masks. @@ -384,35 +741,60 @@ def build_eagle3_dataset( conversation minimum_valid_tokens: Number of tokens to consider for a valid sample """ + is_multimodal = processor is not None + caller = processor if is_multimodal else tokenizer + # Detect and use provided assistant message pattern if assistant_pattern is not None: log.info(f"Using custom assistant pattern: {str(assistant_pattern)[:80]}...") - elif _supports_assistant_mask(tokenizer): + elif _supports_assistant_mask(caller): assistant_pattern = None # Signal to use HF mask in _preprocess_batch - log.info("Using HF assistant token mask for loss masking") + suffix = " (multimodal)" if is_multimodal else "" + log.info(f"Using HF assistant token mask for loss masking{suffix}") else: - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(caller) log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...") original_cols = dataset.column_names + if is_multimodal: + map_fn = partial( + _preprocess_batch_multimodal, + processor=processor, + max_length=max_length, + assistant_pattern=assistant_pattern, + turn_dropout=turn_dropout, + minimum_valid_tokens=minimum_valid_tokens, + ) + else: + map_fn = partial( + _preprocess_batch, + tokenizer=tokenizer, + max_length=max_length, + assistant_pattern=assistant_pattern, + turn_dropout=turn_dropout, + minimum_valid_tokens=minimum_valid_tokens, + ) + dataset = dataset.map( - lambda examples: _preprocess_batch( - examples, - tokenizer, - max_length, - assistant_pattern, - turn_dropout, - minimum_valid_tokens, - ), + map_fn, batched=True, num_proc=num_proc, - batch_size=1000, + # Multimodal preprocessing loads each image and asks the processor to + # expand image placeholders into the real visual-token length. Large + # batches can look stuck because a worker must finish the whole batch + # before datasets updates progress. + batch_size=32 if is_multimodal else 1000, remove_columns=original_cols, keep_in_memory=True, # skip caching ) - dataset.set_format(type="torch") + if is_multimodal: + dataset.set_format( + type="torch", columns=["input_ids", "loss_mask"], output_all_columns=True + ) + else: + dataset.set_format(type="torch") return dataset @@ -436,7 +818,7 @@ def load_raw_dataset(train_data_path: str, num_proc: int = 8) -> HFDataset: return raw_dataset -def load_and_preprocess_dataset( +def load_and_preprocess_dataset( # noqa: PLR0912 target_model_path: str, train_data_paths: list[str], seq_length: int, @@ -447,6 +829,7 @@ def load_and_preprocess_dataset( assistant_pattern: str | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, + is_multimodal: bool | None = None, ) -> tuple[HFDataset, PreTrainedTokenizerBase]: """Load, tokenize, and preprocess a dataset for EAGLE3 training. @@ -480,12 +863,34 @@ def load_and_preprocess_dataset( f"Filtering samples with fewer than {minimum_valid_tokens} valid tokens" ) + if is_multimodal is None: + is_multimodal = False + log.info(f"Using multimodal mode: {is_multimodal}") + log.subsection("Loading tokenizer") - tokenizer = AutoTokenizer.from_pretrained(target_model_path, trust_remote_code=True) + if is_multimodal: + processor = AutoProcessor.from_pretrained( + target_model_path, trust_remote_code=True + ) + tokenizer = _get_tokenizer_from_processor(processor) + else: + processor = None + tokenizer = AutoTokenizer.from_pretrained( + target_model_path, trust_remote_code=True + ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: + if is_multimodal: + if not hasattr(processor, "apply_chat_template"): + raise ValueError( + f"Processor for {target_model_path} does not support " + "apply_chat_template. Please use a model with a pre-configured " + "chat template." + ) + elif ( + not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None + ): raise ValueError( f"Tokenizer for {target_model_path} does not support chat templates. " "Please use a model with a pre-configured chat template." @@ -516,13 +921,14 @@ def load_and_preprocess_dataset( assistant_pattern=assistant_pattern, turn_dropout=turn_dropout, minimum_valid_tokens=minimum_valid_tokens, + processor=processor, ) if minimum_valid_tokens is not None: log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") processed_datasets.append(preprocessed_dataset) combined_dataset = concatenate_datasets(processed_datasets) - combined_dataset.shuffle(seed=seed) + combined_dataset = combined_dataset.shuffle(seed=seed) if max_samples is not None and len(combined_dataset) > max_samples: combined_dataset = combined_dataset.select(range(max_samples)) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index da766a656..9d41e392a 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -1,9 +1,15 @@ import asyncio +import base64 import functools import logging +import mimetypes import time +from pathlib import Path +from typing import Any, cast +from urllib.parse import urlparse import openai +from safetensors.torch import load_file, save_file logger = logging.getLogger(__name__) @@ -16,6 +22,132 @@ class InvalidResponseError(Exception): pass +def _get_field(obj: Any, key: str) -> Any: + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) + + +def _to_token_id_list(token_ids: Any) -> list[int]: + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + return list(token_ids) + + +def _image_ref_to_chat_url(image_ref: Any) -> str: + """Convert a dataset image reference to an OpenAI-compatible image URL.""" + ref = str(image_ref) + parsed = urlparse(ref) + if parsed.scheme in {"http", "https", "data", "file"}: + return ref + + path = Path(ref).expanduser() + if path.exists() and path.is_file(): + mime_type = mimetypes.guess_type(path.name)[0] or "image/jpeg" + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + return f"data:{mime_type};base64,{encoded}" + + if path.is_absolute(): + return path.as_uri() + + return ref + + +def _get_image_ref(part: dict[str, Any]) -> Any | None: + if part.get("type") not in ("image", "image_url", "input_image"): + return None + + image_ref = part.get("image") + if image_ref is not None: + return image_ref + + image_url = part.get("image_url") + if isinstance(image_url, dict): + return image_url.get("url") + return image_url + + +def _prepare_chat_message_content(content: Any) -> Any: + if not isinstance(content, list): + return content + + prepared = [] + for part in content: + if isinstance(part, str): + prepared.append({"type": "text", "text": part}) + continue + + if not isinstance(part, dict): + prepared.append(part) + continue + + image_ref = _get_image_ref(part) + if image_ref is not None: + prepared.append( + { + "type": "image_url", + "image_url": {"url": _image_ref_to_chat_url(image_ref)}, + } + ) + continue + + text = part.get("text") + if text is not None: + prepared.append({"type": "text", "text": str(text)}) + continue + + prepared.append(part) + + return prepared + + +def _prepare_chat_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert processor-style multimodal messages to vLLM chat messages.""" + prepared = [] + for message in messages: + prepared_message: dict[str, Any] = { + "role": message.get("role"), + "content": _prepare_chat_message_content(message.get("content", "")), + } + if "name" in message: + prepared_message["name"] = message["name"] + prepared.append(prepared_message) + return prepared + + +def _truncate_hidden_states_file(hidden_states_path: str, token_ids: list[int]) -> str: + """Trim vLLM's full multimodal prompt output to the preprocessed prefix.""" + tensors = load_file(hidden_states_path) + + try: + file_token_ids = _to_token_id_list(tensors["token_ids"]) + hidden_states = tensors["hidden_states"] + except KeyError as exc: + raise InvalidResponseError( + f"Hidden states file missing {exc.args[0]}: {hidden_states_path}" + ) from exc + + expected_len = len(token_ids) + if ( + file_token_ids[:expected_len] != token_ids + or hidden_states.shape[0] < expected_len + ): + raise InvalidResponseError( + "Hidden states file does not match preprocessed prompt prefix: " + f"expected {token_ids}, got {file_token_ids}" + ) + + # Safe for causal models: prefix hidden states cannot attend to future tokens + # that only exist in vLLM's full chat-rendered prompt. + tensors["token_ids"] = tensors["token_ids"][:expected_len].contiguous() + tensors["hidden_states"] = hidden_states[:expected_len].contiguous() + save_file( + {key: value.contiguous() for key, value in tensors.items()}, + hidden_states_path, + ) + return hidden_states_path + + def _handle_retry_error( error: Exception, attempt: int, total_attempts: int ) -> float | None: @@ -82,21 +214,46 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs): return sync_wrapper -def extract_output(completion, token_ids) -> str: - prompt_token_ids = getattr(completion.choices[0], "prompt_token_ids", None) +def extract_output( + completion, + token_ids, + *, + allow_prefix_truncation: bool = False, +) -> str: + token_ids = _to_token_id_list(token_ids) + prompt_token_ids = _get_field(completion, "prompt_token_ids") + if prompt_token_ids is None: + choices = _get_field(completion, "choices") + if choices: + prompt_token_ids = _get_field(choices[0], "prompt_token_ids") if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") - if prompt_token_ids != token_ids: - raise InvalidResponseError( - f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" - ) - - if not hasattr(completion, "kv_transfer_params"): + kv_transfer_params = _get_field(completion, "kv_transfer_params") + if kv_transfer_params is None: raise InvalidResponseError("Response missing kv_transfer_params") - return completion.kv_transfer_params.get("hidden_states_path") + hidden_states_path = _get_field(kv_transfer_params, "hidden_states_path") + if hidden_states_path is None: + raise InvalidResponseError("Response missing hidden_states_path") + + prompt_token_ids = _to_token_id_list(prompt_token_ids) + if prompt_token_ids == token_ids: + return hidden_states_path + + if allow_prefix_truncation and prompt_token_ids[: len(token_ids)] == token_ids: + logger.debug( + "vLLM returned %d prompt tokens for a %d-token preprocessed prompt; " + "truncating hidden states to match prepare_data output.", + len(prompt_token_ids), + len(token_ids), + ) + return _truncate_hidden_states_file(hidden_states_path, token_ids) + + raise InvalidResponseError( + f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" + ) @with_retries @@ -104,6 +261,7 @@ async def generate_hidden_states_async( client: openai.AsyncClient, model: str, token_ids: list[int], + messages: list[dict[str, Any]] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ @@ -114,21 +272,36 @@ async def generate_hidden_states_async( client: The async OpenAI client. model: The model ID. token_ids: The input token IDs. + messages: Optional chat messages for vLLM multimodal requests. timeout: Timeout in seconds for each request attempt. None for no timeout. """ - coro = client.completions.create( - model=model, - prompt=token_ids, - max_tokens=1, - extra_body={"return_token_ids": True}, - timeout=timeout, - ) + if messages is not None: + chat_messages = _prepare_chat_messages(messages) + coro = client.chat.completions.create( + model=model, + messages=cast("Any", chat_messages), + max_tokens=1, + extra_body={"return_token_ids": True, "add_generation_prompt": False}, + timeout=timeout, + ) + else: + coro = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) if timeout is not None: completion = await asyncio.wait_for(coro, timeout=timeout) else: completion = await coro - return extract_output(completion, token_ids) + return extract_output( + completion, + token_ids, + allow_prefix_truncation=messages is not None, + ) @with_retries @@ -136,17 +309,32 @@ def generate_hidden_states( client: openai.Client, model: str, token_ids: list[int], + messages: list[dict[str, Any]] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ - completion = client.completions.create( - model=model, - prompt=token_ids, - max_tokens=1, - extra_body={"return_token_ids": True}, - timeout=timeout, + if messages is not None: + chat_messages = _prepare_chat_messages(messages) + completion = client.chat.completions.create( + model=model, + messages=cast("Any", chat_messages), + max_tokens=1, + extra_body={"return_token_ids": True, "add_generation_prompt": False}, + timeout=timeout, + ) + else: + completion = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + return extract_output( + completion, + token_ids, + allow_prefix_truncation=messages is not None, ) - return extract_output(completion, token_ids) diff --git a/src/speculators/models/utils.py b/src/speculators/models/utils.py index 0ebb8abca..edb025d8a 100644 --- a/src/speculators/models/utils.py +++ b/src/speculators/models/utils.py @@ -21,10 +21,25 @@ def resolve_target_layer_ids( target_layer_ids: list[int] | None, verifier_name_or_path: str, ) -> list[int]: + num_layers = get_verifier_config(verifier_name_or_path).num_hidden_layers + if target_layer_ids is not None: - return target_layer_ids + # Offline datagen extracts auxiliary layers plus the verifier's final layer. + # Training stores the final layer separately as `verifier_last_hidden_states`, + # so the draft config should only keep the auxiliary layers. + aux_target_layer_ids = [ + layer_id for layer_id in target_layer_ids if layer_id != num_layers + ] + if len(aux_target_layer_ids) != len(target_layer_ids): + warnings.warn( + "Stripping the verifier's final layer " + f"({num_layers}) from --target-layer-ids for training. " + "The last extracted layer is consumed separately as " + "`verifier_last_hidden_states`.", + stacklevel=2, + ) + return aux_target_layer_ids - num_layers = get_verifier_config(verifier_name_or_path).num_hidden_layers target_layer_ids = [2, num_layers // 2, num_layers - 3] warnings.warn( DEFAULT_TARGET_LAYER_IDS_WARNING.format(target_layer_ids=target_layer_ids), diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 97ec87828..c9747a653 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -65,7 +65,9 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0): StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]] -def create_empty_sample(hidden_size: int): +def create_empty_sample( + hidden_size: int, hidden_states_dtype: torch.dtype = torch.bfloat16 +): # data structure: { # "hidden_states": [seq_len, 3 * hidden_size], # "input_ids": [seq_len], @@ -76,10 +78,12 @@ def create_empty_sample(hidden_size: int): # } return { - "hidden_states": torch.empty(0, 3 * hidden_size), - "input_ids": torch.empty(0), - "verifier_last_hidden_states": torch.empty(0, hidden_size), - "loss_mask": torch.empty(0), + "hidden_states": torch.empty(0, 3 * hidden_size, dtype=hidden_states_dtype), + "input_ids": torch.empty(0, dtype=torch.long), + "verifier_last_hidden_states": torch.empty( + 0, hidden_size, dtype=hidden_states_dtype + ), + "loss_mask": torch.empty(0, dtype=torch.long), "lengths": torch.tensor([0], dtype=torch.long), "position_ids": torch.arange(0, dtype=torch.long), } @@ -258,12 +262,17 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: if not self.client: self._setup_client() - input_ids = self.data[index]["input_ids"].tolist() + sample = self.data[index] + input_ids = sample["input_ids"] + if hasattr(input_ids, "tolist"): + input_ids = input_ids.tolist() + messages = sample.get("messages") try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] input_ids, + messages=messages, timeout=self.request_timeout, max_retries=self.max_retries, ) @@ -431,6 +440,7 @@ def _get_raw_data(self, index): def create_collate_fn( max_len: int, hidden_size: int, + hidden_states_dtype: torch.dtype = torch.bfloat16, preprocess: Callable[[BatchType], BatchType] | None = None, ): def collate_fn(batch: list[BatchType | None]) -> BatchType: @@ -440,7 +450,7 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType: if not batch: # Create empty sample which then gets padded to full # batch size if no valid samples are found - batch = [create_empty_sample(hidden_size)] + batch = [create_empty_sample(hidden_size, hidden_states_dtype)] collated_data = {} for key in batch[0]: # type: ignore[union-attr] diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index a9184c18f..c280eaae2 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -3,15 +3,18 @@ """ import re +from typing import Any, cast import pytest import torch from datasets import Dataset as HFDataset -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from speculators.data_generation import preprocessing from speculators.data_generation.preprocessing import ( _create_loss_mask_from_offsets, _detect_assistant_pattern, + _expand_loss_mask_for_multimodal_tokens, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, @@ -77,6 +80,140 @@ def test_normalize_conversation_unknown_role(): assert result[1]["role"] == "assistant" +@pytest.mark.sanity +def test_normalize_conversation_multimodal_content(): + """Test normalization of multimodal content items.""" + conv: list[dict[str, Any]] = [ + { + "role": "user", + "content": [ + "Look at this ", + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + {"text": "Caption "}, + ], + }, + {"role": "assistant", "content": "Nice cat."}, + ] + + result = _normalize_conversation(conv) + + assert result[0]["content"] == [ + {"type": "text", "text": "Look at this"}, + {"type": "image", "image": "https://example.com/cat.png"}, + {"type": "text", "text": "Caption"}, + ] + + +@pytest.mark.sanity +def test_preprocess_batch_supports_messages_schema(): + """Test that preprocessing accepts datasets using the messages field.""" + + class DummyTokenizer: + chat_template = "dummy-template" + + def apply_chat_template(self, *args, **kwargs): + if kwargs.get("tokenize"): + return {"input_ids": [[11, 12, 13]], "assistant_mask": [[0, 1, 1]]} + return "formatted" + + results = _preprocess_batch( + { + "messages": [ + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + ] + }, + cast("PreTrainedTokenizerBase", DummyTokenizer()), + max_length=64, + assistant_pattern=None, + ) + + assert len(results["input_ids"]) == 1 + assert len(results["loss_mask"]) == 1 + + +@pytest.mark.sanity +def test_expand_loss_mask_for_multimodal_tokens_allows_truncation(): + """Test that expanded multimodal alignment tolerates processor truncation.""" + loss_mask = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + expanded = _expand_loss_mask_for_multimodal_tokens( + input_ids=[10, 999, 20, 30], + loss_mask=loss_mask, + expanded_input_ids=[10, 999, 999, 999, 20], + image_token_ids={999}, + ) + + assert expanded.tolist() == [0, 0, 0, 0, 1] + + +@pytest.mark.sanity +def test_load_and_preprocess_dataset_defaults_to_text_mode(monkeypatch): + """Test that omitted multimodal flag falls back to text-only processing.""" + raw_dataset = HFDataset.from_dict( + { + "conversations": [ + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + ] + } + ) + + class DummyTokenizer: + def __init__(self): + self.pad_token = None + self.eos_token = "" + self.chat_template = "dummy-template" + + def apply_chat_template(self, *args, **kwargs): + return "formatted" + + dummy_tokenizer = DummyTokenizer() + captured: dict[str, object] = {} + + def fake_build_eagle3_dataset(*args, **kwargs): + captured["processor"] = kwargs.get("processor") + return raw_dataset + + monkeypatch.setattr( + preprocessing, "load_raw_dataset", lambda *args, **kwargs: raw_dataset + ) + monkeypatch.setattr( + preprocessing.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: dummy_tokenizer, + ) + monkeypatch.setattr( + preprocessing, "build_eagle3_dataset", fake_build_eagle3_dataset + ) + monkeypatch.setattr( + preprocessing, "save_token_frequency_distribution", lambda *args, **kwargs: None + ) + monkeypatch.setattr( + preprocessing, "_visualize_sample", lambda *args, **kwargs: None + ) + + dataset, tokenizer = preprocessing.load_and_preprocess_dataset( + target_model_path="dummy-model", + train_data_paths=["dummy-data"], + seq_length=128, + build_dataset_num_proc=1, + is_multimodal=None, + ) + + assert len(dataset) == len(raw_dataset) + assert tokenizer is dummy_tokenizer + assert captured["processor"] is None + assert dummy_tokenizer.pad_token == dummy_tokenizer.eos_token + + # Tests for _detect_assistant_pattern @pytest.mark.sanity def test_detect_assistant_pattern_structure(): @@ -926,3 +1063,108 @@ def test_build_eagle3_dataset_with_custom_pattern(): # Should successfully build dataset with custom pattern assert isinstance(result, HFDataset) assert len(result) > 0 + + +@pytest.mark.sanity +def test_build_eagle3_dataset_multimodal_expands_image_tokens_and_preserves_messages(): + """Test multimodal preprocessing expands image tokens and preserves messages.""" + image_url = ( + "data:image/png;base64," + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/" + "x8AAwMCAO+/p9sAAAAASUVORK5CYII=" + ) + + class DummyTokenizer: + pad_token = "" + eos_token = "" + unk_token_id = -1 + + def convert_tokens_to_ids(self, token): + if token == "<|image_pad|>": + return 999 + return self.unk_token_id + + def __call__( + self, + text, + return_offsets_mapping=False, + max_length=None, + truncation=False, + add_special_tokens=False, + ): + del text, max_length, truncation, add_special_tokens + if return_offsets_mapping: + return { + "input_ids": [101, 102, 103], + "offset_mapping": [(0, 1), (1, 2), (2, 3)], + } + return {"input_ids": [101, 102, 103]} + + class DummyProcessor: + def __init__(self): + self.tokenizer = DummyTokenizer() + + def __call__(self, *args, **kwargs): + images = kwargs["images"] + assert len(images) == 1 + assert getattr(images[0], "mode", None) == "RGB" + return {"input_ids": [[101, 999, 999, 999, 103]]} + + def apply_chat_template(self, *args, **kwargs): + if kwargs.get("tokenize"): + return {"input_ids": [[101, 999, 103]], "assistant_mask": [[0, 0, 1]]} + return "formatted multimodal prompt" + + dataset = HFDataset.from_dict( + { + "messages": [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe "}, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "A cat."}], + }, + ] + ] + } + ) + + processor = DummyProcessor() + result = build_eagle3_dataset( + dataset, + cast("PreTrainedTokenizerBase", processor.tokenizer), + max_length=64, + num_proc=1, + processor=processor, + ) + + assert "messages" in result.column_names + assert result[0]["input_ids"].tolist() == [101, 999, 999, 999, 103] + assert result[0]["loss_mask"].tolist() == [0, 0, 0, 0, 1] + assert result[0]["seq_len"] == 5 + assert result[0]["messages"] == [ + { + "content": [ + {"image": None, "text": "Describe", "type": "text"}, + { + "image": image_url, + "text": None, + "type": "image", + }, + ], + "role": "user", + }, + { + "content": [{"image": None, "text": "A cat.", "type": "text"}], + "role": "assistant", + }, + ] diff --git a/tests/unit/convert/test_eagle3_converter.py b/tests/unit/convert/test_eagle3_converter.py index edb2f19d4..d8f46b136 100644 --- a/tests/unit/convert/test_eagle3_converter.py +++ b/tests/unit/convert/test_eagle3_converter.py @@ -259,15 +259,16 @@ def test_nm_testing_2layer_eagle3_model_config(self, tmp_path): # TODO: Use a real model in the future # noqa: FIX002 checkpoint_path = "nm-testing/testing-llama3.1.8b-2layer-eagle3" base_model = "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic" + output_path = tmp_path / checkpoint_path.rsplit("/", maxsplit=1)[-1] converter.convert( checkpoint_path, - tmp_path / checkpoint_path.split("/")[-1], + output_path, base_model, norm_before_residual=False, ) - config = load_checkpoint_config(tmp_path / checkpoint_path.split("/")[-1]) + config = load_checkpoint_config(output_path) # Verify that num_hidden_layers is correctly set to 2 assert config["transformer_layer_config"]["num_hidden_layers"] == 2 diff --git a/tests/unit/data_generation/__init__.py b/tests/unit/data_generation/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/unit/data_generation/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/data_generation/test_vllm_client.py b/tests/unit/data_generation/test_vllm_client.py new file mode 100644 index 000000000..30393cc68 --- /dev/null +++ b/tests/unit/data_generation/test_vllm_client.py @@ -0,0 +1,329 @@ +import asyncio + +import pytest +import torch +from safetensors.torch import load_file, save_file + +from speculators.data_generation.vllm_client import ( + InvalidResponseError, + generate_hidden_states, + generate_hidden_states_async, +) + + +class _DummyChoice: + def __init__(self, prompt_token_ids): + self.prompt_token_ids = prompt_token_ids + + +class _DummyCompletion: + def __init__( + self, + prompt_token_ids, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.choices = [_DummyChoice(prompt_token_ids)] + self.kv_transfer_params = {"hidden_states_path": hidden_states_path} + + +class _DummyChatCompletion: + def __init__( + self, + prompt_token_ids, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.prompt_token_ids = prompt_token_ids + self.kv_transfer_params = {"hidden_states_path": hidden_states_path} + + +class _DummySyncCompletions: + def __init__(self): + self.calls = [] + + def create(self, **kwargs): + self.calls.append(kwargs) + prompt = kwargs["prompt"] + prompt_token_ids = ( + prompt["prompt_token_ids"] if isinstance(prompt, dict) else prompt + ) + return _DummyCompletion(prompt_token_ids) + + +class _DummyAsyncCompletions: + def __init__(self): + self.calls = [] + + async def create(self, **kwargs): + self.calls.append(kwargs) + prompt = kwargs["prompt"] + prompt_token_ids = ( + prompt["prompt_token_ids"] if isinstance(prompt, dict) else prompt + ) + return _DummyCompletion(prompt_token_ids) + + +class _DummySyncChatCompletions: + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.calls = [] + self.prompt_token_ids = prompt_token_ids or [4, 5, 6] + self.hidden_states_path = hidden_states_path + + def create(self, **kwargs): + self.calls.append(kwargs) + return _DummyChatCompletion(self.prompt_token_ids, self.hidden_states_path) + + +class _DummyAsyncChatCompletions: + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.calls = [] + self.prompt_token_ids = prompt_token_ids or [7, 8, 9] + self.hidden_states_path = hidden_states_path + + async def create(self, **kwargs): + self.calls.append(kwargs) + return _DummyChatCompletion(self.prompt_token_ids, self.hidden_states_path) + + +class _DummySyncChat: + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.completions = _DummySyncChatCompletions( + prompt_token_ids, hidden_states_path + ) + + +class _DummyAsyncChat: + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.completions = _DummyAsyncChatCompletions( + prompt_token_ids, hidden_states_path + ) + + +class _DummySyncClient: + def __init__( + self, + chat_prompt_token_ids=None, + chat_hidden_states_path="/tmp/hs_0.safetensors", + ): + self.completions = _DummySyncCompletions() + self.chat = _DummySyncChat(chat_prompt_token_ids, chat_hidden_states_path) + + +class _DummyAsyncClient: + def __init__( + self, + chat_prompt_token_ids=None, + chat_hidden_states_path="/tmp/hs_0.safetensors", + ): + self.completions = _DummyAsyncCompletions() + self.chat = _DummyAsyncChat(chat_prompt_token_ids, chat_hidden_states_path) + + +def test_generate_hidden_states_text_prompt(): + client = _DummySyncClient() + + result = generate_hidden_states(client, "dummy-model", [1, 2, 3], timeout=1) + + assert result == "/tmp/hs_0.safetensors" + assert client.completions.calls[0]["prompt"] == [1, 2, 3] + + +def test_generate_hidden_states_multimodal_messages(): + client = _DummySyncClient() + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://example.com/cat.png"}, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] + expected_messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] + + result = generate_hidden_states( + client, + "dummy-model", + [4, 5, 6], + messages=messages, + timeout=1, + ) + + assert result == "/tmp/hs_0.safetensors" + assert client.chat.completions.calls[0]["messages"] == expected_messages + assert client.chat.completions.calls[0]["extra_body"] == { + "return_token_ids": True, + "add_generation_prompt": False, + } + + +def test_generate_hidden_states_multimodal_messages_inlines_local_image(tmp_path): + client = _DummySyncClient() + image_path = tmp_path / "cat.png" + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": str(image_path)}, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": "A cat."}, + ] + + result = generate_hidden_states( + client, + "dummy-model", + [4, 5, 6], + messages=messages, + timeout=1, + ) + + sent_content = client.chat.completions.calls[0]["messages"][0]["content"] + assert result == "/tmp/hs_0.safetensors" + assert sent_content[0]["type"] == "image_url" + assert sent_content[0]["image_url"]["url"].startswith("data:image/png;base64,") + assert sent_content[1] == {"type": "text", "text": "describe"} + + +def test_generate_hidden_states_async_multimodal_messages(): + client = _DummyAsyncClient() + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://example.com/cat.png"}, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] + expected_messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] + + result = asyncio.run( + generate_hidden_states_async( + client, + "dummy-model", + [7, 8, 9], + messages=messages, + timeout=1, + ) + ) + + assert result == "/tmp/hs_0.safetensors" + assert client.chat.completions.calls[0]["messages"] == expected_messages + assert client.chat.completions.calls[0]["extra_body"] == { + "return_token_ids": True, + "add_generation_prompt": False, + } + + +def test_generate_hidden_states_truncates_multimodal_prefix_match(tmp_path): + hs_path = tmp_path / "hs_0.safetensors" + save_file( + { + "token_ids": torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), + "hidden_states": torch.arange(5 * 2 * 3, dtype=torch.float32).reshape( + 5, 2, 3 + ), + }, + hs_path, + ) + client = _DummySyncClient( + chat_prompt_token_ids=[1, 2, 3, 4, 5], + chat_hidden_states_path=str(hs_path), + ) + + result = generate_hidden_states( + client, + "dummy-model", + [1, 2, 3], + messages=[{"role": "user", "content": "describe"}], + timeout=1, + ) + + tensors = load_file(result) + assert result == str(hs_path) + assert tensors["token_ids"].tolist() == [1, 2, 3] + assert tensors["hidden_states"].shape == (3, 2, 3) + assert torch.equal( + tensors["hidden_states"], + torch.arange(5 * 2 * 3, dtype=torch.float32).reshape(5, 2, 3)[:3], + ) + + +def test_generate_hidden_states_rejects_multimodal_non_prefix_mismatch(tmp_path): + hs_path = tmp_path / "hs_0.safetensors" + save_file( + { + "token_ids": torch.tensor([1, 9, 3, 4, 5], dtype=torch.long), + "hidden_states": torch.zeros(5, 2, 3), + }, + hs_path, + ) + client = _DummySyncClient( + chat_prompt_token_ids=[1, 9, 3, 4, 5], + chat_hidden_states_path=str(hs_path), + ) + + with pytest.raises(InvalidResponseError, match="Prompt token IDs mismatch"): + generate_hidden_states( + client, + "dummy-model", + [1, 2, 3], + messages=[{"role": "user", "content": "describe"}], + timeout=1, + ) + + +def test_generate_hidden_states_text_path_rejects_prefix_mismatch(): + class _PrefixTextCompletions: + def create(self, **kwargs): + return _DummyCompletion([1, 2, 3, 4]) + + client = _DummySyncClient() + client.completions = _PrefixTextCompletions() + + with pytest.raises(InvalidResponseError, match="Prompt token IDs mismatch"): + generate_hidden_states(client, "dummy-model", [1, 2, 3], timeout=1) diff --git a/tests/unit/models/test_utils.py b/tests/unit/models/test_utils.py new file mode 100644 index 000000000..c4c99fa81 --- /dev/null +++ b/tests/unit/models/test_utils.py @@ -0,0 +1,34 @@ +from types import SimpleNamespace + +import pytest + +from speculators.models import utils + + +def test_resolve_target_layer_ids_keeps_aux_layers_only(monkeypatch): + monkeypatch.setattr( + utils, + "get_verifier_config", + lambda _name_or_path: SimpleNamespace(num_hidden_layers=36), + ) + + with pytest.warns(UserWarning, match="Stripping the verifier's final layer"): + layer_ids = utils.resolve_target_layer_ids( + [2, 18, 33, 36], "unused-verifier-path" + ) + + assert layer_ids == [2, 18, 33] + + +def test_resolve_target_layer_ids_preserves_custom_aux_layers(monkeypatch): + monkeypatch.setattr( + utils, + "get_verifier_config", + lambda _name_or_path: SimpleNamespace(num_hidden_layers=36), + ) + + assert utils.resolve_target_layer_ids([2, 18, 33], "unused-verifier-path") == [ + 2, + 18, + 33, + ] diff --git a/tests/unit/train/test_data.py b/tests/unit/train/test_data.py index 4142568d1..ac0b5f6ea 100644 --- a/tests/unit/train/test_data.py +++ b/tests/unit/train/test_data.py @@ -243,6 +243,26 @@ def test_collate_fn_length_truncation(): assert collated[key].shape[1] == max_len +def test_collate_fn_empty_batch_uses_training_dtypes(): + """Test that all-skipped batches keep training-compatible dtypes.""" + max_len = 8 + hidden_size = 4 + collate_fn = create_collate_fn( + max_len, + hidden_size, + hidden_states_dtype=torch.bfloat16, + ) + + collated = collate_fn([None]) + + assert collated["hidden_states"].dtype == torch.bfloat16 + assert collated["verifier_last_hidden_states"].dtype == torch.bfloat16 + assert collated["input_ids"].dtype == torch.long + assert collated["loss_mask"].dtype == torch.long + assert collated["position_ids"].dtype == torch.long + assert collated["lengths"].dtype == torch.long + + def test_dataset_getitem_v1_format(tmp_path: Path): """Test dataset __getitem__ with v1 data format and dtype conversion."""