From f26504774868698bba90f88218a63dbbbe25c5ae Mon Sep 17 00:00:00 2001 From: Haoxiang Sun Date: Sat, 25 Apr 2026 14:54:34 +0800 Subject: [PATCH 1/5] Enable multimodal E2E online training for Qwen3-VL - add multimodal preprocessing support for dataset preparation and hidden-state generation - pass prompt and multi_modal_data through offline datagen and vLLM client flows - fix text-only regressions, target-layer alignment, and empty-batch dtype handling - add a configurable 5k Qwen3-VL online training example with runtime notes - extend preprocessing, vLLM client, model utils, and training data tests Signed-off-by: Haoxiang Sun --- docs/api/index.md | 3 +- docs/index.md | 3 +- .../eagle3_qwen3_vl_4b_llava_cot_5k_online.sh | 222 ++++++++++++ scripts/data_generation_offline.py | 32 +- scripts/launch_vllm.py | 32 +- scripts/prepare_data.py | 9 + scripts/train.py | 10 +- .../data_generation/preprocessing.py | 331 ++++++++++++++++-- .../data_generation/vllm_client.py | 45 ++- src/speculators/models/utils.py | 19 +- src/speculators/train/data.py | 24 +- .../integration/datagen/test_preprocessing.py | 194 +++++++++- tests/unit/convert/test_eagle3_converter.py | 5 +- tests/unit/data_generation/__init__.py | 1 + .../unit/data_generation/test_vllm_client.py | 106 ++++++ tests/unit/models/test_utils.py | 34 ++ tests/unit/train/test_data.py | 20 ++ 17 files changed, 1031 insertions(+), 59 deletions(-) create mode 100644 examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh create mode 100644 tests/unit/data_generation/__init__.py create mode 100644 tests/unit/data_generation/test_vllm_client.py create mode 100644 tests/unit/models/test_utils.py diff --git a/docs/api/index.md b/docs/api/index.md index 79738bfb5..f7496e11d 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -2,5 +2,4 @@ This section contains the auto-generated Python API documentation for Speculators. -!!! warning - Using the Python API directly is **not officially supported** at this time and there are **no guarantees of backward compatibility** between releases. We recommend using the [CLI commands](../cli/index.md) as the primary entrypoints for interacting with Speculators. +!!! warning Using the Python API directly is **not officially supported** at this time and there are **no guarantees of backward compatibility** between releases. We recommend using the [CLI commands](../cli/index.md) as the primary entrypoints for interacting with Speculators. diff --git a/docs/index.md b/docs/index.md index ee740d25c..57e2e580e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,8 +31,7 @@ Speculators standardizes this process by providing a productionized end-to-end f - **Standardized, Extensible Format:** Provides a Hugging Face-compatible format for defining speculative models, with tools to convert from external research repositories into a standard speculators format for easy adoption. - **Seamless vLLM Integration:** Built for direct deployment into vLLM, enabling low-latency, production-grade inference with minimal overhead. -!!! tip - Read more about Speculators features in this [vLLM blog post](https://blog.vllm.ai/2025/12/13/speculators-v030.html). +!!! tip Read more about Speculators features in this [vLLM blog post](https://blog.vllm.ai/2025/12/13/speculators-v030.html). ## Quick Start diff --git a/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh b/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh new file mode 100644 index 000000000..6f9431483 --- /dev/null +++ b/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh @@ -0,0 +1,222 @@ +#!/bin/bash +# Online Eagle3 Training Script for Qwen3-VL-4B on hao05/llava-cot-5k-reannotated +# +# Runs the full online training pipeline: +# 1. Download the multimodal Parquet dataset snapshot from Hugging Face +# 2. Materialize local image files and an absolute-path JSONL +# 3. Prepare arrow data with multimodal preprocessing +# 4. Launch a hidden-state extraction vLLM server +# 5. Train Eagle3 with on-the-fly hidden-state generation +# +# Usage: +# bash examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh +# +# Notes: +# - `prepare_data.py` currently accepts local json/jsonl files or built-in dataset +# aliases. This example snapshots the public HF Parquet dataset locally first. +# - The uploaded dataset stores image bytes in Parquet and preserves original +# relative paths in `image_path`. This script materializes image files and a +# JSONL with absolute image paths so vLLM can load images reliably during +# online training. +# - For more data and a longer training run that can improve accuracy, replace +# `hao05/llava-cot-5k-reannotated` with `hao05/llava-cot-48k-reannotated` +# and adjust `MAX_SAMPLES` / `EPOCHS` as needed. +# +# ### Example E2E run for Qwen3-VL-4B on 5k samples from LLaVA-CoT ### +# +# Note: This 5k setup is primarily a pipeline sanity check. It is enough to +# verify that multimodal online training, hidden-state generation, and +# checkpointing all work end-to-end, but it is not intended to represent final +# model quality. +# +# Timing from an observed run on 4x NVIDIA GeForce RTX 4090 24GB GPUs +# (vLLM on GPUs 0,1 and training on GPUs 2,3): +# Data preprocessing: 8 seconds +# vLLM server startup: 54 seconds +# Training (5 epochs): 1337 seconds (22 mins 17 secs) +# Total (prepare_data start to checkpoint save): 1427 seconds (23 mins 47 secs) +# +# Final validation metrics from that run: +# val/loss_epoch: 8.4479 +# val/full_acc_0_epoch: 58.55% +# val/full_acc_1_epoch: 32.46% +# val/full_acc_2_epoch: 18.22% + +set -euo pipefail + +# ============ Configuration ============ +MODEL="${MODEL:-Qwen/Qwen3-VL-4B-Instruct}" +DATASET_REPO="${DATASET_REPO:-hao05/llava-cot-5k-reannotated}" +DATASET_DIR="${DATASET_DIR:-./data/llava-cot-5k-reannotated}" +DATASET_JSONL="$DATASET_DIR/train.absolute_paths.jsonl" +OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_vl_4b_llava_cot_online}" +HIDDEN_STATES_DIR="$OUTPUT_DIR/hidden_states_online" +CHECKPOINT_DIR="$OUTPUT_DIR/checkpoints" +VLLM_PORT="${VLLM_PORT:-8000}" +MAX_SAMPLES="${MAX_SAMPLES:-5000}" +SEQ_LENGTH="${SEQ_LENGTH:-4096}" +VLLM_MAX_MODEL_LEN="${VLLM_MAX_MODEL_LEN:-5120}" +VLLM_TP="${VLLM_TP:-2}" +EPOCHS="${EPOCHS:-5}" +LR="${LR:-1e-4}" +VLLM_EXTRA_ARGS="${VLLM_EXTRA_ARGS:-}" + +# GPU assignments +VLLM_GPUS="${VLLM_GPUS:-0,1}" +TRAIN_GPUS="${TRAIN_GPUS:-2,3}" +NUM_TRAIN_GPUS="${NUM_TRAIN_GPUS:-2}" +# ======================================= + +# Optional mirror for environments without direct access to huggingface.co +# export HF_ENDPOINT=https://hf-mirror.com + +mkdir -p "$DATASET_DIR" "$OUTPUT_DIR" +read -r -a VLLM_EXTRA_ARR <<< "$VLLM_EXTRA_ARGS" + +echo "=== Step 1: Downloading dataset snapshot ===" +hf download "$DATASET_REPO" \ + --repo-type dataset \ + --local-dir "$DATASET_DIR" \ + --include "README.md" \ + --include "data/*.parquet" + +echo "=== Step 2: Materializing Parquet dataset to absolute-path JSONL ===" +python - "$DATASET_DIR" "$MAX_SAMPLES" <<'PY' +import json +import sys +from pathlib import Path + +from datasets import Image, load_dataset + +dataset_dir = Path(sys.argv[1]).resolve() +max_samples_arg = sys.argv[2] +max_samples = None +if max_samples_arg and max_samples_arg.lower() not in {"0", "all", "none"}: + max_samples = int(max_samples_arg) +dst = dataset_dir / "train.absolute_paths.jsonl" +parquet_files = sorted((dataset_dir / "data").glob("train-*.parquet")) +if not parquet_files: + raise FileNotFoundError(f"No Parquet shards found under {dataset_dir / 'data'}") + + +def absolutize_image_ref(image_ref: object) -> object: + if not isinstance(image_ref, str): + return image_ref + if image_ref.startswith(("http://", "https://", "/")): + return image_ref + return str((dataset_dir / image_ref).resolve()) + + +def safe_relative_path(image_path: object, row_idx: int) -> Path: + path_text = str(image_path) if isinstance(image_path, str) else f"images/{row_idx:08d}.jpg" + path = Path(path_text) + if path.is_absolute() or ".." in path.parts: + path = Path("images") / path.name + return path + + +def materialize_image(sample: dict, row_idx: int) -> str: + image = sample.get("image") + image_path = sample.get("image_path") + image_bytes = None + if isinstance(image, dict): + image_path = image_path or image.get("path") + image_bytes = image.get("bytes") + rel_path = safe_relative_path(image_path, row_idx) + image_file = dataset_dir / rel_path + if image_bytes is not None: + image_file.parent.mkdir(parents=True, exist_ok=True) + image_file.write_bytes(image_bytes) + elif not image_file.exists(): + raise FileNotFoundError(f"Missing image bytes and file for row {row_idx}: {image_path}") + return str(image_file.resolve()) + + +ds = load_dataset( + "parquet", + data_files={"train": [str(path) for path in parquet_files]}, + split="train", +).cast_column("image", Image(decode=False)) + +count = 0 +with dst.open("w", encoding="utf-8") as fout: + for row_idx, sample in enumerate(ds): + if max_samples is not None and count >= max_samples: + break + + sample["image"] = materialize_image(sample, row_idx) + sample.pop("image_path", None) + + for turn in sample.get("conversations", []): + content = turn.get("content") + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") in {"image", "image_url"}: + if "image" in item: + item["image"] = absolutize_image_ref(item["image"]) + elif isinstance(item.get("image_url"), dict): + url = item["image_url"].get("url") + item["image_url"]["url"] = absolutize_image_ref(url) + + fout.write(json.dumps(sample, ensure_ascii=False) + "\n") + count += 1 + +print(f"Wrote {count} rows to {dst}") +PY + +echo "=== Step 3: Preparing multimodal data ===" +python scripts/prepare_data.py \ + --model "$MODEL" \ + --data "$DATASET_JSONL" \ + --output "$OUTPUT_DIR" \ + --max-samples "$MAX_SAMPLES" \ + --seq-length "$SEQ_LENGTH" \ + --multimodal + +echo "=== Step 4: Launching vLLM server ===" +CUDA_VISIBLE_DEVICES="$VLLM_GPUS" python scripts/launch_vllm.py "$MODEL" \ + --hidden-states-path "$HIDDEN_STATES_DIR" \ + -- \ + --port "$VLLM_PORT" \ + --tensor-parallel-size "$VLLM_TP" \ + --max-model-len "$VLLM_MAX_MODEL_LEN" \ + --limit-mm-per-prompt '{"image":1}' \ + "${VLLM_EXTRA_ARR[@]}" & +VLLM_PID=$! + +cleanup() { + echo "Stopping vLLM server..." + kill "$VLLM_PID" 2>/dev/null || true + wait "$VLLM_PID" 2>/dev/null || true +} +trap cleanup EXIT + +echo "Waiting for vLLM server to be ready..." +until curl -sf "http://localhost:${VLLM_PORT}/health" > /dev/null 2>&1; do + sleep 2 +done +echo "vLLM server ready." + +echo "=== Step 5: Online training ===" +CUDA_VISIBLE_DEVICES="$TRAIN_GPUS" torchrun \ + --standalone --nproc_per_node "$NUM_TRAIN_GPUS" \ + scripts/train.py \ + --verifier-name-or-path "$MODEL" \ + --data-path "$OUTPUT_DIR" \ + --hidden-states-path "$HIDDEN_STATES_DIR" \ + --vllm-endpoint "http://localhost:${VLLM_PORT}/v1" \ + --save-path "$CHECKPOINT_DIR" \ + --epochs "$EPOCHS" \ + --lr "$LR" \ + --total-seq-len "$SEQ_LENGTH" \ + --num-layers 1 \ + --ttt-steps 3 \ + --ttt-step-loss-decay 1.0 \ + --on-missing generate \ + --on-generate cache \ + --run-name eagle3_qwen3_vl_4b_llava_cot_5k_online + +echo "Done. Checkpoints saved to $CHECKPOINT_DIR" diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 04882c09b..517b61a74 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -246,6 +246,26 @@ def check_safetensors_file(path: Path, tokens: list[int]): ) +def _to_token_id_list(value: Any) -> list[int]: + if hasattr(value, "tolist"): + return value.tolist() + return list(value) + + +def _build_queue_item(idx: int, item: dict[str, Any]) -> dict[str, Any]: + queue_item: dict[str, Any] = { + "idx": idx, + "input_ids": _to_token_id_list(item["input_ids"]), + } + + if "prompt" in item: + queue_item["prompt"] = item["prompt"] + if "multi_modal_data" in item: + queue_item["multi_modal_data"] = item["multi_modal_data"] + + return queue_item + + async def worker( client, model: str, @@ -276,7 +296,9 @@ async def worker( queue.task_done() continue - input_ids = item["input_ids"].tolist() + input_ids = item["input_ids"] + prompt = item.get("prompt") + multi_modal_data = item.get("multi_modal_data") target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" @@ -286,6 +308,8 @@ async def worker( client, model, input_ids, + prompt=prompt, + multi_modal_data=multi_modal_data, timeout=request_timeout, max_retries=max_retries, ) @@ -330,7 +354,7 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): # deadlocking when all workers have died. while not cancel_event.is_set(): try: - queue.put_nowait({"idx": i, "input_ids": item["input_ids"]}) + queue.put_nowait(_build_queue_item(i, item)) break except asyncio.QueueFull: await asyncio.sleep(0.1) @@ -368,6 +392,8 @@ async def generate_and_save_hidden_states(args, dataset): existing_file_indices = get_existing_hidden_state_indices(hidden_states_dir) num_samples = len(dataset) + if "multi_modal_data" in dataset.column_names: + logger.info("Detected multimodal preprocessed dataset") to_process = get_indices_to_process( num_samples, args.max_samples, existing_file_indices @@ -428,7 +454,7 @@ async def generate_and_save_hidden_states(args, dataset): await _shutdown_workers(workers, queue, cancel_event) num_saved = len(to_process) - len(skipped_indices) - logger.info(f"Saved {num_saved} new data points to {args.output}") + logger.info(f"Saved {num_saved} new data points to {hidden_states_dir}") if skipped_indices: logger.warning( f"Skipped {len(skipped_indices)} samples due to errors: {skipped_indices}" diff --git a/scripts/launch_vllm.py b/scripts/launch_vllm.py index a9bbb8573..e744122ad 100644 --- a/scripts/launch_vllm.py +++ b/scripts/launch_vllm.py @@ -57,9 +57,8 @@ def main(): from transformers import AutoConfig # noqa: PLC0415 config = AutoConfig.from_pretrained(args.model) - if hasattr(config, "text_config"): - config = config.text_config - num_hidden_layers = config.num_hidden_layers + text_config = config.text_config if hasattr(config, "text_config") else config + num_hidden_layers = text_config.num_hidden_layers if args.target_layer_ids: target_layer_ids = args.target_layer_ids @@ -67,7 +66,9 @@ def main(): target_layer_ids.append(num_hidden_layers) warnings.warn( f"Using custom target layer ids {target_layer_ids}. These " - "must also be explicitly passed into the training script.", + "must also be explicitly aligned in the training script. " + "If the final verifier layer is included here, pass only the " + "auxiliary layers to training.", stacklevel=2, ) else: @@ -78,12 +79,29 @@ def main(): num_hidden_layers, ] + draft_hf_config = {"eagle_aux_hidden_state_layer_ids": target_layer_ids} + if text_config is not config: + # vLLM's ExtractHiddenStatesConfig flattens the draft config and does not + # preserve nested text_config for multimodal verifiers. Clear the nested + # text_config and copy the text-only shape fields onto the draft config so + # hidden-state extraction can derive the draft model shape from the + # flattened config. + draft_hf_config["text_config"] = None + for field_name in ( + "num_attention_heads", + "num_hidden_layers", + "hidden_size", + "num_key_value_heads", + "head_dim", + ): + field_value = getattr(text_config, field_name, None) + if field_value is not None: + draft_hf_config[field_name] = field_value + speculative_config = { "method": "extract_hidden_states", "num_speculative_tokens": 1, - "draft_model_config": { - "hf_config": {"eagle_aux_hidden_state_layer_ids": target_layer_ids} - }, + "draft_model_config": {"hf_config": draft_hf_config}, } kv_transfer_config = { "kv_connector": "ExampleHiddenStatesConnector", diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 3390d9eef..89a56d997 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -134,6 +134,14 @@ def parse_args(): "trainable tokens." ), ) + parser.add_argument( + "--multimodal", + action="store_true", + help=( + "Enable multimodal preprocessing with AutoProcessor and preserve " + "the prompt/multi_modal_data needed for hidden-state generation." + ), + ) return parser.parse_args() @@ -177,6 +185,7 @@ def main(): assistant_pattern=args.assistant_pattern, turn_dropout=args.turn_dropout, minimum_valid_tokens=args.minimum_valid_tokens, + is_multimodal=args.multimodal, ) log.info("Done preparing data") diff --git a/scripts/train.py b/scripts/train.py index cbcfabcae..508a73958 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -66,6 +66,7 @@ def setup_dataloader( world_size: int, local_rank: int, hidden_size: int, + hidden_states_dtype: torch.dtype, num_workers: int = 12, prefetch_factor: int = 4, preprocess=None, @@ -96,7 +97,12 @@ def setup_dataloader( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=create_collate_fn(args.total_seq_len, hidden_size, preprocess), + collate_fn=create_collate_fn( + args.total_seq_len, + hidden_size, + hidden_states_dtype=hidden_states_dtype, + preprocess=preprocess, + ), persistent_workers=True, ) @@ -339,6 +345,7 @@ def main(args: argparse.Namespace): world_size, local_rank, transformer_layer_config.hidden_size, + hidden_states_dtype, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, preprocess=preprocess, @@ -348,6 +355,7 @@ def main(args: argparse.Namespace): world_size, local_rank, transformer_layer_config.hidden_size, + hidden_states_dtype, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, preprocess=preprocess, diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 3adf0c18b..9d413d2a4 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -1,6 +1,7 @@ import bisect import random import re +from functools import partial from pathlib import Path from re import Pattern from typing import Any, cast @@ -8,7 +9,7 @@ import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger @@ -58,6 +59,103 @@ def _visualize_sample(preprocessed, tokenizer, idx: int = 0): log.info(highlighted) +def _normalize_content(content: Any) -> Any: + """Normalize multimodal content into a processor-friendly representation.""" + + def _strip_image_tokens(text: str) -> str: + return re.sub(r"", "", text, flags=re.IGNORECASE).strip() + + if isinstance(content, list): + normalized_items = [] + for item in content: + if isinstance(item, str): + normalized_items.append( + {"type": "text", "text": _strip_image_tokens(item)} + ) + continue + if isinstance(item, dict): + image_ref = _get_image_ref(item) + if image_ref is not None: + normalized_items.append({"type": "image", "image": image_ref}) + continue + text_val = item.get("text") + if text_val is not None: + normalized_items.append( + {"type": "text", "text": _strip_image_tokens(text_val)} + ) + continue + normalized_items.append(item) + return normalized_items + return content + + +def _get_conversations_from_examples(examples: dict) -> list: + """Extract conversations/messages from a batched dataset example.""" + if "conversations" in examples: + return examples.get("conversations", []) + if "messages" in examples: + return examples.get("messages", []) + return [] + + +def _flatten_singleton_batch(value: Any, *, field_name: str) -> Any: + """Collapse a singleton batch dimension from HF outputs.""" + if isinstance(value, torch.Tensor): + value = value.tolist() + + if isinstance(value, list) and value and isinstance(value[0], list): + if len(value) != 1: + raise ValueError( + f"{field_name} returned non-singleton batch: batch_size={len(value)}" + ) + return value[0] + + return value + + +def _get_tokenizer_from_processor(processor: Any) -> PreTrainedTokenizerBase: + """Extract the tokenizer from a processor.""" + tokenizer = getattr(processor, "tokenizer", None) + if tokenizer is None: + raise ValueError("Processor does not provide a tokenizer attribute.") + return cast("PreTrainedTokenizerBase", tokenizer) + + +def _get_image_ref(item: dict) -> Any | None: + """Extract a serializable image reference from a multimodal content item.""" + if item.get("type") not in ("image", "image_url"): + return None + + image_ref = item.get("image") + if image_ref is not None: + return image_ref + + image_url = item.get("image_url") + if isinstance(image_url, dict): + return image_url.get("url") + return image_url + + +def _extract_multimodal_data_from_conversation( + conv: list[dict], +) -> dict[str, list[Any]]: + """Build the minimal multi_modal_data payload expected by vLLM.""" + mm_data: dict[str, list[Any]] = {} + for turn in conv: + content = turn.get("content") + if not isinstance(content, list): + continue + + for item in content: + if not isinstance(item, dict): + continue + image_ref = _get_image_ref(item) + if image_ref is not None: + mm_data.setdefault("image", []).append(image_ref) + + return mm_data + + def _normalize_conversation( conv: list[dict], turn_dropout: bool = False, @@ -78,6 +176,7 @@ def _normalize_conversation( for i, turn in enumerate(conv): role = turn.get("from", turn.get("role", "")) content = turn.get("value") or turn.get("content") or "" + content = _normalize_content(content) # Map various role names to standard user/assistant if role in ("human", "user"): @@ -107,13 +206,13 @@ def _normalize_conversation( return normalized -def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: - """Check if tokenizer truly supports HF assistant token mask. +def _supports_assistant_mask(caller: Any) -> bool: + """Check if tokenizer/processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ try: - res_any = tokenizer.apply_chat_template( + res_any = caller.apply_chat_template( [{"role": "assistant", "content": "test"}], tokenize=True, return_assistant_tokens_mask=True, @@ -126,13 +225,14 @@ def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: return False # Verify the mask is not all zeros + mask = _flatten_singleton_batch(mask, field_name="assistant mask") return any(m == 1 for m in mask) except (TypeError, ValueError, KeyError, AttributeError): return False -def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: - """Auto-detect the assistant message pattern from the tokenizer's chat template. +def _detect_assistant_pattern(caller: Any) -> str: + """Auto-detect the assistant message pattern from a chat template caller. Uses multi-turn conversation but extracts pattern from the LAST assistant message only. @@ -144,7 +244,7 @@ def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: {"role": "assistant", "content": "ASSISTANT_MSG_2"}, ] - formatted = tokenizer.apply_chat_template( + formatted = caller.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False ) assert isinstance(formatted, str), "Expected string from apply_chat_template" @@ -266,7 +366,7 @@ def _preprocess_batch( """Process a batch of conversations into tokenized format with loss masks.""" results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} - conversations = examples.get("conversations", []) + conversations = _get_conversations_from_examples(examples) if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") @@ -294,14 +394,23 @@ def _preprocess_batch( encoded = cast("dict[str, Any]", encoded_any) # input IDs and loss mask - input_ids = encoded["input_ids"] + input_ids = _flatten_singleton_batch( + encoded["input_ids"], + field_name="Text apply_chat_template input_ids", + ) # HF uses 'assistant_masks' in recent versions mask_key = ( "assistant_masks" if "assistant_masks" in encoded else "assistant_mask" ) - loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) + loss_mask = torch.tensor( + _flatten_singleton_batch( + encoded[mask_key], + field_name="Text apply_chat_template assistant mask", + ), + dtype=torch.long, + ) else: # Fallback: regex-based detection @@ -325,8 +434,13 @@ def _preprocess_batch( ) # input IDs and loss mask - input_ids = encoding["input_ids"] - offsets = encoding["offset_mapping"] + input_ids = _flatten_singleton_batch( + encoding["input_ids"], field_name="Text tokenizer input_ids" + ) + offsets = _flatten_singleton_batch( + encoding["offset_mapping"], + field_name="Text tokenizer offset_mapping", + ) loss_mask = _create_loss_mask_from_offsets( formatted_raw, offsets, assistant_pattern @@ -359,6 +473,121 @@ def _preprocess_batch( return results +def _preprocess_batch_multimodal( # noqa: PLR0912, PLR0915 + examples: dict, + processor: Any, + max_length: int, + assistant_pattern: str | Pattern[str] | None, + turn_dropout: bool = False, + minimum_valid_tokens: int | None = None, +) -> dict[str, list]: + """Process a batch of multimodal conversations into token IDs and loss masks.""" + + results: dict[str, list] = { + "input_ids": [], + "loss_mask": [], + "multi_modal_data": [], + "prompt": [], + "seq_len": [], + } + conversations = _get_conversations_from_examples(examples) + + if not conversations: + log.warning(f"No conversations key found. Keys: {list(examples.keys())}") + return results + + tokenizer = _get_tokenizer_from_processor(processor) + + for idx, conv in enumerate(conversations): + if not conv or not isinstance(conv, list): + continue + + normalized_conv = _normalize_conversation(conv, turn_dropout) + if not normalized_conv: + continue + + try: + formatted_raw = processor.apply_chat_template( + normalized_conv, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(formatted_raw, str) + + if assistant_pattern is None: + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) + encoded = cast("dict[str, Any]", encoded_any) + + input_ids = _flatten_singleton_batch( + encoded["input_ids"], + field_name="Multimodal apply_chat_template input_ids", + ) + mask_key = ( + "assistant_masks" + if "assistant_masks" in encoded + else "assistant_mask" + ) + loss_mask = torch.tensor( + _flatten_singleton_batch( + encoded[mask_key], + field_name="Multimodal apply_chat_template assistant mask", + ), + dtype=torch.long, + ) + else: + encoding = tokenizer( + formatted_raw, + return_offsets_mapping=True, + max_length=max_length, + truncation=True, + add_special_tokens=False, + ) + + input_ids = _flatten_singleton_batch( + encoding["input_ids"], field_name="Multimodal tokenizer input_ids" + ) + offsets = _flatten_singleton_batch( + encoding["offset_mapping"], + field_name="Multimodal tokenizer offset_mapping", + ) + loss_mask = _create_loss_mask_from_offsets( + formatted_raw, offsets, assistant_pattern + ) + + assert len(input_ids) == len(loss_mask), ( + f"Shape mismatch: input_ids={len(input_ids)}, " + f"loss_mask={len(loss_mask)}" + ) + + if minimum_valid_tokens is not None: + num_valid_tokens = int(loss_mask.sum().item()) + if num_valid_tokens < minimum_valid_tokens: + continue + + mm_data = _extract_multimodal_data_from_conversation(normalized_conv) + + results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) + results["loss_mask"].append(loss_mask) + results["multi_modal_data"].append(mm_data) + results["prompt"].append(formatted_raw) + results["seq_len"].append(len(input_ids)) + + except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: + log.error( + f"Failed to process conversation {idx} " + f"(assistant_pattern={assistant_pattern is not None}): {e}" + ) + continue + + return results + + def build_eagle3_dataset( dataset: HFDataset, tokenizer: PreTrainedTokenizerBase, @@ -367,6 +596,7 @@ def build_eagle3_dataset( assistant_pattern: str | Pattern[str] | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, + processor: Any | None = None, ) -> HFDataset: """Build EAGLE3 dataset by tokenizing conversations and creating loss masks. @@ -384,27 +614,43 @@ def build_eagle3_dataset( conversation minimum_valid_tokens: Number of tokens to consider for a valid sample """ + is_multimodal = processor is not None + caller = processor if is_multimodal else tokenizer + # Detect and use provided assistant message pattern if assistant_pattern is not None: log.info(f"Using custom assistant pattern: {str(assistant_pattern)[:80]}...") - elif _supports_assistant_mask(tokenizer): + elif _supports_assistant_mask(caller): assistant_pattern = None # Signal to use HF mask in _preprocess_batch - log.info("Using HF assistant token mask for loss masking") + suffix = " (multimodal)" if is_multimodal else "" + log.info(f"Using HF assistant token mask for loss masking{suffix}") else: - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(caller) log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...") original_cols = dataset.column_names + if is_multimodal: + map_fn = partial( + _preprocess_batch_multimodal, + processor=processor, + max_length=max_length, + assistant_pattern=assistant_pattern, + turn_dropout=turn_dropout, + minimum_valid_tokens=minimum_valid_tokens, + ) + else: + map_fn = partial( + _preprocess_batch, + tokenizer=tokenizer, + max_length=max_length, + assistant_pattern=assistant_pattern, + turn_dropout=turn_dropout, + minimum_valid_tokens=minimum_valid_tokens, + ) + dataset = dataset.map( - lambda examples: _preprocess_batch( - examples, - tokenizer, - max_length, - assistant_pattern, - turn_dropout, - minimum_valid_tokens, - ), + map_fn, batched=True, num_proc=num_proc, batch_size=1000, @@ -412,7 +658,12 @@ def build_eagle3_dataset( keep_in_memory=True, # skip caching ) - dataset.set_format(type="torch") + if is_multimodal: + dataset.set_format( + type="torch", columns=["input_ids", "loss_mask"], output_all_columns=True + ) + else: + dataset.set_format(type="torch") return dataset @@ -436,7 +687,7 @@ def load_raw_dataset(train_data_path: str, num_proc: int = 8) -> HFDataset: return raw_dataset -def load_and_preprocess_dataset( +def load_and_preprocess_dataset( # noqa: PLR0912 target_model_path: str, train_data_paths: list[str], seq_length: int, @@ -447,6 +698,7 @@ def load_and_preprocess_dataset( assistant_pattern: str | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, + is_multimodal: bool | None = None, ) -> tuple[HFDataset, PreTrainedTokenizerBase]: """Load, tokenize, and preprocess a dataset for EAGLE3 training. @@ -480,12 +732,34 @@ def load_and_preprocess_dataset( f"Filtering samples with fewer than {minimum_valid_tokens} valid tokens" ) + if is_multimodal is None: + is_multimodal = False + log.info(f"Using multimodal mode: {is_multimodal}") + log.subsection("Loading tokenizer") - tokenizer = AutoTokenizer.from_pretrained(target_model_path, trust_remote_code=True) + if is_multimodal: + processor = AutoProcessor.from_pretrained( + target_model_path, trust_remote_code=True + ) + tokenizer = _get_tokenizer_from_processor(processor) + else: + processor = None + tokenizer = AutoTokenizer.from_pretrained( + target_model_path, trust_remote_code=True + ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: + if is_multimodal: + if not hasattr(processor, "apply_chat_template"): + raise ValueError( + f"Processor for {target_model_path} does not support " + "apply_chat_template. Please use a model with a pre-configured " + "chat template." + ) + elif ( + not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None + ): raise ValueError( f"Tokenizer for {target_model_path} does not support chat templates. " "Please use a model with a pre-configured chat template." @@ -516,13 +790,14 @@ def load_and_preprocess_dataset( assistant_pattern=assistant_pattern, turn_dropout=turn_dropout, minimum_valid_tokens=minimum_valid_tokens, + processor=processor, ) if minimum_valid_tokens is not None: log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") processed_datasets.append(preprocessed_dataset) combined_dataset = concatenate_datasets(processed_datasets) - combined_dataset.shuffle(seed=seed) + combined_dataset = combined_dataset.shuffle(seed=seed) if max_samples is not None and len(combined_dataset) > max_samples: combined_dataset = combined_dataset.select(range(max_samples)) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index da766a656..9679a5884 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -2,6 +2,7 @@ import functools import logging import time +from typing import Any, cast import openai @@ -99,11 +100,37 @@ def extract_output(completion, token_ids) -> str: return completion.kv_transfer_params.get("hidden_states_path") +def _build_request_payload( + token_ids: list[int], + prompt: str | None = None, + multi_modal_data: dict[str, Any] | None = None, +) -> tuple[list[int] | dict[str, Any], dict[str, Any]]: + """Build a vLLM completion request payload. + + Text-only requests can pass token IDs directly. Requests that need the + original prompt text or multimodal inputs use vLLM's structured prompt form. + """ + extra_body: dict[str, Any] = {"return_token_ids": True} + + if prompt is None and multi_modal_data is None: + return token_ids, extra_body + + request_prompt: dict[str, Any] = {"prompt_token_ids": token_ids} + if prompt is not None: + request_prompt["prompt"] = prompt + if multi_modal_data is not None: + request_prompt["multi_modal_data"] = multi_modal_data + + return request_prompt, extra_body + + @with_retries async def generate_hidden_states_async( client: openai.AsyncClient, model: str, token_ids: list[int], + prompt: str | None = None, + multi_modal_data: dict[str, Any] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ @@ -114,13 +141,18 @@ async def generate_hidden_states_async( client: The async OpenAI client. model: The model ID. token_ids: The input token IDs. + prompt: Optional prompt text corresponding to ``token_ids``. + multi_modal_data: Optional multimodal payload expected by vLLM. timeout: Timeout in seconds for each request attempt. None for no timeout. """ + request_prompt, extra_body = _build_request_payload( + token_ids, prompt, multi_modal_data + ) coro = client.completions.create( model=model, - prompt=token_ids, + prompt=cast("Any", request_prompt), max_tokens=1, - extra_body={"return_token_ids": True}, + extra_body=extra_body, timeout=timeout, ) if timeout is not None: @@ -136,17 +168,22 @@ def generate_hidden_states( client: openai.Client, model: str, token_ids: list[int], + prompt: str | None = None, + multi_modal_data: dict[str, Any] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ + request_prompt, extra_body = _build_request_payload( + token_ids, prompt, multi_modal_data + ) completion = client.completions.create( model=model, - prompt=token_ids, + prompt=cast("Any", request_prompt), max_tokens=1, - extra_body={"return_token_ids": True}, + extra_body=extra_body, timeout=timeout, ) return extract_output(completion, token_ids) diff --git a/src/speculators/models/utils.py b/src/speculators/models/utils.py index 0ebb8abca..edb025d8a 100644 --- a/src/speculators/models/utils.py +++ b/src/speculators/models/utils.py @@ -21,10 +21,25 @@ def resolve_target_layer_ids( target_layer_ids: list[int] | None, verifier_name_or_path: str, ) -> list[int]: + num_layers = get_verifier_config(verifier_name_or_path).num_hidden_layers + if target_layer_ids is not None: - return target_layer_ids + # Offline datagen extracts auxiliary layers plus the verifier's final layer. + # Training stores the final layer separately as `verifier_last_hidden_states`, + # so the draft config should only keep the auxiliary layers. + aux_target_layer_ids = [ + layer_id for layer_id in target_layer_ids if layer_id != num_layers + ] + if len(aux_target_layer_ids) != len(target_layer_ids): + warnings.warn( + "Stripping the verifier's final layer " + f"({num_layers}) from --target-layer-ids for training. " + "The last extracted layer is consumed separately as " + "`verifier_last_hidden_states`.", + stacklevel=2, + ) + return aux_target_layer_ids - num_layers = get_verifier_config(verifier_name_or_path).num_hidden_layers target_layer_ids = [2, num_layers // 2, num_layers - 3] warnings.warn( DEFAULT_TARGET_LAYER_IDS_WARNING.format(target_layer_ids=target_layer_ids), diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 97ec87828..5351320a4 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -65,7 +65,9 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0): StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]] -def create_empty_sample(hidden_size: int): +def create_empty_sample( + hidden_size: int, hidden_states_dtype: torch.dtype = torch.bfloat16 +): # data structure: { # "hidden_states": [seq_len, 3 * hidden_size], # "input_ids": [seq_len], @@ -76,10 +78,12 @@ def create_empty_sample(hidden_size: int): # } return { - "hidden_states": torch.empty(0, 3 * hidden_size), - "input_ids": torch.empty(0), - "verifier_last_hidden_states": torch.empty(0, hidden_size), - "loss_mask": torch.empty(0), + "hidden_states": torch.empty(0, 3 * hidden_size, dtype=hidden_states_dtype), + "input_ids": torch.empty(0, dtype=torch.long), + "verifier_last_hidden_states": torch.empty( + 0, hidden_size, dtype=hidden_states_dtype + ), + "loss_mask": torch.empty(0, dtype=torch.long), "lengths": torch.tensor([0], dtype=torch.long), "position_ids": torch.arange(0, dtype=torch.long), } @@ -258,12 +262,17 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: if not self.client: self._setup_client() - input_ids = self.data[index]["input_ids"].tolist() + sample = self.data[index] + input_ids = sample["input_ids"] + if hasattr(input_ids, "tolist"): + input_ids = input_ids.tolist() try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] input_ids, + prompt=sample.get("prompt"), + multi_modal_data=sample.get("multi_modal_data"), timeout=self.request_timeout, max_retries=self.max_retries, ) @@ -431,6 +440,7 @@ def _get_raw_data(self, index): def create_collate_fn( max_len: int, hidden_size: int, + hidden_states_dtype: torch.dtype = torch.bfloat16, preprocess: Callable[[BatchType], BatchType] | None = None, ): def collate_fn(batch: list[BatchType | None]) -> BatchType: @@ -440,7 +450,7 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType: if not batch: # Create empty sample which then gets padded to full # batch size if no valid samples are found - batch = [create_empty_sample(hidden_size)] + batch = [create_empty_sample(hidden_size, hidden_states_dtype)] collated_data = {} for key in batch[0]: # type: ignore[union-attr] diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index a9184c18f..f7aa3a0d2 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -3,12 +3,14 @@ """ import re +from typing import Any, cast import pytest import torch from datasets import Dataset as HFDataset -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from speculators.data_generation import preprocessing from speculators.data_generation.preprocessing import ( _create_loss_mask_from_offsets, _detect_assistant_pattern, @@ -77,6 +79,125 @@ def test_normalize_conversation_unknown_role(): assert result[1]["role"] == "assistant" +@pytest.mark.sanity +def test_normalize_conversation_multimodal_content(): + """Test normalization of multimodal content items.""" + conv: list[dict[str, Any]] = [ + { + "role": "user", + "content": [ + "Look at this ", + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + {"text": "Caption "}, + ], + }, + {"role": "assistant", "content": "Nice cat."}, + ] + + result = _normalize_conversation(conv) + + assert result[0]["content"] == [ + {"type": "text", "text": "Look at this"}, + {"type": "image", "image": "https://example.com/cat.png"}, + {"type": "text", "text": "Caption"}, + ] + + +@pytest.mark.sanity +def test_preprocess_batch_supports_messages_schema(): + """Test that preprocessing accepts datasets using the messages field.""" + + class DummyTokenizer: + chat_template = "dummy-template" + + def apply_chat_template(self, *args, **kwargs): + if kwargs.get("tokenize"): + return {"input_ids": [[11, 12, 13]], "assistant_mask": [[0, 1, 1]]} + return "formatted" + + results = _preprocess_batch( + { + "messages": [ + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + ] + }, + cast("PreTrainedTokenizerBase", DummyTokenizer()), + max_length=64, + assistant_pattern=None, + ) + + assert len(results["input_ids"]) == 1 + assert len(results["loss_mask"]) == 1 + + +@pytest.mark.sanity +def test_load_and_preprocess_dataset_defaults_to_text_mode(monkeypatch): + """Test that omitted multimodal flag falls back to text-only processing.""" + raw_dataset = HFDataset.from_dict( + { + "conversations": [ + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + ] + } + ) + + class DummyTokenizer: + def __init__(self): + self.pad_token = None + self.eos_token = "" + self.chat_template = "dummy-template" + + def apply_chat_template(self, *args, **kwargs): + return "formatted" + + dummy_tokenizer = DummyTokenizer() + captured: dict[str, object] = {} + + def fake_build_eagle3_dataset(*args, **kwargs): + captured["processor"] = kwargs.get("processor") + return raw_dataset + + monkeypatch.setattr( + preprocessing, "load_raw_dataset", lambda *args, **kwargs: raw_dataset + ) + monkeypatch.setattr( + preprocessing.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: dummy_tokenizer, + ) + monkeypatch.setattr( + preprocessing, "build_eagle3_dataset", fake_build_eagle3_dataset + ) + monkeypatch.setattr( + preprocessing, "save_token_frequency_distribution", lambda *args, **kwargs: None + ) + monkeypatch.setattr( + preprocessing, "_visualize_sample", lambda *args, **kwargs: None + ) + + dataset, tokenizer = preprocessing.load_and_preprocess_dataset( + target_model_path="dummy-model", + train_data_paths=["dummy-data"], + seq_length=128, + build_dataset_num_proc=1, + is_multimodal=None, + ) + + assert len(dataset) == len(raw_dataset) + assert tokenizer is dummy_tokenizer + assert captured["processor"] is None + assert dummy_tokenizer.pad_token == dummy_tokenizer.eos_token + + # Tests for _detect_assistant_pattern @pytest.mark.sanity def test_detect_assistant_pattern_structure(): @@ -926,3 +1047,74 @@ def test_build_eagle3_dataset_with_custom_pattern(): # Should successfully build dataset with custom pattern assert isinstance(result, HFDataset) assert len(result) > 0 + + +@pytest.mark.sanity +def test_build_eagle3_dataset_multimodal_preserves_prompt_and_mm_data(): + """Test multimodal preprocessing preserves prompt text and image metadata.""" + + class DummyTokenizer: + pad_token = "" + eos_token = "" + + def __call__( + self, + text, + return_offsets_mapping=False, + max_length=None, + truncation=False, + add_special_tokens=False, + ): + del text, max_length, truncation, add_special_tokens + if return_offsets_mapping: + return { + "input_ids": [101, 102, 103], + "offset_mapping": [(0, 1), (1, 2), (2, 3)], + } + return {"input_ids": [101, 102, 103]} + + class DummyProcessor: + def __init__(self): + self.tokenizer = DummyTokenizer() + + def apply_chat_template(self, *args, **kwargs): + if kwargs.get("tokenize"): + return {"input_ids": [[101, 102, 103]], "assistant_mask": [[0, 1, 1]]} + return "formatted multimodal prompt" + + dataset = HFDataset.from_dict( + { + "messages": [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe "}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "A cat."}], + }, + ] + ] + } + ) + + processor = DummyProcessor() + result = build_eagle3_dataset( + dataset, + cast("PreTrainedTokenizerBase", processor.tokenizer), + max_length=64, + num_proc=1, + processor=processor, + ) + + assert "prompt" in result.column_names + assert "multi_modal_data" in result.column_names + assert result[0]["prompt"] == "formatted multimodal prompt" + assert result[0]["multi_modal_data"] == {"image": ["https://example.com/cat.png"]} diff --git a/tests/unit/convert/test_eagle3_converter.py b/tests/unit/convert/test_eagle3_converter.py index edb2f19d4..d8f46b136 100644 --- a/tests/unit/convert/test_eagle3_converter.py +++ b/tests/unit/convert/test_eagle3_converter.py @@ -259,15 +259,16 @@ def test_nm_testing_2layer_eagle3_model_config(self, tmp_path): # TODO: Use a real model in the future # noqa: FIX002 checkpoint_path = "nm-testing/testing-llama3.1.8b-2layer-eagle3" base_model = "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic" + output_path = tmp_path / checkpoint_path.rsplit("/", maxsplit=1)[-1] converter.convert( checkpoint_path, - tmp_path / checkpoint_path.split("/")[-1], + output_path, base_model, norm_before_residual=False, ) - config = load_checkpoint_config(tmp_path / checkpoint_path.split("/")[-1]) + config = load_checkpoint_config(output_path) # Verify that num_hidden_layers is correctly set to 2 assert config["transformer_layer_config"]["num_hidden_layers"] == 2 diff --git a/tests/unit/data_generation/__init__.py b/tests/unit/data_generation/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/unit/data_generation/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/data_generation/test_vllm_client.py b/tests/unit/data_generation/test_vllm_client.py new file mode 100644 index 000000000..1e1870d2a --- /dev/null +++ b/tests/unit/data_generation/test_vllm_client.py @@ -0,0 +1,106 @@ +import asyncio + +from speculators.data_generation.vllm_client import ( + generate_hidden_states, + generate_hidden_states_async, +) + + +class _DummyChoice: + def __init__(self, prompt_token_ids): + self.prompt_token_ids = prompt_token_ids + + +class _DummyCompletion: + def __init__(self, prompt_token_ids): + self.choices = [_DummyChoice(prompt_token_ids)] + self.kv_transfer_params = {"hidden_states_path": "/tmp/hs_0.safetensors"} + + +class _DummySyncCompletions: + def __init__(self): + self.calls = [] + + def create(self, **kwargs): + self.calls.append(kwargs) + prompt = kwargs["prompt"] + prompt_token_ids = ( + prompt["prompt_token_ids"] if isinstance(prompt, dict) else prompt + ) + return _DummyCompletion(prompt_token_ids) + + +class _DummyAsyncCompletions: + def __init__(self): + self.calls = [] + + async def create(self, **kwargs): + self.calls.append(kwargs) + prompt = kwargs["prompt"] + prompt_token_ids = ( + prompt["prompt_token_ids"] if isinstance(prompt, dict) else prompt + ) + return _DummyCompletion(prompt_token_ids) + + +class _DummySyncClient: + def __init__(self): + self.completions = _DummySyncCompletions() + + +class _DummyAsyncClient: + def __init__(self): + self.completions = _DummyAsyncCompletions() + + +def test_generate_hidden_states_text_prompt(): + client = _DummySyncClient() + + result = generate_hidden_states(client, "dummy-model", [1, 2, 3], timeout=1) + + assert result == "/tmp/hs_0.safetensors" + assert client.completions.calls[0]["prompt"] == [1, 2, 3] + + +def test_generate_hidden_states_multimodal_prompt(): + client = _DummySyncClient() + multi_modal_data = {"image": ["https://example.com/cat.png"]} + + result = generate_hidden_states( + client, + "dummy-model", + [4, 5, 6], + prompt="formatted prompt", + multi_modal_data=multi_modal_data, + timeout=1, + ) + + assert result == "/tmp/hs_0.safetensors" + assert client.completions.calls[0]["prompt"] == { + "prompt_token_ids": [4, 5, 6], + "prompt": "formatted prompt", + "multi_modal_data": multi_modal_data, + } + + +def test_generate_hidden_states_async_multimodal_prompt(): + client = _DummyAsyncClient() + multi_modal_data = {"image": ["https://example.com/cat.png"]} + + result = asyncio.run( + generate_hidden_states_async( + client, + "dummy-model", + [7, 8, 9], + prompt="formatted prompt", + multi_modal_data=multi_modal_data, + timeout=1, + ) + ) + + assert result == "/tmp/hs_0.safetensors" + assert client.completions.calls[0]["prompt"] == { + "prompt_token_ids": [7, 8, 9], + "prompt": "formatted prompt", + "multi_modal_data": multi_modal_data, + } diff --git a/tests/unit/models/test_utils.py b/tests/unit/models/test_utils.py new file mode 100644 index 000000000..c4c99fa81 --- /dev/null +++ b/tests/unit/models/test_utils.py @@ -0,0 +1,34 @@ +from types import SimpleNamespace + +import pytest + +from speculators.models import utils + + +def test_resolve_target_layer_ids_keeps_aux_layers_only(monkeypatch): + monkeypatch.setattr( + utils, + "get_verifier_config", + lambda _name_or_path: SimpleNamespace(num_hidden_layers=36), + ) + + with pytest.warns(UserWarning, match="Stripping the verifier's final layer"): + layer_ids = utils.resolve_target_layer_ids( + [2, 18, 33, 36], "unused-verifier-path" + ) + + assert layer_ids == [2, 18, 33] + + +def test_resolve_target_layer_ids_preserves_custom_aux_layers(monkeypatch): + monkeypatch.setattr( + utils, + "get_verifier_config", + lambda _name_or_path: SimpleNamespace(num_hidden_layers=36), + ) + + assert utils.resolve_target_layer_ids([2, 18, 33], "unused-verifier-path") == [ + 2, + 18, + 33, + ] diff --git a/tests/unit/train/test_data.py b/tests/unit/train/test_data.py index 4142568d1..ac0b5f6ea 100644 --- a/tests/unit/train/test_data.py +++ b/tests/unit/train/test_data.py @@ -243,6 +243,26 @@ def test_collate_fn_length_truncation(): assert collated[key].shape[1] == max_len +def test_collate_fn_empty_batch_uses_training_dtypes(): + """Test that all-skipped batches keep training-compatible dtypes.""" + max_len = 8 + hidden_size = 4 + collate_fn = create_collate_fn( + max_len, + hidden_size, + hidden_states_dtype=torch.bfloat16, + ) + + collated = collate_fn([None]) + + assert collated["hidden_states"].dtype == torch.bfloat16 + assert collated["verifier_last_hidden_states"].dtype == torch.bfloat16 + assert collated["input_ids"].dtype == torch.long + assert collated["loss_mask"].dtype == torch.long + assert collated["position_ids"].dtype == torch.long + assert collated["lengths"].dtype == torch.long + + def test_dataset_getitem_v1_format(tmp_path: Path): """Test dataset __getitem__ with v1 data format and dtype conversion.""" From 84e43c94d6e08f3e098e08e6d2b2b3af1f81359e Mon Sep 17 00:00:00 2001 From: Haoxiang Sun Date: Wed, 29 Apr 2026 22:36:07 +0800 Subject: [PATCH 2/5] Adapt multimodal datagen for vLLM 0.20 Signed-off-by: Haoxiang Sun --- scripts/data_generation_offline.py | 14 +- scripts/launch_vllm.py | 46 +++++ scripts/prepare_data.py | 2 +- .../data_generation/preprocessing.py | 163 +++++++++++++-- .../data_generation/vllm_client.py | 192 +++++++++++++----- src/speculators/train/data.py | 4 +- .../integration/datagen/test_preprocessing.py | 66 +++++- .../unit/data_generation/test_vllm_client.py | 140 +++++++++++-- 8 files changed, 522 insertions(+), 105 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 517b61a74..d041eccd7 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -258,10 +258,8 @@ def _build_queue_item(idx: int, item: dict[str, Any]) -> dict[str, Any]: "input_ids": _to_token_id_list(item["input_ids"]), } - if "prompt" in item: - queue_item["prompt"] = item["prompt"] - if "multi_modal_data" in item: - queue_item["multi_modal_data"] = item["multi_modal_data"] + if "messages" in item: + queue_item["messages"] = item["messages"] return queue_item @@ -297,8 +295,7 @@ async def worker( continue input_ids = item["input_ids"] - prompt = item.get("prompt") - multi_modal_data = item.get("multi_modal_data") + messages = item.get("messages") target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" @@ -308,8 +305,7 @@ async def worker( client, model, input_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, + messages=messages, timeout=request_timeout, max_retries=max_retries, ) @@ -392,7 +388,7 @@ async def generate_and_save_hidden_states(args, dataset): existing_file_indices = get_existing_hidden_state_indices(hidden_states_dir) num_samples = len(dataset) - if "multi_modal_data" in dataset.column_names: + if "messages" in dataset.column_names: logger.info("Detected multimodal preprocessed dataset") to_process = get_indices_to_process( diff --git a/scripts/launch_vllm.py b/scripts/launch_vllm.py index e744122ad..13ac3ae4e 100644 --- a/scripts/launch_vllm.py +++ b/scripts/launch_vllm.py @@ -49,6 +49,42 @@ def parse_args(): return parser.parse_known_args() +def _is_multimodal_config(config) -> bool: + if any( + hasattr(config, field_name) + for field_name in ( + "vision_config", + "visual_config", + "image_token_id", + "video_token_id", + ) + ): + return True + + model_type = str(getattr(config, "model_type", "")).lower() + if any( + marker in model_type + for marker in ("vl", "vision", "llava", "mllama", "paligemma", "pixtral") + ): + return True + + architectures = getattr(config, "architectures", []) or [] + return any( + any( + marker in str(architecture).lower() + for marker in ( + "vl", + "vision", + "llava", + "mllama", + "paligemma", + "pixtral", + ) + ) + for architecture in architectures + ) + + def main(): args, vllm_args = parse_args() if "--" in vllm_args: @@ -60,6 +96,16 @@ def main(): text_config = config.text_config if hasattr(config, "text_config") else config num_hidden_layers = text_config.num_hidden_layers + if _is_multimodal_config(config) and "--enforce-eager" not in vllm_args: + # vLLM 0.20 multimodal hidden-state extraction can hit CUDA graph + # shape mismatches after image/video token expansion. Eager mode avoids + # that runtime crash. + vllm_args.append("--enforce-eager") + warnings.warn( + "Adding --enforce-eager for multimodal hidden-state extraction.", + stacklevel=2, + ) + if args.target_layer_ids: target_layer_ids = args.target_layer_ids if args.include_last_layer and num_hidden_layers not in target_layer_ids: diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 89a56d997..24fc2bd1a 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -139,7 +139,7 @@ def parse_args(): action="store_true", help=( "Enable multimodal preprocessing with AutoProcessor and preserve " - "the prompt/multi_modal_data needed for hidden-state generation." + "messages needed for vLLM chat hidden-state generation." ), ) return parser.parse_args() diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 9d413d2a4..33fdb59dc 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -5,11 +5,13 @@ from pathlib import Path from re import Pattern from typing import Any, cast +from urllib.parse import unquote, urlparse import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase +from transformers.image_utils import load_image from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger @@ -121,9 +123,18 @@ def _get_tokenizer_from_processor(processor: Any) -> PreTrainedTokenizerBase: return cast("PreTrainedTokenizerBase", tokenizer) +def _load_image_for_processor(image_ref: Any) -> Any: + """Load a local/data image reference for HF multimodal processors.""" + if isinstance(image_ref, str): + parsed = urlparse(image_ref) + if parsed.scheme == "file": + image_ref = unquote(parsed.path) + return load_image(image_ref) + + def _get_image_ref(item: dict) -> Any | None: """Extract a serializable image reference from a multimodal content item.""" - if item.get("type") not in ("image", "image_url"): + if item.get("type") not in ("image", "image_url", "input_image"): return None image_ref = item.get("image") @@ -136,11 +147,9 @@ def _get_image_ref(item: dict) -> Any | None: return image_url -def _extract_multimodal_data_from_conversation( - conv: list[dict], -) -> dict[str, list[Any]]: - """Build the minimal multi_modal_data payload expected by vLLM.""" - mm_data: dict[str, list[Any]] = {} +def _extract_processor_images_from_conversation(conv: list[dict]) -> list[Any]: + """Extract image inputs in the same order as the chat template placeholders.""" + images = [] for turn in conv: content = turn.get("content") if not isinstance(content, list): @@ -151,9 +160,114 @@ def _extract_multimodal_data_from_conversation( continue image_ref = _get_image_ref(item) if image_ref is not None: - mm_data.setdefault("image", []).append(image_ref) + images.append(_load_image_for_processor(image_ref)) + + return images + - return mm_data +def _get_image_token_ids(tokenizer: PreTrainedTokenizerBase) -> set[int]: + """Return known image placeholder token IDs for VL token expansion.""" + image_token_ids = set() + convert_tokens_to_ids = getattr(tokenizer, "convert_tokens_to_ids", None) + if not callable(convert_tokens_to_ids): + return image_token_ids + + unk_token_id = getattr(tokenizer, "unk_token_id", None) + for token in ("<|image_pad|>", "", ""): + token_id = convert_tokens_to_ids(token) + if isinstance(token_id, int) and token_id >= 0 and token_id != unk_token_id: + image_token_ids.add(token_id) + + return image_token_ids + + +def _expand_loss_mask_for_multimodal_tokens( + input_ids: list[int], + loss_mask: torch.Tensor, + expanded_input_ids: list[int], + image_token_ids: set[int], +) -> torch.Tensor: + """Expand loss masks when a processor expands image placeholders.""" + if input_ids == expanded_input_ids: + return loss_mask + + if not image_token_ids: + raise ValueError("Cannot align expanded multimodal input IDs without image IDs") + + mask_values = loss_mask.tolist() + expanded_mask: list[int] = [] + src_idx = 0 + dst_idx = 0 + + while src_idx < len(input_ids): + token_id = input_ids[src_idx] + if dst_idx >= len(expanded_input_ids): + break + + if token_id in image_token_ids: + start_idx = dst_idx + while ( + dst_idx < len(expanded_input_ids) + and expanded_input_ids[dst_idx] == token_id + ): + dst_idx += 1 + if dst_idx == start_idx: + raise ValueError( + f"Unable to align image token {token_id} at position {src_idx}" + ) + expanded_mask.extend([0] * (dst_idx - start_idx)) + src_idx += 1 + continue + + if expanded_input_ids[dst_idx] != token_id: + raise ValueError( + "Unable to align expanded multimodal input IDs at " + f"source={src_idx}, expanded={dst_idx}: " + f"{token_id} != {expanded_input_ids[dst_idx]}" + ) + expanded_mask.append(int(mask_values[src_idx])) + src_idx += 1 + dst_idx += 1 + + if dst_idx != len(expanded_input_ids): + raise ValueError( + "Expanded multimodal input IDs contain trailing tokens after alignment" + ) + + return torch.tensor(expanded_mask, dtype=loss_mask.dtype) + + +def _expand_multimodal_inputs_with_images( + processor: Any, + tokenizer: PreTrainedTokenizerBase, + formatted_text: str, + normalized_conv: list[dict], + input_ids: list[int], + loss_mask: torch.Tensor, + max_length: int, +) -> tuple[list[int], torch.Tensor]: + """Use actual images so HF preprocessing matches vLLM VL token expansion.""" + images = _extract_processor_images_from_conversation(normalized_conv) + if not images: + return input_ids, loss_mask + + encoded = processor( + text=[formatted_text], + images=images, + max_length=max_length, + truncation=True, + ) + expanded_input_ids = _flatten_singleton_batch( + encoded["input_ids"], + field_name="Multimodal processor input_ids with images", + ) + expanded_loss_mask = _expand_loss_mask_for_multimodal_tokens( + input_ids, + loss_mask, + expanded_input_ids, + _get_image_token_ids(tokenizer), + ) + return expanded_input_ids, expanded_loss_mask def _normalize_conversation( @@ -486,8 +600,7 @@ def _preprocess_batch_multimodal( # noqa: PLR0912, PLR0915 results: dict[str, list] = { "input_ids": [], "loss_mask": [], - "multi_modal_data": [], - "prompt": [], + "messages": [], "seq_len": [], } conversations = _get_conversations_from_examples(examples) @@ -565,20 +678,34 @@ def _preprocess_batch_multimodal( # noqa: PLR0912, PLR0915 f"loss_mask={len(loss_mask)}" ) + input_ids, loss_mask = _expand_multimodal_inputs_with_images( + processor, + tokenizer, + formatted_raw, + normalized_conv, + input_ids, + loss_mask, + max_length, + ) + if minimum_valid_tokens is not None: num_valid_tokens = int(loss_mask.sum().item()) if num_valid_tokens < minimum_valid_tokens: continue - mm_data = _extract_multimodal_data_from_conversation(normalized_conv) - results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) results["loss_mask"].append(loss_mask) - results["multi_modal_data"].append(mm_data) - results["prompt"].append(formatted_raw) + results["messages"].append(normalized_conv) results["seq_len"].append(len(input_ids)) - except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: + except ( + TypeError, + ValueError, + KeyError, + AttributeError, + RuntimeError, + OSError, + ) as e: log.error( f"Failed to process conversation {idx} " f"(assistant_pattern={assistant_pattern is not None}): {e}" @@ -653,7 +780,11 @@ def build_eagle3_dataset( map_fn, batched=True, num_proc=num_proc, - batch_size=1000, + # Multimodal preprocessing loads each image and asks the processor to + # expand image placeholders into the real visual-token length. Large + # batches can look stuck because a worker must finish the whole batch + # before datasets updates progress. + batch_size=32 if is_multimodal else 1000, remove_columns=original_cols, keep_in_memory=True, # skip caching ) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 9679a5884..ed9f1dec0 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -1,8 +1,12 @@ import asyncio +import base64 import functools import logging +import mimetypes import time +from pathlib import Path from typing import Any, cast +from urllib.parse import urlparse import openai @@ -17,6 +21,93 @@ class InvalidResponseError(Exception): pass +def _get_field(obj: Any, key: str) -> Any: + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) + + +def _image_ref_to_chat_url(image_ref: Any) -> str: + """Convert a dataset image reference to an OpenAI-compatible image URL.""" + ref = str(image_ref) + parsed = urlparse(ref) + if parsed.scheme in {"http", "https", "data", "file"}: + return ref + + path = Path(ref).expanduser() + if path.exists() and path.is_file(): + mime_type = mimetypes.guess_type(path.name)[0] or "image/jpeg" + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + return f"data:{mime_type};base64,{encoded}" + + if path.is_absolute(): + return path.as_uri() + + return ref + + +def _get_image_ref(part: dict[str, Any]) -> Any | None: + if part.get("type") not in ("image", "image_url", "input_image"): + return None + + image_ref = part.get("image") + if image_ref is not None: + return image_ref + + image_url = part.get("image_url") + if isinstance(image_url, dict): + return image_url.get("url") + return image_url + + +def _prepare_chat_message_content(content: Any) -> Any: + if not isinstance(content, list): + return content + + prepared = [] + for part in content: + if isinstance(part, str): + prepared.append({"type": "text", "text": part}) + continue + + if not isinstance(part, dict): + prepared.append(part) + continue + + image_ref = _get_image_ref(part) + if image_ref is not None: + prepared.append( + { + "type": "image_url", + "image_url": {"url": _image_ref_to_chat_url(image_ref)}, + } + ) + continue + + text = part.get("text") + if text is not None: + prepared.append({"type": "text", "text": str(text)}) + continue + + prepared.append(part) + + return prepared + + +def _prepare_chat_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert processor-style multimodal messages to vLLM chat messages.""" + prepared = [] + for message in messages: + prepared_message: dict[str, Any] = { + "role": message.get("role"), + "content": _prepare_chat_message_content(message.get("content", "")), + } + if "name" in message: + prepared_message["name"] = message["name"] + prepared.append(prepared_message) + return prepared + + def _handle_retry_error( error: Exception, attempt: int, total_attempts: int ) -> float | None: @@ -84,7 +175,11 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs): def extract_output(completion, token_ids) -> str: - prompt_token_ids = getattr(completion.choices[0], "prompt_token_ids", None) + prompt_token_ids = _get_field(completion, "prompt_token_ids") + if prompt_token_ids is None: + choices = _get_field(completion, "choices") + if choices: + prompt_token_ids = _get_field(choices[0], "prompt_token_ids") if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") @@ -94,34 +189,14 @@ def extract_output(completion, token_ids) -> str: f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" ) - if not hasattr(completion, "kv_transfer_params"): + kv_transfer_params = _get_field(completion, "kv_transfer_params") + if kv_transfer_params is None: raise InvalidResponseError("Response missing kv_transfer_params") - return completion.kv_transfer_params.get("hidden_states_path") - - -def _build_request_payload( - token_ids: list[int], - prompt: str | None = None, - multi_modal_data: dict[str, Any] | None = None, -) -> tuple[list[int] | dict[str, Any], dict[str, Any]]: - """Build a vLLM completion request payload. - - Text-only requests can pass token IDs directly. Requests that need the - original prompt text or multimodal inputs use vLLM's structured prompt form. - """ - extra_body: dict[str, Any] = {"return_token_ids": True} - - if prompt is None and multi_modal_data is None: - return token_ids, extra_body - - request_prompt: dict[str, Any] = {"prompt_token_ids": token_ids} - if prompt is not None: - request_prompt["prompt"] = prompt - if multi_modal_data is not None: - request_prompt["multi_modal_data"] = multi_modal_data - - return request_prompt, extra_body + hidden_states_path = _get_field(kv_transfer_params, "hidden_states_path") + if hidden_states_path is None: + raise InvalidResponseError("Response missing hidden_states_path") + return hidden_states_path @with_retries @@ -129,8 +204,7 @@ async def generate_hidden_states_async( client: openai.AsyncClient, model: str, token_ids: list[int], - prompt: str | None = None, - multi_modal_data: dict[str, Any] | None = None, + messages: list[dict[str, Any]] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ @@ -141,20 +215,26 @@ async def generate_hidden_states_async( client: The async OpenAI client. model: The model ID. token_ids: The input token IDs. - prompt: Optional prompt text corresponding to ``token_ids``. - multi_modal_data: Optional multimodal payload expected by vLLM. + messages: Optional chat messages for vLLM multimodal requests. timeout: Timeout in seconds for each request attempt. None for no timeout. """ - request_prompt, extra_body = _build_request_payload( - token_ids, prompt, multi_modal_data - ) - coro = client.completions.create( - model=model, - prompt=cast("Any", request_prompt), - max_tokens=1, - extra_body=extra_body, - timeout=timeout, - ) + if messages is not None: + chat_messages = _prepare_chat_messages(messages) + coro = client.chat.completions.create( + model=model, + messages=cast("Any", chat_messages), + max_tokens=1, + extra_body={"return_token_ids": True, "add_generation_prompt": False}, + timeout=timeout, + ) + else: + coro = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) if timeout is not None: completion = await asyncio.wait_for(coro, timeout=timeout) else: @@ -168,22 +248,28 @@ def generate_hidden_states( client: openai.Client, model: str, token_ids: list[int], - prompt: str | None = None, - multi_modal_data: dict[str, Any] | None = None, + messages: list[dict[str, Any]] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ - request_prompt, extra_body = _build_request_payload( - token_ids, prompt, multi_modal_data - ) - completion = client.completions.create( - model=model, - prompt=cast("Any", request_prompt), - max_tokens=1, - extra_body=extra_body, - timeout=timeout, - ) + if messages is not None: + chat_messages = _prepare_chat_messages(messages) + completion = client.chat.completions.create( + model=model, + messages=cast("Any", chat_messages), + max_tokens=1, + extra_body={"return_token_ids": True, "add_generation_prompt": False}, + timeout=timeout, + ) + else: + completion = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) return extract_output(completion, token_ids) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 5351320a4..c9747a653 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -266,13 +266,13 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: input_ids = sample["input_ids"] if hasattr(input_ids, "tolist"): input_ids = input_ids.tolist() + messages = sample.get("messages") try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] input_ids, - prompt=sample.get("prompt"), - multi_modal_data=sample.get("multi_modal_data"), + messages=messages, timeout=self.request_timeout, max_retries=self.max_retries, ) diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index f7aa3a0d2..c280eaae2 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -14,6 +14,7 @@ from speculators.data_generation.preprocessing import ( _create_loss_mask_from_offsets, _detect_assistant_pattern, + _expand_loss_mask_for_multimodal_tokens, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, @@ -136,6 +137,21 @@ def apply_chat_template(self, *args, **kwargs): assert len(results["loss_mask"]) == 1 +@pytest.mark.sanity +def test_expand_loss_mask_for_multimodal_tokens_allows_truncation(): + """Test that expanded multimodal alignment tolerates processor truncation.""" + loss_mask = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + expanded = _expand_loss_mask_for_multimodal_tokens( + input_ids=[10, 999, 20, 30], + loss_mask=loss_mask, + expanded_input_ids=[10, 999, 999, 999, 20], + image_token_ids={999}, + ) + + assert expanded.tolist() == [0, 0, 0, 0, 1] + + @pytest.mark.sanity def test_load_and_preprocess_dataset_defaults_to_text_mode(monkeypatch): """Test that omitted multimodal flag falls back to text-only processing.""" @@ -1050,12 +1066,23 @@ def test_build_eagle3_dataset_with_custom_pattern(): @pytest.mark.sanity -def test_build_eagle3_dataset_multimodal_preserves_prompt_and_mm_data(): - """Test multimodal preprocessing preserves prompt text and image metadata.""" +def test_build_eagle3_dataset_multimodal_expands_image_tokens_and_preserves_messages(): + """Test multimodal preprocessing expands image tokens and preserves messages.""" + image_url = ( + "data:image/png;base64," + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/" + "x8AAwMCAO+/p9sAAAAASUVORK5CYII=" + ) class DummyTokenizer: pad_token = "" eos_token = "" + unk_token_id = -1 + + def convert_tokens_to_ids(self, token): + if token == "<|image_pad|>": + return 999 + return self.unk_token_id def __call__( self, @@ -1077,9 +1104,15 @@ class DummyProcessor: def __init__(self): self.tokenizer = DummyTokenizer() + def __call__(self, *args, **kwargs): + images = kwargs["images"] + assert len(images) == 1 + assert getattr(images[0], "mode", None) == "RGB" + return {"input_ids": [[101, 999, 999, 999, 103]]} + def apply_chat_template(self, *args, **kwargs): if kwargs.get("tokenize"): - return {"input_ids": [[101, 102, 103]], "assistant_mask": [[0, 1, 1]]} + return {"input_ids": [[101, 999, 103]], "assistant_mask": [[0, 0, 1]]} return "formatted multimodal prompt" dataset = HFDataset.from_dict( @@ -1092,7 +1125,7 @@ def apply_chat_template(self, *args, **kwargs): {"type": "text", "text": "Describe "}, { "type": "image_url", - "image_url": {"url": "https://example.com/cat.png"}, + "image_url": {"url": image_url}, }, ], }, @@ -1114,7 +1147,24 @@ def apply_chat_template(self, *args, **kwargs): processor=processor, ) - assert "prompt" in result.column_names - assert "multi_modal_data" in result.column_names - assert result[0]["prompt"] == "formatted multimodal prompt" - assert result[0]["multi_modal_data"] == {"image": ["https://example.com/cat.png"]} + assert "messages" in result.column_names + assert result[0]["input_ids"].tolist() == [101, 999, 999, 999, 103] + assert result[0]["loss_mask"].tolist() == [0, 0, 0, 0, 1] + assert result[0]["seq_len"] == 5 + assert result[0]["messages"] == [ + { + "content": [ + {"image": None, "text": "Describe", "type": "text"}, + { + "image": image_url, + "text": None, + "type": "image", + }, + ], + "role": "user", + }, + { + "content": [{"image": None, "text": "A cat.", "type": "text"}], + "role": "assistant", + }, + ] diff --git a/tests/unit/data_generation/test_vllm_client.py b/tests/unit/data_generation/test_vllm_client.py index 1e1870d2a..72b22e239 100644 --- a/tests/unit/data_generation/test_vllm_client.py +++ b/tests/unit/data_generation/test_vllm_client.py @@ -17,6 +17,12 @@ def __init__(self, prompt_token_ids): self.kv_transfer_params = {"hidden_states_path": "/tmp/hs_0.safetensors"} +class _DummyChatCompletion: + def __init__(self, prompt_token_ids): + self.prompt_token_ids = prompt_token_ids + self.kv_transfer_params = {"hidden_states_path": "/tmp/hs_0.safetensors"} + + class _DummySyncCompletions: def __init__(self): self.calls = [] @@ -43,14 +49,44 @@ async def create(self, **kwargs): return _DummyCompletion(prompt_token_ids) +class _DummySyncChatCompletions: + def __init__(self): + self.calls = [] + + def create(self, **kwargs): + self.calls.append(kwargs) + return _DummyChatCompletion([4, 5, 6]) + + +class _DummyAsyncChatCompletions: + def __init__(self): + self.calls = [] + + async def create(self, **kwargs): + self.calls.append(kwargs) + return _DummyChatCompletion([7, 8, 9]) + + +class _DummySyncChat: + def __init__(self): + self.completions = _DummySyncChatCompletions() + + +class _DummyAsyncChat: + def __init__(self): + self.completions = _DummyAsyncChatCompletions() + + class _DummySyncClient: def __init__(self): self.completions = _DummySyncCompletions() + self.chat = _DummySyncChat() class _DummyAsyncClient: def __init__(self): self.completions = _DummyAsyncCompletions() + self.chat = _DummyAsyncChat() def test_generate_hidden_states_text_prompt(): @@ -62,45 +98,117 @@ def test_generate_hidden_states_text_prompt(): assert client.completions.calls[0]["prompt"] == [1, 2, 3] -def test_generate_hidden_states_multimodal_prompt(): +def test_generate_hidden_states_multimodal_messages(): client = _DummySyncClient() - multi_modal_data = {"image": ["https://example.com/cat.png"]} + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://example.com/cat.png"}, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] + expected_messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] result = generate_hidden_states( client, "dummy-model", [4, 5, 6], - prompt="formatted prompt", - multi_modal_data=multi_modal_data, + messages=messages, timeout=1, ) assert result == "/tmp/hs_0.safetensors" - assert client.completions.calls[0]["prompt"] == { - "prompt_token_ids": [4, 5, 6], - "prompt": "formatted prompt", - "multi_modal_data": multi_modal_data, + assert client.chat.completions.calls[0]["messages"] == expected_messages + assert client.chat.completions.calls[0]["extra_body"] == { + "return_token_ids": True, + "add_generation_prompt": False, } -def test_generate_hidden_states_async_multimodal_prompt(): +def test_generate_hidden_states_multimodal_messages_inlines_local_image(tmp_path): + client = _DummySyncClient() + image_path = tmp_path / "cat.png" + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": str(image_path)}, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": "A cat."}, + ] + + result = generate_hidden_states( + client, + "dummy-model", + [4, 5, 6], + messages=messages, + timeout=1, + ) + + sent_content = client.chat.completions.calls[0]["messages"][0]["content"] + assert result == "/tmp/hs_0.safetensors" + assert sent_content[0]["type"] == "image_url" + assert sent_content[0]["image_url"]["url"].startswith("data:image/png;base64,") + assert sent_content[1] == {"type": "text", "text": "describe"} + + +def test_generate_hidden_states_async_multimodal_messages(): client = _DummyAsyncClient() - multi_modal_data = {"image": ["https://example.com/cat.png"]} + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://example.com/cat.png"}, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] + expected_messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.png"}, + }, + {"type": "text", "text": "describe"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "A cat."}]}, + ] result = asyncio.run( generate_hidden_states_async( client, "dummy-model", [7, 8, 9], - prompt="formatted prompt", - multi_modal_data=multi_modal_data, + messages=messages, timeout=1, ) ) assert result == "/tmp/hs_0.safetensors" - assert client.completions.calls[0]["prompt"] == { - "prompt_token_ids": [7, 8, 9], - "prompt": "formatted prompt", - "multi_modal_data": multi_modal_data, + assert client.chat.completions.calls[0]["messages"] == expected_messages + assert client.chat.completions.calls[0]["extra_body"] == { + "return_token_ids": True, + "add_generation_prompt": False, } From 75f0c192fe88360e10763fd34f793be3a880a435 Mon Sep 17 00:00:00 2001 From: Haoxiang Sun Date: Fri, 1 May 2026 13:42:17 +0800 Subject: [PATCH 3/5] Handle truncated multimodal prompt hidden states Signed-off-by: Haoxiang Sun --- .../data_generation/vllm_client.py | 83 ++++++++-- .../unit/data_generation/test_vllm_client.py | 147 ++++++++++++++++-- 2 files changed, 205 insertions(+), 25 deletions(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index ed9f1dec0..9d41e392a 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -9,6 +9,7 @@ from urllib.parse import urlparse import openai +from safetensors.torch import load_file, save_file logger = logging.getLogger(__name__) @@ -27,6 +28,12 @@ def _get_field(obj: Any, key: str) -> Any: return getattr(obj, key, None) +def _to_token_id_list(token_ids: Any) -> list[int]: + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + return list(token_ids) + + def _image_ref_to_chat_url(image_ref: Any) -> str: """Convert a dataset image reference to an OpenAI-compatible image URL.""" ref = str(image_ref) @@ -108,6 +115,39 @@ def _prepare_chat_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any return prepared +def _truncate_hidden_states_file(hidden_states_path: str, token_ids: list[int]) -> str: + """Trim vLLM's full multimodal prompt output to the preprocessed prefix.""" + tensors = load_file(hidden_states_path) + + try: + file_token_ids = _to_token_id_list(tensors["token_ids"]) + hidden_states = tensors["hidden_states"] + except KeyError as exc: + raise InvalidResponseError( + f"Hidden states file missing {exc.args[0]}: {hidden_states_path}" + ) from exc + + expected_len = len(token_ids) + if ( + file_token_ids[:expected_len] != token_ids + or hidden_states.shape[0] < expected_len + ): + raise InvalidResponseError( + "Hidden states file does not match preprocessed prompt prefix: " + f"expected {token_ids}, got {file_token_ids}" + ) + + # Safe for causal models: prefix hidden states cannot attend to future tokens + # that only exist in vLLM's full chat-rendered prompt. + tensors["token_ids"] = tensors["token_ids"][:expected_len].contiguous() + tensors["hidden_states"] = hidden_states[:expected_len].contiguous() + save_file( + {key: value.contiguous() for key, value in tensors.items()}, + hidden_states_path, + ) + return hidden_states_path + + def _handle_retry_error( error: Exception, attempt: int, total_attempts: int ) -> float | None: @@ -174,7 +214,13 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs): return sync_wrapper -def extract_output(completion, token_ids) -> str: +def extract_output( + completion, + token_ids, + *, + allow_prefix_truncation: bool = False, +) -> str: + token_ids = _to_token_id_list(token_ids) prompt_token_ids = _get_field(completion, "prompt_token_ids") if prompt_token_ids is None: choices = _get_field(completion, "choices") @@ -184,11 +230,6 @@ def extract_output(completion, token_ids) -> str: if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") - if prompt_token_ids != token_ids: - raise InvalidResponseError( - f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" - ) - kv_transfer_params = _get_field(completion, "kv_transfer_params") if kv_transfer_params is None: raise InvalidResponseError("Response missing kv_transfer_params") @@ -196,7 +237,23 @@ def extract_output(completion, token_ids) -> str: hidden_states_path = _get_field(kv_transfer_params, "hidden_states_path") if hidden_states_path is None: raise InvalidResponseError("Response missing hidden_states_path") - return hidden_states_path + + prompt_token_ids = _to_token_id_list(prompt_token_ids) + if prompt_token_ids == token_ids: + return hidden_states_path + + if allow_prefix_truncation and prompt_token_ids[: len(token_ids)] == token_ids: + logger.debug( + "vLLM returned %d prompt tokens for a %d-token preprocessed prompt; " + "truncating hidden states to match prepare_data output.", + len(prompt_token_ids), + len(token_ids), + ) + return _truncate_hidden_states_file(hidden_states_path, token_ids) + + raise InvalidResponseError( + f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" + ) @with_retries @@ -240,7 +297,11 @@ async def generate_hidden_states_async( else: completion = await coro - return extract_output(completion, token_ids) + return extract_output( + completion, + token_ids, + allow_prefix_truncation=messages is not None, + ) @with_retries @@ -272,4 +333,8 @@ def generate_hidden_states( extra_body={"return_token_ids": True}, timeout=timeout, ) - return extract_output(completion, token_ids) + return extract_output( + completion, + token_ids, + allow_prefix_truncation=messages is not None, + ) diff --git a/tests/unit/data_generation/test_vllm_client.py b/tests/unit/data_generation/test_vllm_client.py index 72b22e239..30393cc68 100644 --- a/tests/unit/data_generation/test_vllm_client.py +++ b/tests/unit/data_generation/test_vllm_client.py @@ -1,6 +1,11 @@ import asyncio +import pytest +import torch +from safetensors.torch import load_file, save_file + from speculators.data_generation.vllm_client import ( + InvalidResponseError, generate_hidden_states, generate_hidden_states_async, ) @@ -12,15 +17,23 @@ def __init__(self, prompt_token_ids): class _DummyCompletion: - def __init__(self, prompt_token_ids): + def __init__( + self, + prompt_token_ids, + hidden_states_path="/tmp/hs_0.safetensors", + ): self.choices = [_DummyChoice(prompt_token_ids)] - self.kv_transfer_params = {"hidden_states_path": "/tmp/hs_0.safetensors"} + self.kv_transfer_params = {"hidden_states_path": hidden_states_path} class _DummyChatCompletion: - def __init__(self, prompt_token_ids): + def __init__( + self, + prompt_token_ids, + hidden_states_path="/tmp/hs_0.safetensors", + ): self.prompt_token_ids = prompt_token_ids - self.kv_transfer_params = {"hidden_states_path": "/tmp/hs_0.safetensors"} + self.kv_transfer_params = {"hidden_states_path": hidden_states_path} class _DummySyncCompletions: @@ -50,43 +63,75 @@ async def create(self, **kwargs): class _DummySyncChatCompletions: - def __init__(self): + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): self.calls = [] + self.prompt_token_ids = prompt_token_ids or [4, 5, 6] + self.hidden_states_path = hidden_states_path def create(self, **kwargs): self.calls.append(kwargs) - return _DummyChatCompletion([4, 5, 6]) + return _DummyChatCompletion(self.prompt_token_ids, self.hidden_states_path) class _DummyAsyncChatCompletions: - def __init__(self): + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): self.calls = [] + self.prompt_token_ids = prompt_token_ids or [7, 8, 9] + self.hidden_states_path = hidden_states_path async def create(self, **kwargs): self.calls.append(kwargs) - return _DummyChatCompletion([7, 8, 9]) + return _DummyChatCompletion(self.prompt_token_ids, self.hidden_states_path) class _DummySyncChat: - def __init__(self): - self.completions = _DummySyncChatCompletions() + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.completions = _DummySyncChatCompletions( + prompt_token_ids, hidden_states_path + ) class _DummyAsyncChat: - def __init__(self): - self.completions = _DummyAsyncChatCompletions() + def __init__( + self, + prompt_token_ids=None, + hidden_states_path="/tmp/hs_0.safetensors", + ): + self.completions = _DummyAsyncChatCompletions( + prompt_token_ids, hidden_states_path + ) class _DummySyncClient: - def __init__(self): + def __init__( + self, + chat_prompt_token_ids=None, + chat_hidden_states_path="/tmp/hs_0.safetensors", + ): self.completions = _DummySyncCompletions() - self.chat = _DummySyncChat() + self.chat = _DummySyncChat(chat_prompt_token_ids, chat_hidden_states_path) class _DummyAsyncClient: - def __init__(self): + def __init__( + self, + chat_prompt_token_ids=None, + chat_hidden_states_path="/tmp/hs_0.safetensors", + ): self.completions = _DummyAsyncCompletions() - self.chat = _DummyAsyncChat() + self.chat = _DummyAsyncChat(chat_prompt_token_ids, chat_hidden_states_path) def test_generate_hidden_states_text_prompt(): @@ -212,3 +257,73 @@ def test_generate_hidden_states_async_multimodal_messages(): "return_token_ids": True, "add_generation_prompt": False, } + + +def test_generate_hidden_states_truncates_multimodal_prefix_match(tmp_path): + hs_path = tmp_path / "hs_0.safetensors" + save_file( + { + "token_ids": torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), + "hidden_states": torch.arange(5 * 2 * 3, dtype=torch.float32).reshape( + 5, 2, 3 + ), + }, + hs_path, + ) + client = _DummySyncClient( + chat_prompt_token_ids=[1, 2, 3, 4, 5], + chat_hidden_states_path=str(hs_path), + ) + + result = generate_hidden_states( + client, + "dummy-model", + [1, 2, 3], + messages=[{"role": "user", "content": "describe"}], + timeout=1, + ) + + tensors = load_file(result) + assert result == str(hs_path) + assert tensors["token_ids"].tolist() == [1, 2, 3] + assert tensors["hidden_states"].shape == (3, 2, 3) + assert torch.equal( + tensors["hidden_states"], + torch.arange(5 * 2 * 3, dtype=torch.float32).reshape(5, 2, 3)[:3], + ) + + +def test_generate_hidden_states_rejects_multimodal_non_prefix_mismatch(tmp_path): + hs_path = tmp_path / "hs_0.safetensors" + save_file( + { + "token_ids": torch.tensor([1, 9, 3, 4, 5], dtype=torch.long), + "hidden_states": torch.zeros(5, 2, 3), + }, + hs_path, + ) + client = _DummySyncClient( + chat_prompt_token_ids=[1, 9, 3, 4, 5], + chat_hidden_states_path=str(hs_path), + ) + + with pytest.raises(InvalidResponseError, match="Prompt token IDs mismatch"): + generate_hidden_states( + client, + "dummy-model", + [1, 2, 3], + messages=[{"role": "user", "content": "describe"}], + timeout=1, + ) + + +def test_generate_hidden_states_text_path_rejects_prefix_mismatch(): + class _PrefixTextCompletions: + def create(self, **kwargs): + return _DummyCompletion([1, 2, 3, 4]) + + client = _DummySyncClient() + client.completions = _PrefixTextCompletions() + + with pytest.raises(InvalidResponseError, match="Prompt token IDs mismatch"): + generate_hidden_states(client, "dummy-model", [1, 2, 3], timeout=1) From 4d06511bbf04d1093b67712c0bec7deaede45a94 Mon Sep 17 00:00:00 2001 From: Haoxiang Sun Date: Fri, 1 May 2026 13:55:14 +0800 Subject: [PATCH 4/5] Restore docs indexes from upstream Signed-off-by: Haoxiang Sun --- docs/api/index.md | 3 ++- docs/index.md | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/api/index.md b/docs/api/index.md index f7496e11d..79738bfb5 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -2,4 +2,5 @@ This section contains the auto-generated Python API documentation for Speculators. -!!! warning Using the Python API directly is **not officially supported** at this time and there are **no guarantees of backward compatibility** between releases. We recommend using the [CLI commands](../cli/index.md) as the primary entrypoints for interacting with Speculators. +!!! warning + Using the Python API directly is **not officially supported** at this time and there are **no guarantees of backward compatibility** between releases. We recommend using the [CLI commands](../cli/index.md) as the primary entrypoints for interacting with Speculators. diff --git a/docs/index.md b/docs/index.md index 57e2e580e..ee740d25c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,8 @@ Speculators standardizes this process by providing a productionized end-to-end f - **Standardized, Extensible Format:** Provides a Hugging Face-compatible format for defining speculative models, with tools to convert from external research repositories into a standard speculators format for easy adoption. - **Seamless vLLM Integration:** Built for direct deployment into vLLM, enabling low-latency, production-grade inference with minimal overhead. -!!! tip Read more about Speculators features in this [vLLM blog post](https://blog.vllm.ai/2025/12/13/speculators-v030.html). +!!! tip + Read more about Speculators features in this [vLLM blog post](https://blog.vllm.ai/2025/12/13/speculators-v030.html). ## Quick Start From 331be30987e7b9ebf1091f57227003a64c55b417 Mon Sep 17 00:00:00 2001 From: Haoxiang Sun Date: Fri, 1 May 2026 15:52:32 +0800 Subject: [PATCH 5/5] Update evaluation results Signed-off-by: Haoxiang Sun --- .../eagle3_qwen3_vl_4b_llava_cot_5k_online.sh | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh b/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh index 6f9431483..80fbf3b45 100644 --- a/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh +++ b/examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh @@ -29,18 +29,18 @@ # checkpointing all work end-to-end, but it is not intended to represent final # model quality. # -# Timing from an observed run on 4x NVIDIA GeForce RTX 4090 24GB GPUs +# Timing from an observed run on 4x NVIDIA GeForce RTX 5090 32GB GPUs # (vLLM on GPUs 0,1 and training on GPUs 2,3): -# Data preprocessing: 8 seconds -# vLLM server startup: 54 seconds -# Training (5 epochs): 1337 seconds (22 mins 17 secs) -# Total (prepare_data start to checkpoint save): 1427 seconds (23 mins 47 secs) +# Data preprocessing: 460 seconds (7 mins 40 secs) +# vLLM server startup: 45 seconds +# Training (5 epochs): 1110 seconds (18 mins 30 secs) +# Total (prepare_data start to checkpoint save): 1615 seconds (26 mins 55 secs) # # Final validation metrics from that run: -# val/loss_epoch: 8.4479 -# val/full_acc_0_epoch: 58.55% -# val/full_acc_1_epoch: 32.46% -# val/full_acc_2_epoch: 18.22% +# val/loss_epoch: 8.676 +# val/full_acc_0_epoch: 57.7% +# val/full_acc_1_epoch: 31.9% +# val/full_acc_2_epoch: 17.9% set -euo pipefail @@ -60,6 +60,7 @@ VLLM_TP="${VLLM_TP:-2}" EPOCHS="${EPOCHS:-5}" LR="${LR:-1e-4}" VLLM_EXTRA_ARGS="${VLLM_EXTRA_ARGS:-}" +VLLM_LOG_FILE="${VLLM_LOG_FILE:-./vllm_server.log}" # GPU assignments VLLM_GPUS="${VLLM_GPUS:-0,1}" @@ -177,6 +178,8 @@ python scripts/prepare_data.py \ --multimodal echo "=== Step 4: Launching vLLM server ===" +echo "vLLM logs will be written to: $VLLM_LOG_FILE" + CUDA_VISIBLE_DEVICES="$VLLM_GPUS" python scripts/launch_vllm.py "$MODEL" \ --hidden-states-path "$HIDDEN_STATES_DIR" \ -- \ @@ -184,7 +187,8 @@ CUDA_VISIBLE_DEVICES="$VLLM_GPUS" python scripts/launch_vllm.py "$MODEL" \ --tensor-parallel-size "$VLLM_TP" \ --max-model-len "$VLLM_MAX_MODEL_LEN" \ --limit-mm-per-prompt '{"image":1}' \ - "${VLLM_EXTRA_ARR[@]}" & + "${VLLM_EXTRA_ARR[@]}" \ + > "$VLLM_LOG_FILE" 2>&1 & VLLM_PID=$! cleanup() {