From f07f57526484efbbacae036d9f7f43e9b5855f3d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 30 Apr 2026 17:52:35 +0000 Subject: [PATCH 01/95] Support multi-modal datasets (preprocessing part) Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 2 + scripts/prepare_data.py | 11 +- src/speculators/data_generation/configs.py | 36 ++++++ .../data_generation/preprocessing.py | 121 ++++++++++++------ .../data_generation/vllm_client.py | 55 ++++++-- src/speculators/train/data.py | 6 +- 6 files changed, 178 insertions(+), 53 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 04882c09b..37650a0fa 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -277,6 +277,7 @@ async def worker( continue input_ids = item["input_ids"].tolist() + messages = item.get("messages") target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" @@ -286,6 +287,7 @@ async def worker( client, model, input_ids, + messages=messages, timeout=request_timeout, max_retries=max_retries, ) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 3390d9eef..64103278d 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -49,6 +49,14 @@ def parse_args(): required=True, help="HuggingFace model ID or local path for target model", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help=( + "Allow executing code from HF Hub when loading the target model's. " + "processor. Also applies to datasets." + ), + ) # Data arguments parser.add_argument( @@ -75,7 +83,7 @@ def parse_args(): type=str, default=None, help=( - "Path to save token frequency distribution" + "Path to save token frequency distribution " "(default: args.output / 'token_freq.pt')" ), ) @@ -177,6 +185,7 @@ def main(): assistant_pattern=args.assistant_pattern, turn_dropout=args.turn_dropout, minimum_valid_tokens=args.minimum_valid_tokens, + trust_remote_code=args.trust_remote_code, ) log.info("Done preparing data") diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 915823bb7..2d8900082 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -25,6 +25,34 @@ def _normalize_ultrachat(example: dict) -> dict: return example +COCO_TASKS = [ + "Locate each object in this image.", + "Describe the image with a brief caption.", +] + + +def _normalize_coco(example: dict) -> dict: + image_path_local = example["image"] + + conversations = [ + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": f"file://{image_path_local}"}, + { + "type": "text", + "text": task, + }, + ], + } + for task in COCO_TASKS + ] + ] + + return {"conversations": conversations} + + DATASET_CONFIGS: dict[str, DatasetConfig] = { "sharegpt": DatasetConfig( name="sharegpt", @@ -37,4 +65,12 @@ def _normalize_ultrachat(example: dict) -> dict: split="train_sft", normalize_fn=_normalize_ultrachat, ), + # NOTE: `datasets<4` is needed to run custom script + # You also need to pass `--allowed-local-media-path` to `launch_vllm.py` + "coco": DatasetConfig( + name="coco", + hf_path="HuggingFaceM4/COCO", + split="train", + normalize_fn=_normalize_coco, + ), } diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 3adf0c18b..f7f5a6c78 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -3,12 +3,19 @@ import re from pathlib import Path from re import Pattern -from typing import Any, cast +from typing import cast +import aiohttp import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import ( + AutoProcessor, + BatchEncoding, + BatchFeature, + PreTrainedTokenizerBase, + ProcessorMixin, +) from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger @@ -23,7 +30,10 @@ log = PipelineLogger(__name__) -def _visualize_sample(preprocessed, tokenizer, idx: int = 0): +ProcessorLike = PreTrainedTokenizerBase | ProcessorMixin + + +def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: int = 0): """Visualize a single sample with color-coded trainable regions.""" # Get preprocessed sample prep_sample = preprocessed[idx] @@ -43,7 +53,7 @@ def _visualize_sample(preprocessed, tokenizer, idx: int = 0): for i in range(len(input_ids)): is_train = loss_mask[i].item() == 1 - token = tokenizer.decode([input_ids[i].item()]) + token = processor.decode([input_ids[i].item()]) # Switch colors when state changes if is_train != prev_state: @@ -107,19 +117,19 @@ def _normalize_conversation( return normalized -def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: - """Check if tokenizer truly supports HF assistant token mask. +def _supports_assistant_mask(processor: ProcessorLike) -> bool: + """Check if processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ try: - res_any = tokenizer.apply_chat_template( + res_any = processor.apply_chat_template( [{"role": "assistant", "content": "test"}], tokenize=True, return_assistant_tokens_mask=True, return_dict=True, ) - res = cast("dict[str, Any]", res_any) + res = cast("BatchEncoding | BatchFeature", res_any) # Check both singular and plural key names mask = res.get("assistant_masks", res.get("assistant_mask")) if mask is None: @@ -131,8 +141,8 @@ def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: return False -def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: - """Auto-detect the assistant message pattern from the tokenizer's chat template. +def _detect_assistant_pattern(processor: ProcessorLike) -> str: + """Auto-detect the assistant message pattern from the processor's chat template. Uses multi-turn conversation but extracts pattern from the LAST assistant message only. @@ -144,7 +154,7 @@ def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: {"role": "assistant", "content": "ASSISTANT_MSG_2"}, ] - formatted = tokenizer.apply_chat_template( + formatted = processor.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False ) assert isinstance(formatted, str), "Expected string from apply_chat_template" @@ -255,9 +265,14 @@ def _create_loss_mask_from_offsets( return loss_mask +# Keys in processor outputs that indicate that the conversation has multi-modal data +# In that case, we must use Chat Completions API to send the multi-modal inputs to vLLM +MM_KEYS = {"pixel_values", "pixel_values_videos", "input_features"} + + def _preprocess_batch( examples: dict, - tokenizer: PreTrainedTokenizerBase, + processor: ProcessorLike, max_length: int, assistant_pattern: str | Pattern[str] | None, turn_dropout: bool = False, @@ -281,17 +296,18 @@ def _preprocess_batch( if not normalized_conv: continue + encoded: BatchEncoding | BatchFeature try: if assistant_pattern is None: # HF assistant token mask - encoded_any = tokenizer.apply_chat_template( + encoded_any = processor.apply_chat_template( normalized_conv, tokenize=True, add_generation_prompt=False, return_assistant_tokens_mask=True, return_dict=True, ) - encoded = cast("dict[str, Any]", encoded_any) + encoded = cast("BatchEncoding | BatchFeature", encoded_any) # input IDs and loss mask input_ids = encoded["input_ids"] @@ -308,7 +324,7 @@ def _preprocess_batch( assert assistant_pattern is not None, ( "Assistant pattern required for fallback" ) - formatted_raw = tokenizer.apply_chat_template( + formatted_raw = processor.apply_chat_template( normalized_conv, tokenize=False, add_generation_prompt=False, @@ -316,17 +332,18 @@ def _preprocess_batch( assert isinstance(formatted_raw, str) # Tokenize and get offsets - encoding = tokenizer( + encoded_any = processor( formatted_raw, return_offsets_mapping=True, max_length=max_length, truncation=True, add_special_tokens=False, ) + encoded = cast("BatchEncoding | BatchFeature", encoded_any) # input IDs and loss mask - input_ids = encoding["input_ids"] - offsets = encoding["offset_mapping"] + input_ids = encoded["input_ids"] + offsets = encoded["offset_mapping"] loss_mask = _create_loss_mask_from_offsets( formatted_raw, offsets, assistant_pattern @@ -349,6 +366,12 @@ def _preprocess_batch( results["loss_mask"].append(loss_mask) results["seq_len"].append(len(input_ids)) + if MM_KEYS.intersection(results.keys()): + if "messages" not in results: + results["messages"] = [] + + results["messages"].append(conv) + except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: log.error( f"Failed to process conversation {idx} " @@ -361,7 +384,7 @@ def _preprocess_batch( def build_eagle3_dataset( dataset: HFDataset, - tokenizer: PreTrainedTokenizerBase, + processor: ProcessorLike, max_length: int = 2048, num_proc: int = 8, assistant_pattern: str | Pattern[str] | None = None, @@ -370,11 +393,11 @@ def build_eagle3_dataset( ) -> HFDataset: """Build EAGLE3 dataset by tokenizing conversations and creating loss masks. - Uses the tokenizer's built-in chat template via apply_chat_template. + Uses the processor's built-in chat template via apply_chat_template. Args: dataset: Raw dataset with conversations - tokenizer: Tokenizer with chat template support + processor: Processor with chat template support max_length: Maximum sequence length num_proc: Number of processes for parallel processing assistant_pattern: Optional custom regex pattern for matching assistant @@ -387,11 +410,11 @@ def build_eagle3_dataset( # Detect and use provided assistant message pattern if assistant_pattern is not None: log.info(f"Using custom assistant pattern: {str(assistant_pattern)[:80]}...") - elif _supports_assistant_mask(tokenizer): + elif _supports_assistant_mask(processor): assistant_pattern = None # Signal to use HF mask in _preprocess_batch log.info("Using HF assistant token mask for loss masking") else: - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...") original_cols = dataset.column_names @@ -399,7 +422,7 @@ def build_eagle3_dataset( dataset = dataset.map( lambda examples: _preprocess_batch( examples, - tokenizer, + processor, max_length, assistant_pattern, turn_dropout, @@ -416,7 +439,11 @@ def build_eagle3_dataset( return dataset -def load_raw_dataset(train_data_path: str, num_proc: int = 8) -> HFDataset: +def load_raw_dataset( + train_data_path: str, + num_proc: int = 8, + trust_remote_code: bool = False, +) -> HFDataset: """Load raw dataset from local file or HuggingFace.""" if train_data_path.endswith((".jsonl", ".json")): return load_dataset("json", data_files=train_data_path, split="train") @@ -428,7 +455,15 @@ def load_raw_dataset(train_data_path: str, num_proc: int = 8) -> HFDataset: ) config = DATASET_CONFIGS[train_data_path] - raw_dataset = load_dataset(config.hf_path, split=config.split) + raw_dataset = load_dataset( + config.hf_path, + split=config.split, + trust_remote_code=trust_remote_code, + storage_options={ + # Avoid timeout when downloading COCO dataset + "client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)} + }, + ) if config.normalize_fn is not None: raw_dataset = raw_dataset.map(config.normalize_fn, num_proc=num_proc) @@ -447,10 +482,11 @@ def load_and_preprocess_dataset( assistant_pattern: str | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, -) -> tuple[HFDataset, PreTrainedTokenizerBase]: + trust_remote_code: bool = False, +) -> tuple[HFDataset, ProcessorLike]: """Load, tokenize, and preprocess a dataset for EAGLE3 training. - Uses the tokenizer's built-in chat template via apply_chat_template. + Uses the processor's built-in chat template via apply_chat_template. Caching is handled automatically by HuggingFace datasets. Args: @@ -468,9 +504,10 @@ def load_and_preprocess_dataset( turn_dropout: If True, randomly keeps first N consecutive turns per conversation minimum_valid_tokens: Number of tokens to consider for a valid sample + trust_remote_code: If True, allows executing code from HF Hub. Returns: - Tuple of (preprocessed_dataset, tokenizer) + Tuple of (preprocessed_dataset, processor) """ if minimum_valid_tokens is not None and minimum_valid_tokens < 0: raise ValueError("minimum_valid_tokens must be >= 0") @@ -480,21 +517,31 @@ def load_and_preprocess_dataset( f"Filtering samples with fewer than {minimum_valid_tokens} valid tokens" ) - log.subsection("Loading tokenizer") - tokenizer = AutoTokenizer.from_pretrained(target_model_path, trust_remote_code=True) + log.subsection("Loading processor") + processor = AutoProcessor.from_pretrained( + target_model_path, + trust_remote_code=trust_remote_code, + ) + tokenizer = ( + processor.tokenizer if isinstance(processor, ProcessorMixin) else processor + ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: raise ValueError( - f"Tokenizer for {target_model_path} does not support chat templates. " + f"Processor for {target_model_path} does not support chat templates. " "Please use a model with a pre-configured chat template." ) processed_datasets = [] for train_data_path in train_data_paths: log.subsection(f"Processing {train_data_path}") - raw_dataset = load_raw_dataset(train_data_path, num_proc=build_dataset_num_proc) + raw_dataset = load_raw_dataset( + train_data_path, + num_proc=build_dataset_num_proc, + trust_remote_code=trust_remote_code, + ) raw_dataset = raw_dataset.shuffle(seed=seed) if max_samples is not None and len(raw_dataset) > 3 * max_samples: @@ -510,7 +557,7 @@ def load_and_preprocess_dataset( preprocessed_dataset = build_eagle3_dataset( dataset=raw_dataset, - tokenizer=tokenizer, + processor=processor, max_length=seq_length, num_proc=build_dataset_num_proc, assistant_pattern=assistant_pattern, @@ -533,8 +580,8 @@ def load_and_preprocess_dataset( ) log.subsection("Visualizing sample") - _visualize_sample(combined_dataset, tokenizer, idx=0) + _visualize_sample(combined_dataset, processor, idx=0) log.section("Dataset preprocessing complete") - return combined_dataset, tokenizer + return combined_dataset, processor diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index da766a656..2473b7197 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -4,6 +4,7 @@ import time import openai +from openai.types.chat import ChatCompletionMessageParam logger = logging.getLogger(__name__) @@ -104,6 +105,8 @@ async def generate_hidden_states_async( client: openai.AsyncClient, model: str, token_ids: list[int], + *, + messages: list[ChatCompletionMessageParam] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ @@ -114,15 +117,27 @@ async def generate_hidden_states_async( client: The async OpenAI client. model: The model ID. token_ids: The input token IDs. + messages: If provided, pass `messages` to Chat Completions API + instead of passing `token_ids` to Completions API. timeout: Timeout in seconds for each request attempt. None for no timeout. """ - coro = client.completions.create( - model=model, - prompt=token_ids, - max_tokens=1, - extra_body={"return_token_ids": True}, - timeout=timeout, - ) + if messages is None: + coro = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + else: + coro = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + if timeout is not None: completion = await asyncio.wait_for(coro, timeout=timeout) else: @@ -136,17 +151,29 @@ def generate_hidden_states( client: openai.Client, model: str, token_ids: list[int], + *, + messages: list[ChatCompletionMessageParam] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ - completion = client.completions.create( - model=model, - prompt=token_ids, - max_tokens=1, - extra_body={"return_token_ids": True}, - timeout=timeout, - ) + if messages is None: + completion = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + else: + completion = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + return extract_output(completion, token_ids) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 97ec87828..ed0fc6f6c 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -258,12 +258,16 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: if not self.client: self._setup_client() - input_ids = self.data[index]["input_ids"].tolist() + item = self.data[index] + input_ids = item["input_ids"].tolist() + messages = self.data.get("messages") + try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] input_ids, + messages=messages, timeout=self.request_timeout, max_retries=self.max_retries, ) From 76a7a2d6dcd5b55304b253a46a6ecd018a9e0849 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 30 Apr 2026 17:56:30 +0000 Subject: [PATCH 02/95] Doc Signed-off-by: DarkLight1337 --- scripts/prepare_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 64103278d..4d19a3cd2 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -53,7 +53,7 @@ def parse_args(): "--trust-remote-code", action="store_true", help=( - "Allow executing code from HF Hub when loading the target model's. " + "Allow executing code from HF Hub when loading the target model's " "processor. Also applies to datasets." ), ) From b8f3f4a3287f6ff4ed5c22c6e18dc053d727b856 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 30 Apr 2026 17:59:04 +0000 Subject: [PATCH 03/95] Doc Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 2d8900082..5170a5ef2 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -66,7 +66,7 @@ def _normalize_coco(example: dict) -> dict: normalize_fn=_normalize_ultrachat, ), # NOTE: `datasets<4` is needed to run custom script - # You also need to pass `--allowed-local-media-path` to `launch_vllm.py` + # You also need to pass `--allowed-local-media-path /` to `launch_vllm.py` "coco": DatasetConfig( name="coco", hf_path="HuggingFaceM4/COCO", From 8688e1b76cf4efa727f5a71f20d0be739f20dcfa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 06:43:52 +0000 Subject: [PATCH 04/95] Iterate Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 2 +- src/speculators/data_generation/configs.py | 49 +++++++++----- .../data_generation/preprocessing.py | 67 ++++++++++++------- src/speculators/train/data.py | 2 +- 4 files changed, 77 insertions(+), 43 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 37650a0fa..3da17a126 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -277,7 +277,7 @@ async def worker( continue input_ids = item["input_ids"].tolist() - messages = item.get("messages") + messages = item.get("_vllm_messages") target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 5170a5ef2..a04ace129 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -1,5 +1,6 @@ """Configuration registries for data generation pipeline.""" +import random from collections.abc import Callable from dataclasses import dataclass @@ -32,25 +33,41 @@ def _normalize_ultrachat(example: dict) -> dict: def _normalize_coco(example: dict) -> dict: - image_path_local = example["image"] + pil_image = example["image"] + selected_task = random.choice(COCO_TASKS) - conversations = [ - [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": f"file://{image_path_local}"}, - { - "type": "text", - "text": task, - }, - ], - } - for task in COCO_TASKS - ] + hf_messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": pil_image, + }, + { + "type": "text", + "text": selected_task, + }, + ], + } + ] + vllm_messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": f"file://{pil_image.filename}", + }, + { + "type": "text", + "text": selected_task, + }, + ], + } ] - return {"conversations": conversations} + return {"conversations": hf_messages, "_vllm_messages": vllm_messages} DATASET_CONFIGS: dict[str, DatasetConfig] = { diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index f7f5a6c78..7cd03e79c 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -1,9 +1,10 @@ import bisect import random import re +from collections.abc import Callable from pathlib import Path from re import Pattern -from typing import cast +from typing import Literal, cast import aiohttp import torch @@ -117,14 +118,29 @@ def _normalize_conversation( return normalized -def _supports_assistant_mask(processor: ProcessorLike) -> bool: +def _supports_assistant_mask( + processor: ProcessorLike, + chat_template_content_format: Literal["string", "openai"] | None = None, +) -> bool: """Check if processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ + if chat_template_content_format is None: + return any( + _supports_assistant_mask(processor, chat_template_content_format) + for chat_template_content_format in ("string", "openai") + ) + + content = ( + "test" + if chat_template_content_format == "string" + else [{"type": "text", "text": "test"}] + ) + try: res_any = processor.apply_chat_template( - [{"role": "assistant", "content": "test"}], + [{"role": "assistant", "content": content}], tokenize=True, return_assistant_tokens_mask=True, return_dict=True, @@ -137,7 +153,11 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: # Verify the mask is not all zeros return any(m == 1 for m in mask) - except (TypeError, ValueError, KeyError, AttributeError): + except (TypeError, ValueError, KeyError, AttributeError) as e: + log.warning( + f"An error occurred when trying to return assistant mask " + f"({chat_template_content_format=}): {e}" + ) return False @@ -265,11 +285,6 @@ def _create_loss_mask_from_offsets( return loss_mask -# Keys in processor outputs that indicate that the conversation has multi-modal data -# In that case, we must use Chat Completions API to send the multi-modal inputs to vLLM -MM_KEYS = {"pixel_values", "pixel_values_videos", "input_features"} - - def _preprocess_batch( examples: dict, processor: ProcessorLike, @@ -283,6 +298,11 @@ def _preprocess_batch( results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} conversations = examples.get("conversations", []) + # A special key defining Chat Completion API messages to pass directly to vLLM + # It should be consistent with the `conversations` passed to `apply_chat_template` + if "_vllm_messages" in examples: + results["_vllm_messages"] = examples["_vllm_messages"] + if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") return results @@ -365,13 +385,6 @@ def _preprocess_batch( results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) results["loss_mask"].append(loss_mask) results["seq_len"].append(len(input_ids)) - - if MM_KEYS.intersection(results.keys()): - if "messages" not in results: - results["messages"] = [] - - results["messages"].append(conv) - except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: log.error( f"Failed to process conversation {idx} " @@ -441,12 +454,12 @@ def build_eagle3_dataset( def load_raw_dataset( train_data_path: str, - num_proc: int = 8, + *, trust_remote_code: bool = False, -) -> HFDataset: +) -> tuple[HFDataset, Callable[[dict], dict] | None]: """Load raw dataset from local file or HuggingFace.""" if train_data_path.endswith((".jsonl", ".json")): - return load_dataset("json", data_files=train_data_path, split="train") + return load_dataset("json", data_files=train_data_path, split="train"), None if train_data_path not in DATASET_CONFIGS: raise ValueError( @@ -465,15 +478,13 @@ def load_raw_dataset( }, ) - if config.normalize_fn is not None: - raw_dataset = raw_dataset.map(config.normalize_fn, num_proc=num_proc) - - return raw_dataset + return raw_dataset, config.normalize_fn def load_and_preprocess_dataset( target_model_path: str, train_data_paths: list[str], + *, seq_length: int, build_dataset_num_proc: int = 8, seed: int = 0, @@ -537,9 +548,8 @@ def load_and_preprocess_dataset( processed_datasets = [] for train_data_path in train_data_paths: log.subsection(f"Processing {train_data_path}") - raw_dataset = load_raw_dataset( + raw_dataset, normalize_fn = load_raw_dataset( train_data_path, - num_proc=build_dataset_num_proc, trust_remote_code=trust_remote_code, ) raw_dataset = raw_dataset.shuffle(seed=seed) @@ -550,6 +560,13 @@ def load_and_preprocess_dataset( # after combining datasets and shuffling raw_dataset = raw_dataset.select(range(3 * max_samples)) + if normalize_fn is not None: + raw_dataset = raw_dataset.map( + normalize_fn, + num_proc=build_dataset_num_proc, + keep_in_memory=True, # skip caching + ) + log.info(f"Loaded {len(raw_dataset)} samples") if turn_dropout: diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index ed0fc6f6c..7f36449ee 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -260,7 +260,7 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: item = self.data[index] input_ids = item["input_ids"].tolist() - messages = self.data.get("messages") + messages = self.data.get("_vllm_messages") try: hs_filepath = generate_hidden_states( From ddabe1ce6a83858df89a526aa228219d11d8355d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 08:34:12 +0000 Subject: [PATCH 05/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 38 ++-- .../data_generation/preprocessing.py | 170 +++++++++++------- 2 files changed, 128 insertions(+), 80 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index a04ace129..d69588ea9 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -1,6 +1,5 @@ """Configuration registries for data generation pipeline.""" -import random from collections.abc import Callable from dataclasses import dataclass @@ -26,15 +25,10 @@ def _normalize_ultrachat(example: dict) -> dict: return example -COCO_TASKS = [ - "Locate each object in this image.", - "Describe the image with a brief caption.", -] - - def _normalize_coco(example: dict) -> dict: pil_image = example["image"] - selected_task = random.choice(COCO_TASKS) + task = "Describe the image with a brief caption." + caption = example["sentences"]["raw"] hf_messages = [ { @@ -42,14 +36,23 @@ def _normalize_coco(example: dict) -> dict: "content": [ { "type": "image", - "image": pil_image, + "path": pil_image.filename, }, { "type": "text", - "text": selected_task, + "text": task, }, ], - } + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": caption, + } + ] + }, ] vllm_messages = [ { @@ -61,10 +64,19 @@ def _normalize_coco(example: dict) -> dict: }, { "type": "text", - "text": selected_task, + "text": task, }, ], - } + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": caption, + } + ] + }, ] return {"conversations": hf_messages, "_vllm_messages": vllm_messages} diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 7cd03e79c..ff89ac5a9 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -285,6 +285,87 @@ def _create_loss_mask_from_offsets( return loss_mask +def _get_input_ids_loss_mask( + normalized_conv: list[dict], + processor: ProcessorLike, + max_length: int, + assistant_pattern: str | Pattern[str] | None, +): + encoded: BatchEncoding | BatchFeature + + if assistant_pattern is None: + # HF assistant token mask + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) + encoded = cast("BatchEncoding | BatchFeature", encoded_any) + + # input IDs and loss mask + input_ids = encoded["input_ids"] + # HF uses 'assistant_masks' in recent versions + mask_key = ( + "assistant_masks" + if "assistant_masks" in encoded + else "assistant_mask" + ) + loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) + + else: + # Fallback: regex-based detection + assert assistant_pattern is not None, ( + "Assistant pattern required for fallback" + ) + + processor_kwargs: dict = { + "return_offsets_mapping": True, + "max_length": max_length, + "truncation": True, + "add_special_tokens": False, + } + + if isinstance(processor, ProcessorMixin): + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + processor_kwargs=processor_kwargs, + return_dict=True, + ) + encoded = cast("BatchFeature", encoded_any) + + # Remove batch dimension + (input_ids,) = encoded["input_ids"] + (offsets,) = encoded["offset_mapping"] + + # MM placeholder tokens are inserted separate from chat template + formatted_text = processor.decode(input_ids) + assert isinstance(formatted_text, str) + else: + formatted_text = processor.apply_chat_template( + normalized_conv, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(formatted_text, str) + + # Tokenize and get offsets + encoded_any = processor(formatted_text, **processor_kwargs) + encoded = cast("BatchEncoding", encoded_any) + + input_ids = encoded["input_ids"] + offsets = encoded["offset_mapping"] + + loss_mask = _create_loss_mask_from_offsets( + formatted_text, offsets, assistant_pattern + ) + + return input_ids, loss_mask + + def _preprocess_batch( examples: dict, processor: ProcessorLike, @@ -316,75 +397,13 @@ def _preprocess_batch( if not normalized_conv: continue - encoded: BatchEncoding | BatchFeature try: - if assistant_pattern is None: - # HF assistant token mask - encoded_any = processor.apply_chat_template( - normalized_conv, - tokenize=True, - add_generation_prompt=False, - return_assistant_tokens_mask=True, - return_dict=True, - ) - encoded = cast("BatchEncoding | BatchFeature", encoded_any) - - # input IDs and loss mask - input_ids = encoded["input_ids"] - # HF uses 'assistant_masks' in recent versions - mask_key = ( - "assistant_masks" - if "assistant_masks" in encoded - else "assistant_mask" - ) - loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) - - else: - # Fallback: regex-based detection - assert assistant_pattern is not None, ( - "Assistant pattern required for fallback" - ) - formatted_raw = processor.apply_chat_template( - normalized_conv, - tokenize=False, - add_generation_prompt=False, - ) - assert isinstance(formatted_raw, str) - - # Tokenize and get offsets - encoded_any = processor( - formatted_raw, - return_offsets_mapping=True, - max_length=max_length, - truncation=True, - add_special_tokens=False, - ) - encoded = cast("BatchEncoding | BatchFeature", encoded_any) - - # input IDs and loss mask - input_ids = encoded["input_ids"] - offsets = encoded["offset_mapping"] - - loss_mask = _create_loss_mask_from_offsets( - formatted_raw, offsets, assistant_pattern - ) - - # Assert shapes match - assert len(input_ids) == len(loss_mask), ( - f"Shape mismatch: input_ids={len(input_ids)}, " - f"loss_mask={len(loss_mask)}" + input_ids, loss_mask = _get_input_ids_loss_mask( + normalized_conv, + processor, + max_length=max_length, + assistant_pattern=assistant_pattern, ) - - # Filtering samples out with too few valid tokens - if minimum_valid_tokens is not None: - num_valid_tokens = int(loss_mask.sum().item()) - if num_valid_tokens < minimum_valid_tokens: - continue - - # Append to results - results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) - results["loss_mask"].append(loss_mask) - results["seq_len"].append(len(input_ids)) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: log.error( f"Failed to process conversation {idx} " @@ -392,6 +411,23 @@ def _preprocess_batch( ) continue + # Assert shapes match + assert len(input_ids) == len(loss_mask), ( + f"Shape mismatch: input_ids={len(input_ids)}, " + f"loss_mask={len(loss_mask)}" + ) + + # Filtering samples out with too few valid tokens + if minimum_valid_tokens is not None: + num_valid_tokens = int(loss_mask.sum().item()) + if num_valid_tokens < minimum_valid_tokens: + continue + + # Append to results + results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) + results["loss_mask"].append(loss_mask) + results["seq_len"].append(len(input_ids)) + return results From c9f845ee9af14286bb4df521188c5e0f248d41dd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 08:49:07 +0000 Subject: [PATCH 06/95] Simplify Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 47 +++++++++---------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index ff89ac5a9..abf955b18 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -4,7 +4,7 @@ from collections.abc import Callable from pathlib import Path from re import Pattern -from typing import Literal, cast +from typing import cast import aiohttp import torch @@ -118,34 +118,31 @@ def _normalize_conversation( return normalized -def _supports_assistant_mask( - processor: ProcessorLike, - chat_template_content_format: Literal["string", "openai"] | None = None, -) -> bool: +def _supports_assistant_mask(processor: ProcessorLike) -> bool: """Check if processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ - if chat_template_content_format is None: - return any( - _supports_assistant_mask(processor, chat_template_content_format) - for chat_template_content_format in ("string", "openai") - ) - - content = ( - "test" - if chat_template_content_format == "string" - else [{"type": "text", "text": "test"}] - ) + chat_template_kwargs: dict = { + "tokenize": True, + "return_assistant_tokens_mask": True, + "return_dict": True, + } try: - res_any = processor.apply_chat_template( - [{"role": "assistant", "content": content}], - tokenize=True, - return_assistant_tokens_mask=True, - return_dict=True, - ) - res = cast("BatchEncoding | BatchFeature", res_any) + if isinstance(processor, ProcessorMixin): + res_any = processor.apply_chat_template( + [{"role": "assistant", "content": [{"type": "text", "text": "test"}]}], + **chat_template_kwargs, + ) + res = cast("BatchFeature", res_any) + else: + res_any = processor.apply_chat_template( + [{"role": "assistant", "content": "test"}], + **chat_template_kwargs, + ) + res = cast("BatchEncoding", res_any) + # Check both singular and plural key names mask = res.get("assistant_masks", res.get("assistant_mask")) if mask is None: @@ -155,8 +152,7 @@ def _supports_assistant_mask( return any(m == 1 for m in mask) except (TypeError, ValueError, KeyError, AttributeError) as e: log.warning( - f"An error occurred when trying to return assistant mask " - f"({chat_template_content_format=}): {e}" + f"An error occurred when trying to return assistant mask: {e}" ) return False @@ -345,6 +341,7 @@ def _get_input_ids_loss_mask( formatted_text = processor.decode(input_ids) assert isinstance(formatted_text, str) else: + # More optimized flow for text-only processors (i.e. tokenizers) formatted_text = processor.apply_chat_template( normalized_conv, tokenize=False, From 7b44b1aeadbec35ef0c32bb836cc69acc8dbd744 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 08:50:54 +0000 Subject: [PATCH 07/95] Clean Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index abf955b18..5ffeff358 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -328,8 +328,8 @@ def _get_input_ids_loss_mask( normalized_conv, tokenize=True, add_generation_prompt=False, - processor_kwargs=processor_kwargs, return_dict=True, + processor_kwargs=processor_kwargs, ) encoded = cast("BatchFeature", encoded_any) From ce2a07258d74a09fb3dd51fa520ea11272de2738 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 09:49:09 +0000 Subject: [PATCH 08/95] Use ShareGPT4V to avoid outdated version of datasets Signed-off-by: DarkLight1337 --- scripts/prepare_data.py | 2 +- src/speculators/data_generation/configs.py | 114 ++++++++++-------- .../data_generation/preprocessing.py | 33 +---- 3 files changed, 70 insertions(+), 79 deletions(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 4d19a3cd2..69604192d 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -54,7 +54,7 @@ def parse_args(): action="store_true", help=( "Allow executing code from HF Hub when loading the target model's " - "processor. Also applies to datasets." + "processor." ), ) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index d69588ea9..ae707d05e 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -1,5 +1,6 @@ """Configuration registries for data generation pipeline.""" +import os from collections.abc import Callable from dataclasses import dataclass @@ -9,74 +10,85 @@ ] -@dataclass +@dataclass(kw_only=True) class DatasetConfig: """Configuration for loading a dataset""" name: str hf_path: str + hf_name: str | None = None split: str normalize_fn: Callable[[dict], dict] | None = None +def hf_to_vllm_part(part: str | dict): + if isinstance(part, str): + return {"type": "text", "text": part} + + part_type = part["type"] + + for modality in ("image", "video", "audio"): + if part_type == modality: + if local_path := part.get("path"): + file_url = f"file://{local_path}" + return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} + if url := part.get("url"): + return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} + + fields_expr = {f"part.{k}" for k in part if k != "type"} + + raise NotImplementedError( + f"No handler defined in part.type={part_type!r} " + f"for fields: {fields_expr}" + ) + + raise NotImplementedError(f"No handler defined for part.type={part_type!r}") + + +def get_coco_dir(): + return os.getenv("COCO_DIR") or "coco/" + + def _normalize_ultrachat(example: dict) -> dict: if "messages" in example: return {"conversations": example["messages"]} return example -def _normalize_coco(example: dict) -> dict: - pil_image = example["image"] - task = "Describe the image with a brief caption." - caption = example["sentences"]["raw"] +def _unformat_sharegpt4v(part: str, image_path: str): + if part == "": + return {"type": "image", "path": image_path} + + return {"type": "text", "text": part} + + +def _normalize_sharegpt4v(example: dict) -> dict: + image_path: str = example["image"] + image_path = os.path.join(get_coco_dir(), image_path.removeprefix("coco/")) + + if not os.path.exists(image_path): + raise ValueError( + "Please download COCO 2017 Train Images from " + "http://images.cocodataset.org/zips/train2017.zip and " + "place the files under `COCO_DIR` (default: `./coco`)." + ) hf_messages = [ { - "role": "user", + **turn, "content": [ - { - "type": "image", - "path": pil_image.filename, - }, - { - "type": "text", - "text": task, - }, + _unformat_sharegpt4v(part, image_path) + for part in turn.pop("value").split("\n") ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": caption, - } - ] - }, + } + for turn in example["conversations"] ] vllm_messages = [ { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": f"file://{pil_image.filename}", - }, - { - "type": "text", - "text": task, - }, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": caption, - } - ] - }, + **turn, + "content": [hf_to_vllm_part(part) for part in turn["content"]], + } + for turn in hf_messages ] return {"conversations": hf_messages, "_vllm_messages": vllm_messages} @@ -94,12 +106,12 @@ def _normalize_coco(example: dict) -> dict: split="train_sft", normalize_fn=_normalize_ultrachat, ), - # NOTE: `datasets<4` is needed to run custom script - # You also need to pass `--allowed-local-media-path /` to `launch_vllm.py` - "coco": DatasetConfig( - name="coco", - hf_path="HuggingFaceM4/COCO", + # NOTE: You need to pass `--allowed-local-media-path /` to `launch_vllm.py` + "sharegpt4v": DatasetConfig( + name="sharegpt4v", + hf_path="Lin-Chen/ShareGPT4V", + hf_name="ShareGPT4V", split="train", - normalize_fn=_normalize_coco, + normalize_fn=_normalize_sharegpt4v, ), } diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 5ffeff358..49608f862 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -6,7 +6,6 @@ from re import Pattern from typing import cast -import aiohttp import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset @@ -151,9 +150,7 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: # Verify the mask is not all zeros return any(m == 1 for m in mask) except (TypeError, ValueError, KeyError, AttributeError) as e: - log.warning( - f"An error occurred when trying to return assistant mask: {e}" - ) + log.warning(f"An error occurred when trying to return assistant mask: {e}") return False @@ -304,17 +301,13 @@ def _get_input_ids_loss_mask( input_ids = encoded["input_ids"] # HF uses 'assistant_masks' in recent versions mask_key = ( - "assistant_masks" - if "assistant_masks" in encoded - else "assistant_mask" + "assistant_masks" if "assistant_masks" in encoded else "assistant_mask" ) loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) else: # Fallback: regex-based detection - assert assistant_pattern is not None, ( - "Assistant pattern required for fallback" - ) + assert assistant_pattern is not None, "Assistant pattern required for fallback" processor_kwargs: dict = { "return_offsets_mapping": True, @@ -410,8 +403,7 @@ def _preprocess_batch( # Assert shapes match assert len(input_ids) == len(loss_mask), ( - f"Shape mismatch: input_ids={len(input_ids)}, " - f"loss_mask={len(loss_mask)}" + f"Shape mismatch: input_ids={len(input_ids)}, loss_mask={len(loss_mask)}" ) # Filtering samples out with too few valid tokens @@ -487,8 +479,6 @@ def build_eagle3_dataset( def load_raw_dataset( train_data_path: str, - *, - trust_remote_code: bool = False, ) -> tuple[HFDataset, Callable[[dict], dict] | None]: """Load raw dataset from local file or HuggingFace.""" if train_data_path.endswith((".jsonl", ".json")): @@ -501,15 +491,7 @@ def load_raw_dataset( ) config = DATASET_CONFIGS[train_data_path] - raw_dataset = load_dataset( - config.hf_path, - split=config.split, - trust_remote_code=trust_remote_code, - storage_options={ - # Avoid timeout when downloading COCO dataset - "client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)} - }, - ) + raw_dataset = load_dataset(config.hf_path, config.hf_name, split=config.split) return raw_dataset, config.normalize_fn @@ -581,10 +563,7 @@ def load_and_preprocess_dataset( processed_datasets = [] for train_data_path in train_data_paths: log.subsection(f"Processing {train_data_path}") - raw_dataset, normalize_fn = load_raw_dataset( - train_data_path, - trust_remote_code=trust_remote_code, - ) + raw_dataset, normalize_fn = load_raw_dataset(train_data_path) raw_dataset = raw_dataset.shuffle(seed=seed) if max_samples is not None and len(raw_dataset) > 3 * max_samples: From a81744a4171f34eefab5bc673032bbbce08974aa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 09:52:58 +0000 Subject: [PATCH 09/95] Improve UX Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index ae707d05e..ccf4088ff 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -67,10 +67,12 @@ def _normalize_sharegpt4v(example: dict) -> dict: image_path = os.path.join(get_coco_dir(), image_path.removeprefix("coco/")) if not os.path.exists(image_path): + state_str = "set to" if os.getenv("COCO_DIR") else "default" + raise ValueError( - "Please download COCO 2017 Train Images from " - "http://images.cocodataset.org/zips/train2017.zip and " - "place the files under `COCO_DIR` (default: `./coco`)." + f"Please download COCO 2017 Train Images from " + f"http://images.cocodataset.org/zips/train2017.zip and " + f"place the files under `COCO_DIR` ({state_str}: `./coco`)." ) hf_messages = [ From 5f9a87c486db96ef7e682af49896dd3bb99b4073 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 09:54:23 +0000 Subject: [PATCH 10/95] Again Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index ccf4088ff..53b6903a2 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -63,16 +63,16 @@ def _unformat_sharegpt4v(part: str, image_path: str): def _normalize_sharegpt4v(example: dict) -> dict: - image_path: str = example["image"] - image_path = os.path.join(get_coco_dir(), image_path.removeprefix("coco/")) + coco_dir = get_coco_dir() + image_path = os.path.join(coco_dir, example["image"].removeprefix("coco/")) if not os.path.exists(image_path): state_str = "set to" if os.getenv("COCO_DIR") else "default" raise ValueError( f"Please download COCO 2017 Train Images from " - f"http://images.cocodataset.org/zips/train2017.zip and " - f"place the files under `COCO_DIR` ({state_str}: `./coco`)." + f" and " + f"place the files under `COCO_DIR` ({state_str}: `{coco_dir}`)." ) hf_messages = [ From e827b771042b0a3c81c1034d5fa847ea5967a88c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 09:58:25 +0000 Subject: [PATCH 11/95] Improve Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 53b6903a2..eaa938399 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -71,8 +71,8 @@ def _normalize_sharegpt4v(example: dict) -> dict: raise ValueError( f"Please download COCO 2017 Train Images from " - f" and " - f"place the files under `COCO_DIR` ({state_str}: `{coco_dir}`)." + f" and place the " + f"extracted folder under `COCO_DIR` ({state_str}: `{coco_dir}`)." ) hf_messages = [ From 4f0f9b84444e517e845ae92de18001603c7409a5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 10:13:01 +0000 Subject: [PATCH 12/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 21 +++++++++++++------ .../data_generation/preprocessing.py | 3 +++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index eaa938399..19ff04990 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -18,6 +18,7 @@ class DatasetConfig: hf_path: str hf_name: str | None = None split: str + filter_fn: Callable[[dict], bool] | None = None normalize_fn: Callable[[dict], dict] | None = None @@ -27,6 +28,9 @@ def hf_to_vllm_part(part: str | dict): part_type = part["type"] + if part_type == "text": + return {"type": "text", "text": part["text"]} + for modality in ("image", "video", "audio"): if part_type == modality: if local_path := part.get("path"): @@ -62,7 +66,11 @@ def _unformat_sharegpt4v(part: str, image_path: str): return {"type": "text", "text": part} -def _normalize_sharegpt4v(example: dict) -> dict: +def _filter_sharegpt4v_coco(example: dict) -> bool: + return example["image"].startswith("coco/") + + +def _normalize_sharegpt4v_coco(example: dict) -> dict: coco_dir = get_coco_dir() image_path = os.path.join(coco_dir, example["image"].removeprefix("coco/")) @@ -77,18 +85,18 @@ def _normalize_sharegpt4v(example: dict) -> dict: hf_messages = [ { - **turn, "content": [ _unformat_sharegpt4v(part, image_path) for part in turn.pop("value").split("\n") ], + **turn, } for turn in example["conversations"] ] vllm_messages = [ { + "content": [hf_to_vllm_part(part) for part in turn.pop("content")], **turn, - "content": [hf_to_vllm_part(part) for part in turn["content"]], } for turn in hf_messages ] @@ -109,11 +117,12 @@ def _normalize_sharegpt4v(example: dict) -> dict: normalize_fn=_normalize_ultrachat, ), # NOTE: You need to pass `--allowed-local-media-path /` to `launch_vllm.py` - "sharegpt4v": DatasetConfig( - name="sharegpt4v", + "sharegpt4v_coco": DatasetConfig( + name="sharegpt4v_coco", hf_path="Lin-Chen/ShareGPT4V", hf_name="ShareGPT4V", split="train", - normalize_fn=_normalize_sharegpt4v, + filter_fn=_filter_sharegpt4v_coco, + normalize_fn=_normalize_sharegpt4v_coco, ), } diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 49608f862..e8beec779 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -493,6 +493,9 @@ def load_raw_dataset( config = DATASET_CONFIGS[train_data_path] raw_dataset = load_dataset(config.hf_path, config.hf_name, split=config.split) + if config.filter_fn is not None: + raw_dataset = raw_dataset.filter(config.filter_fn) + return raw_dataset, config.normalize_fn From 3101671485e152073ec698b94b3d6eb0fbce12b4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 10:41:58 +0000 Subject: [PATCH 13/95] Fixes Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 65 ++++++------------- .../data_generation/preprocessing.py | 46 +++++++++++-- 2 files changed, 62 insertions(+), 49 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 19ff04990..6a73dd969 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -22,33 +22,6 @@ class DatasetConfig: normalize_fn: Callable[[dict], dict] | None = None -def hf_to_vllm_part(part: str | dict): - if isinstance(part, str): - return {"type": "text", "text": part} - - part_type = part["type"] - - if part_type == "text": - return {"type": "text", "text": part["text"]} - - for modality in ("image", "video", "audio"): - if part_type == modality: - if local_path := part.get("path"): - file_url = f"file://{local_path}" - return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} - if url := part.get("url"): - return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} - - fields_expr = {f"part.{k}" for k in part if k != "type"} - - raise NotImplementedError( - f"No handler defined in part.type={part_type!r} " - f"for fields: {fields_expr}" - ) - - raise NotImplementedError(f"No handler defined for part.type={part_type!r}") - - def get_coco_dir(): return os.getenv("COCO_DIR") or "coco/" @@ -59,13 +32,21 @@ def _normalize_ultrachat(example: dict) -> dict: return example -def _unformat_sharegpt4v(part: str, image_path: str): +def _parse_sharegpt4v_part(part: str, image_path: str): if part == "": return {"type": "image", "path": image_path} return {"type": "text", "text": part} +def _parse_sharegpt4v_user_content(content: str, image_path: str): + return [_parse_sharegpt4v_part(part, image_path) for part in content.split("\n")] + + +def _parse_sharegpt4v_assistant_content(content: str): + return [{"type": "text", "text": content}] + + def _filter_sharegpt4v_coco(example: dict) -> bool: return example["image"].startswith("coco/") @@ -83,25 +64,21 @@ def _normalize_sharegpt4v_coco(example: dict) -> dict: f"extracted folder under `COCO_DIR` ({state_str}: `{coco_dir}`)." ) - hf_messages = [ - { - "content": [ - _unformat_sharegpt4v(part, image_path) - for part in turn.pop("value").split("\n") - ], - **turn, - } + messages = [ + ( + turn + | { + "value": ( + _parse_sharegpt4v_user_content(turn["value"], image_path) + if turn["from"] in ("human", "user") + else _parse_sharegpt4v_assistant_content(turn["value"]) + ) + } + ) for turn in example["conversations"] ] - vllm_messages = [ - { - "content": [hf_to_vllm_part(part) for part in turn.pop("content")], - **turn, - } - for turn in hf_messages - ] - return {"conversations": hf_messages, "_vllm_messages": vllm_messages} + return {"conversations": messages} DATASET_CONFIGS: dict[str, DatasetConfig] = { diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index e8beec779..3dd5f4a11 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -117,6 +117,40 @@ def _normalize_conversation( return normalized +def _hf_to_vllm_part(part: str | dict): + if isinstance(part, str): + return {"type": "text", "text": part} + + part_type = part["type"] + + if part_type == "text": + return {"type": "text", "text": part["text"]} + + for modality in ("image", "video", "audio"): + if part_type == modality: + if local_path := part.get("path"): + file_url = f"file://{local_path}" + return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} + if url := part.get("url"): + return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} + + fields_expr = {f"part.{k}" for k in part if k != "type"} + + raise NotImplementedError( + f"No handler defined in part.type={part_type!r} " + f"for fields: {fields_expr}" + ) + + raise NotImplementedError(f"No handler defined for part.type={part_type!r}") + + +def _hf_to_vllm_conv(normalized_conv: list[dict]): + return [ + turn | {"content": [_hf_to_vllm_part(part) for part in turn["content"]]} + for turn in normalized_conv + ] + + def _supports_assistant_mask(processor: ProcessorLike) -> bool: """Check if processor truly supports HF assistant token mask. @@ -367,12 +401,11 @@ def _preprocess_batch( """Process a batch of conversations into tokenized format with loss masks.""" results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} - conversations = examples.get("conversations", []) + conversations: list[dict] = examples.get("conversations", []) - # A special key defining Chat Completion API messages to pass directly to vLLM - # It should be consistent with the `conversations` passed to `apply_chat_template` - if "_vllm_messages" in examples: - results["_vllm_messages"] = examples["_vllm_messages"] + # MM inputs must use Chat Completion API + if isinstance(processor, ProcessorMixin): + results["_vllm_messages"] = [] if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") @@ -417,6 +450,9 @@ def _preprocess_batch( results["loss_mask"].append(loss_mask) results["seq_len"].append(len(input_ids)) + if "_vllm_messages" in results: + results["_vllm_messages"].append(_hf_to_vllm_conv(normalized_conv)) + return results From c53fb6f6c82fb882fc1800268a02a89fa3d1072c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 13:49:20 +0000 Subject: [PATCH 14/95] Update and cover more edge cases Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 2 +- .../data_generation/preprocessing.py | 138 ++++---- .../integration/datagen/test_preprocessing.py | 313 +++++++++++------- .../datagen/test_regex_patterns.py | 77 +++-- 4 files changed, 319 insertions(+), 211 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 6a73dd969..8f57cd26f 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -93,7 +93,7 @@ def _normalize_sharegpt4v_coco(example: dict) -> dict: split="train_sft", normalize_fn=_normalize_ultrachat, ), - # NOTE: You need to pass `--allowed-local-media-path /` to `launch_vllm.py` + # NOTE: You need to serve vLLM with `--allowed-local-media-path /path/to/coco` "sharegpt4v_coco": DatasetConfig( name="sharegpt4v_coco", hf_path="Lin-Chen/ShareGPT4V", diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 3dd5f4a11..424360591 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -144,11 +144,15 @@ def _hf_to_vllm_part(part: str | dict): raise NotImplementedError(f"No handler defined for part.type={part_type!r}") +def _hf_to_vllm_turn(turn: dict): + if isinstance(turn["content"], str): + return turn + + return turn | {"content": [_hf_to_vllm_part(part) for part in turn["content"]]} + + def _hf_to_vllm_conv(normalized_conv: list[dict]): - return [ - turn | {"content": [_hf_to_vllm_part(part) for part in turn["content"]]} - for turn in normalized_conv - ] + return [_hf_to_vllm_turn(turn) for turn in normalized_conv] def _supports_assistant_mask(processor: ProcessorLike) -> bool: @@ -156,6 +160,13 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: Must return a non-zero mask for a conversation containing an assistant message. """ + if isinstance(processor, ProcessorMixin): + test_conv = [ + {"role": "assistant", "content": [{"type": "text", "text": "test"}]} + ] + else: + test_conv = [{"role": "assistant", "content": "test"}] + chat_template_kwargs: dict = { "tokenize": True, "return_assistant_tokens_mask": True, @@ -163,18 +174,8 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: } try: - if isinstance(processor, ProcessorMixin): - res_any = processor.apply_chat_template( - [{"role": "assistant", "content": [{"type": "text", "text": "test"}]}], - **chat_template_kwargs, - ) - res = cast("BatchFeature", res_any) - else: - res_any = processor.apply_chat_template( - [{"role": "assistant", "content": "test"}], - **chat_template_kwargs, - ) - res = cast("BatchEncoding", res_any) + res_any = processor.apply_chat_template(test_conv, **chat_template_kwargs) + res = cast("BatchEncoding | BatchFeature", res_any) # Check both singular and plural key names mask = res.get("assistant_masks", res.get("assistant_mask")) @@ -194,12 +195,20 @@ def _detect_assistant_pattern(processor: ProcessorLike) -> str: Uses multi-turn conversation but extracts pattern from the LAST assistant message only. """ - test_conv = [ - {"role": "user", "content": "USER_MSG_1"}, - {"role": "assistant", "content": "ASSISTANT_MSG_1"}, - {"role": "user", "content": "USER_MSG_2"}, - {"role": "assistant", "content": "ASSISTANT_MSG_2"}, - ] + if isinstance(processor, ProcessorMixin): + test_conv = [ + {"role": "user", "content": [{"type": "text", "text": "USER_MSG_1"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "ASSISTANT_MSG_1"}]}, + {"role": "user", "content": [{"type": "text", "text": "USER_MSG_2"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "ASSISTANT_MSG_2"}]}, + ] + else: + test_conv = [ + {"role": "user", "content": "USER_MSG_1"}, + {"role": "assistant", "content": "ASSISTANT_MSG_1"}, + {"role": "user", "content": "USER_MSG_2"}, + {"role": "assistant", "content": "ASSISTANT_MSG_2"}, + ] formatted = processor.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False @@ -318,8 +327,6 @@ def _get_input_ids_loss_mask( max_length: int, assistant_pattern: str | Pattern[str] | None, ): - encoded: BatchEncoding | BatchFeature - if assistant_pattern is None: # HF assistant token mask encoded_any = processor.apply_chat_template( @@ -339,53 +346,54 @@ def _get_input_ids_loss_mask( ) loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) - else: - # Fallback: regex-based detection - assert assistant_pattern is not None, "Assistant pattern required for fallback" - - processor_kwargs: dict = { - "return_offsets_mapping": True, - "max_length": max_length, - "truncation": True, - "add_special_tokens": False, - } - - if isinstance(processor, ProcessorMixin): - encoded_any = processor.apply_chat_template( - normalized_conv, - tokenize=True, - add_generation_prompt=False, - return_dict=True, - processor_kwargs=processor_kwargs, - ) - encoded = cast("BatchFeature", encoded_any) + return input_ids, loss_mask - # Remove batch dimension - (input_ids,) = encoded["input_ids"] - (offsets,) = encoded["offset_mapping"] + # Fallback: regex-based detection + assert assistant_pattern is not None, "Assistant pattern required for fallback" - # MM placeholder tokens are inserted separate from chat template - formatted_text = processor.decode(input_ids) - assert isinstance(formatted_text, str) - else: - # More optimized flow for text-only processors (i.e. tokenizers) - formatted_text = processor.apply_chat_template( - normalized_conv, - tokenize=False, - add_generation_prompt=False, - ) - assert isinstance(formatted_text, str) + processor_kwargs: dict = { + "return_offsets_mapping": True, + "max_length": max_length, + "truncation": True, + "add_special_tokens": False, + } - # Tokenize and get offsets - encoded_any = processor(formatted_text, **processor_kwargs) - encoded = cast("BatchEncoding", encoded_any) + if isinstance(processor, ProcessorMixin): + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + return_dict=True, + processor_kwargs=processor_kwargs, + ) + encoded = cast("BatchFeature", encoded_any) - input_ids = encoded["input_ids"] - offsets = encoded["offset_mapping"] + # Remove batch dimension + (input_ids,) = encoded["input_ids"] + (offsets,) = encoded["offset_mapping"] - loss_mask = _create_loss_mask_from_offsets( - formatted_text, offsets, assistant_pattern + # MM placeholder tokens are inserted separate from chat template + formatted_text = processor.decode(input_ids) + assert isinstance(formatted_text, str) + else: + # More optimized flow for text-only processors (i.e. tokenizers) + formatted_text = processor.apply_chat_template( + normalized_conv, + tokenize=False, + add_generation_prompt=False, ) + assert isinstance(formatted_text, str) + + # Tokenize and get offsets + encoded_any = processor(formatted_text, **processor_kwargs) + encoded = cast("BatchEncoding", encoded_any) + + input_ids = encoded["input_ids"] + offsets = encoded["offset_mapping"] + + loss_mask = _create_loss_mask_from_offsets( + formatted_text, offsets, assistant_pattern + ) return input_ids, loss_mask diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index a9184c18f..3fbfe24aa 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -7,11 +7,13 @@ import pytest import torch from datasets import Dataset as HFDataset -from transformers import AutoTokenizer +from PIL import Image +from transformers import AutoProcessor from speculators.data_generation.preprocessing import ( _create_loss_mask_from_offsets, _detect_assistant_pattern, + _hf_to_vllm_conv, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, @@ -77,16 +79,79 @@ def test_normalize_conversation_unknown_role(): assert result[1]["role"] == "assistant" +# Tests for _hf_to_vllm_conv +@pytest.mark.sanity +def test_hf_to_vllm_all_content_formats(): + """Test converting from HF-format to vLLM-format messages with each supported content format.""" + conv = [ + { + "role": "system", + "content": "You are a helpful assistant." # Content as string + }, + { + "role": "assistant", + "content": [ # Content as list + "Hello,", # Text as string + {"type": "text", "text": "I am"}, # Text as dictionary + {"type": "image", "path": "/path/to/img"}, # Image file path + {"type": "image", "url": "http://path/to/img"} # Image URL + ] + }, + ] + result = _hf_to_vllm_conv(conv) + + assert len(result) == 2 + + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are a helpful assistant." + + assert result[1]["role"] == "assistant" + assert result[1]["content"][0] == {"type": "text", "text": "Hello,"} + assert result[1]["content"][1] == {"type": "text", "text": "I am"} + assert result[1]["content"][2] == {"type": "image_url", "image_url": {"url": "file:///path/to/img"}} + assert result[1]["content"][3] == {"type": "image_url", "image_url": {"url": "http://path/to/img"}} + + +@pytest.mark.sanity +def test_hf_to_vllm_invalid_content_formats(): + """Test converting from HF-format to vLLM-format messages with unsupported content formats.""" + # Image object is not supported to discourage copying images when saving the preprocessed dataset + with pytest.raises(NotImplementedError, match=r"No handler .* for fields: \{'part\.image'\}"): + _hf_to_vllm_conv( + [ + { + "role": "assistant", + "content": [ + {"type": "image", "image": Image.new("RGB", (256, 256))}, + ] + }, + ] + ) + + # Image base64 is not supported to discourage copying images when saving the preprocessed dataset + with pytest.raises(NotImplementedError, match=r"No handler .* for fields: \{'part\.base64'\}"): + _hf_to_vllm_conv( + [ + { + "role": "assistant", + "content": [ + {"type": "image", "base64": "abcdef"}, + ] + }, + ] + ) + + # Tests for _detect_assistant_pattern @pytest.mark.sanity def test_detect_assistant_pattern_structure(): """Test that the detected pattern has the correct regex structure.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) # Pattern should be a valid regex string assert isinstance(pattern, str) @@ -105,20 +170,20 @@ def test_detect_assistant_pattern_structure(): @pytest.mark.sanity def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): """Test that pattern correctly distinguishes assistant from user content.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") # Get the pattern - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) # Format a conversation manually to test the pattern test_conv = [ {"role": "user", "content": "USER_MSG"}, {"role": "assistant", "content": "ASSISTANT_MSG"}, ] - formatted: str = tokenizer.apply_chat_template( # type: ignore[assignment] + formatted: str = processor.apply_chat_template( # type: ignore[assignment] test_conv, tokenize=False, add_generation_prompt=False ) @@ -141,12 +206,12 @@ def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): @pytest.mark.sanity def test_detect_assistant_pattern_extracts_correct_content(): """Test that the pattern's capture group extracts only assistant message content.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) # Test with a multi-turn conversation test_conv = [ @@ -156,7 +221,7 @@ def test_detect_assistant_pattern_extracts_correct_content(): {"role": "assistant", "content": "Second answer"}, ] - formatted: str = tokenizer.apply_chat_template( # type: ignore [assignment] + formatted: str = processor.apply_chat_template( # type: ignore [assignment] test_conv, tokenize=False, add_generation_prompt=False ) @@ -261,13 +326,13 @@ def test_create_loss_mask_empty_offsets(): @pytest.mark.sanity def test_preprocess_batch_basic(): """Test preprocessing a basic batch of conversations.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token examples = { "conversations": [ @@ -282,9 +347,9 @@ def test_preprocess_batch_basic(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=assistant_pattern + examples, processor, max_length=512, assistant_pattern=assistant_pattern ) assert "input_ids" in results @@ -302,15 +367,15 @@ def test_preprocess_batch_basic(): @pytest.mark.sanity def test_preprocess_batch_empty_conversations(): """Test preprocessing batch with no conversations.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples: dict[str, list] = {"conversations": []} - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=assistant_pattern + examples, processor, max_length=512, assistant_pattern=assistant_pattern ) assert results["input_ids"] == [] @@ -320,13 +385,13 @@ def test_preprocess_batch_empty_conversations(): @pytest.mark.sanity def test_preprocess_batch_invalid_conversation(): """Test preprocessing batch with invalid conversations.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token examples = { "conversations": [ @@ -336,9 +401,9 @@ def test_preprocess_batch_invalid_conversation(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=assistant_pattern + examples, processor, max_length=512, assistant_pattern=assistant_pattern ) # Should only process the valid conversation @@ -349,13 +414,13 @@ def test_preprocess_batch_invalid_conversation(): @pytest.mark.sanity def test_preprocess_batch_truncation(): """Test that long sequences are truncated to max_length.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token # Create a very long message long_content = "word " * 1000 @@ -370,9 +435,9 @@ def test_preprocess_batch_truncation(): } max_length = 100 - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=max_length, assistant_pattern=assistant_pattern + examples, processor, max_length=max_length, assistant_pattern=assistant_pattern ) if len(results["input_ids"]) > 0: @@ -384,17 +449,17 @@ def test_preprocess_batch_truncation(): @pytest.mark.sanity def test_preprocess_batch_uses_hf_assistant_mask(): """Test that HF assistant token mask is used when supported.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token - # Skip test if assistant mask is not supported/functional for this tokenizer - if not _supports_assistant_mask(tokenizer): - pytest.skip("Tokenizer does not support assistant token mask") + # Skip test if assistant mask is not supported/functional for this processor + if not _supports_assistant_mask(processor): + pytest.skip("Processor does not support assistant token mask") examples = { "conversations": [ @@ -408,7 +473,7 @@ def test_preprocess_batch_uses_hf_assistant_mask(): # Pass None to trigger masking path results = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=None, ) @@ -427,23 +492,23 @@ def test_preprocess_batch_falls_back_to_regex(): """Test that preprocessing falls back to regex-based detection when HF mask is unavailable. """ - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token # Monkeypatch apply_chat_template to force HF mask failure - original_apply_chat_template = tokenizer.apply_chat_template + original_apply_chat_template = processor.apply_chat_template def patched_apply_chat_template(*args, **kwargs): if kwargs.get("return_assistant_tokens_mask", False): raise ValueError("Forcing fallback to regex path") return original_apply_chat_template(*args, **kwargs) - tokenizer.apply_chat_template = patched_apply_chat_template # type: ignore [method-assign] + processor.apply_chat_template = patched_apply_chat_template # type: ignore [method-assign] examples = { "conversations": [ @@ -454,11 +519,11 @@ def patched_apply_chat_template(*args, **kwargs): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=assistant_pattern, ) @@ -475,13 +540,13 @@ def patched_apply_chat_template(*args, **kwargs): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): """Test that minimum_valid_tokens drops short samples on regex path.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token examples = { "conversations": [ @@ -492,11 +557,11 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) baseline = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=assistant_pattern, ) @@ -507,7 +572,7 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): filtered = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=assistant_pattern, minimum_valid_tokens=valid_count + 1, @@ -521,13 +586,13 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): """Test that a sample is kept when valid tokens equal the threshold.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token examples = { "conversations": [ @@ -544,11 +609,11 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) baseline = _preprocess_batch( examples, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, ) @@ -559,7 +624,7 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): kept = _preprocess_batch( examples, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, minimum_valid_tokens=valid_count, @@ -576,13 +641,13 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): @pytest.mark.sanity def test_build_eagle3_dataset_basic(): """Test building EAGLE3 dataset from a simple HuggingFace dataset.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token # Create a simple dataset data = { @@ -599,7 +664,7 @@ def test_build_eagle3_dataset_basic(): } dataset = HFDataset.from_dict(data) - result = build_eagle3_dataset(dataset, tokenizer, max_length=512, num_proc=1) + result = build_eagle3_dataset(dataset, processor, max_length=512, num_proc=1) assert isinstance(result, HFDataset) assert len(result) <= len(dataset) @@ -613,13 +678,13 @@ def test_build_eagle3_dataset_basic(): @pytest.mark.sanity def test_build_eagle3_dataset_preserves_format(): """Test that build_eagle3_dataset sets the correct format.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token data = { "conversations": [ @@ -631,7 +696,7 @@ def test_build_eagle3_dataset_preserves_format(): } dataset = HFDataset.from_dict(data) - result = build_eagle3_dataset(dataset, tokenizer, max_length=512, num_proc=1) + result = build_eagle3_dataset(dataset, processor, max_length=512, num_proc=1) # Dataset should be in torch format assert result.format["type"] == "torch" @@ -640,13 +705,13 @@ def test_build_eagle3_dataset_preserves_format(): @pytest.mark.sanity def test_build_eagle3_dataset_removes_original_columns(): """Test that original columns are removed after processing.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token data = { "conversations": [ @@ -659,7 +724,7 @@ def test_build_eagle3_dataset_removes_original_columns(): } dataset = HFDataset.from_dict(data) - result = build_eagle3_dataset(dataset, tokenizer, max_length=512, num_proc=1) + result = build_eagle3_dataset(dataset, processor, max_length=512, num_proc=1) # Original columns should be removed if len(result) > 0: @@ -670,13 +735,13 @@ def test_build_eagle3_dataset_removes_original_columns(): @pytest.mark.sanity def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): """Test that build_eagle3_dataset removes samples below the token threshold.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token short_conv = [ {"role": "user", "content": "Hi"}, @@ -693,17 +758,17 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): }, ] - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) short_baseline = _preprocess_batch( {"conversations": [short_conv]}, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, ) long_baseline = _preprocess_batch( {"conversations": [long_conv]}, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, ) @@ -722,7 +787,7 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): dataset = HFDataset.from_dict({"conversations": [short_conv, long_conv]}) result = build_eagle3_dataset( dataset, - tokenizer, + processor, max_length=256, num_proc=1, assistant_pattern=assistant_pattern, @@ -742,13 +807,13 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): @pytest.mark.sanity def test_preprocess_batch_with_turn_dropout(): """Test preprocessing batch with turn dropout enabled.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token examples = { "conversations": [ @@ -761,10 +826,10 @@ def test_preprocess_batch_with_turn_dropout(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( examples, - tokenizer, + processor, max_length=512, assistant_pattern=assistant_pattern, turn_dropout=True, @@ -788,8 +853,8 @@ def test_detect_assistant_pattern_thinking_model(): but the pattern must still match real conversations where the think block contains substantial content. """ - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) - pattern = _detect_assistant_pattern(tokenizer) + processor = AutoProcessor.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) + pattern = _detect_assistant_pattern(processor) # Format a multi-turn conversation with thinking content injected # directly into the formatted string (as it would appear in real data) @@ -807,7 +872,7 @@ def test_detect_assistant_pattern_thinking_model(): "reasoning_content": "We are adding 3 and 3.", }, ] - formatted: str = tokenizer.apply_chat_template( # type: ignore[assignment] + formatted: str = processor.apply_chat_template( # type: ignore[assignment] test_conv, tokenize=False, add_generation_prompt=False, enable_thinking=True ) @@ -848,8 +913,8 @@ def test_create_loss_mask_thinking_model(thinking_content): Verifies correct masking both with and without thinking content in the block. """ - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) - pattern = _detect_assistant_pattern(tokenizer) + processor = AutoProcessor.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) + pattern = _detect_assistant_pattern(processor) # Build formatted text using the real chat template conv = [ @@ -858,7 +923,7 @@ def test_create_loss_mask_thinking_model(thinking_content): ] if thinking_content: conv[-1]["reasoning_content"] = thinking_content - formatted: str = tokenizer.apply_chat_template( # type: ignore[assignment] + formatted: str = processor.apply_chat_template( # type: ignore[assignment] conv, tokenize=False, add_generation_prompt=False, @@ -866,7 +931,7 @@ def test_create_loss_mask_thinking_model(thinking_content): ) # Tokenize with offsets - encoding = tokenizer( + encoding = processor( formatted, return_offsets_mapping=True, add_special_tokens=False, @@ -880,8 +945,8 @@ def test_create_loss_mask_thinking_model(thinking_content): # Decode masked vs unmasked regions input_ids = torch.tensor(encoding["input_ids"]) - trainable_text = tokenizer.decode(input_ids[mask == 1]) - masked_text = tokenizer.decode(input_ids[mask == 0]) + trainable_text = processor.decode(input_ids[mask == 1]) + masked_text = processor.decode(input_ids[mask == 0]) # Assistant response must be in the trainable region assert "Paris is the capital" in trainable_text @@ -898,13 +963,13 @@ def test_create_loss_mask_thinking_model(thinking_content): @pytest.mark.sanity def test_build_eagle3_dataset_with_custom_pattern(): """Test building dataset with custom assistant pattern.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if processor.pad_token is None: + processor.pad_token = processor.eos_token data = { "conversations": [ @@ -920,7 +985,7 @@ def test_build_eagle3_dataset_with_custom_pattern(): dataset = HFDataset.from_dict(data) result = build_eagle3_dataset( - dataset, tokenizer, max_length=512, num_proc=1, assistant_pattern=custom_pattern + dataset, processor, max_length=512, num_proc=1, assistant_pattern=custom_pattern ) # Should successfully build dataset with custom pattern diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index 4a671f12a..e3c87a9dd 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -6,7 +6,8 @@ import pytest from loguru import logger as log -from transformers import AutoTokenizer +from PIL import Image +from transformers import AutoProcessor, ProcessorMixin from speculators.data_generation.preprocessing import ( _detect_assistant_pattern, @@ -27,30 +28,33 @@ "microsoft/Phi-3-mini-4k-instruct", # GPT-OSS "openai/gpt-oss-20b", + # Multimodal + "google/gemma-4-E2B-it", ] @pytest.fixture(scope="module", params=MODELS) -def tokenizer(request): +def processor(request): model_id = request.param try: # Using trust_remote_code=True for variety of templates - return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + return AutoProcessor.from_pretrained(model_id, trust_remote_code=True) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: - pytest.skip(f"Failed to load tokenizer for {model_id}: {e}") + pytest.skip(f"Failed to load processor for {model_id}: {e}") -def test_regex_detection_across_models(tokenizer): +def test_regex_detection_across_models(tmp_path, processor): """ Verify that _detect_assistant_pattern and _preprocess_batch (regex path) work correctly for a variety of model families. """ + tokenizer = processor.tokenizer if isinstance(processor, ProcessorMixin) else processor model_name = tokenizer.name_or_path log.info(f"Testing family: {model_name}") # 1. Detect pattern try: - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) except (ValueError, RuntimeError) as e: pytest.fail(f"Failed to detect assistant pattern for {model_name}: {e}") @@ -58,20 +62,55 @@ def test_regex_detection_across_models(tokenizer): assert isinstance(pattern, (str, Pattern)), "Pattern must be str or regex object" # 2. Preprocess a simple multi-turn conversation using REGEX path - examples = { - "conversations": [ - [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - {"role": "assistant", "content": "The capital of France is Paris."}, - ] + if isinstance(processor, ProcessorMixin): + img_path = str(tmp_path / "blank.png") + Image.new("RGB", (256, 256)).save(img_path) + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello, how are you?"}, + {"type": "image", "path": img_path}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I am a helpful assistant."}, + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the capital"}, + {"type": "image", "path": img_path}, + {"type": "text", "text": "of France?"} + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The capital of France is Paris.", + }, + ], + }, ] - } + else: + conversation = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + ] + + examples = {"conversations": [conversation]} # Regex path by passing the explicit pattern results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=pattern + examples, processor, max_length=2048, assistant_pattern=pattern ) assert len(results["input_ids"]) == 1 @@ -86,7 +125,7 @@ def test_regex_detection_across_models(tokenizer): # 3. Qualitative check: Assistant content should be masked as 1 trainable_tokens = input_ids[loss_mask == 1] - decoded_assistant = tokenizer.decode(trainable_tokens) + decoded_assistant = processor.decode(trainable_tokens) log.info(f"Decoded trainable regions: {decoded_assistant}") @@ -97,7 +136,3 @@ def test_regex_detection_across_models(tokenizer): # It should NOT contain user message content assert "Hello" not in decoded_assistant assert "France?" not in decoded_assistant - - -if __name__ == "__main__": - pass From df2c542b904ba64635292afba1c499a123b97d1a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 13:51:26 +0000 Subject: [PATCH 15/95] Add trust remote code Signed-off-by: DarkLight1337 --- scripts/train.py | 6 ++++++ src/speculators/train/utils.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index cbcfabcae..cdc766367 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -259,6 +259,7 @@ def main(args: argparse.Namespace): args.verifier_name_or_path, transformer_layer_config.vocab_size, args.mask_token_id, + trust_remote_code=args.trust_remote_code, ) registry = SpeculatorModel.registry @@ -393,6 +394,11 @@ def _checkpoint_freq(value: str) -> int: def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--verifier-name-or-path", type=str, required=True) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow executing code from HF Hub when loading the verifier's tokenizer.", + ) parser.add_argument( "--speculator-type", type=str, diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py index 61c368309..ef1ba264e 100644 --- a/src/speculators/train/utils.py +++ b/src/speculators/train/utils.py @@ -60,6 +60,8 @@ def resolve_mask_token_id( verifier_name_or_path: str, vocab_size: int, mask_token_id: int | None = None, + *, + trust_remote_code: bool = False, ) -> int: """Resolve mask_token_id from explicit value, tokenizer, or fallback. @@ -73,7 +75,10 @@ def resolve_mask_token_id( logger.info(f"Using explicit mask_token_id={mask_token_id}") return mask_token_id - tokenizer = AutoTokenizer.from_pretrained(verifier_name_or_path) + tokenizer = AutoTokenizer.from_pretrained( + verifier_name_or_path, + trust_remote_code=trust_remote_code, + ) if tokenizer.mask_token_id is not None: logger.info(f"Using tokenizer mask_token_id={tokenizer.mask_token_id}") From a0da6221ce03edca19b550fc28f9ea27866fffa4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 13:52:56 +0000 Subject: [PATCH 16/95] Update CLI reference Signed-off-by: DarkLight1337 --- docs/cli/prepare_data.md | 2 ++ docs/cli/train.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/cli/prepare_data.md b/docs/cli/prepare_data.md index f4fc40066..fab7cb788 100644 --- a/docs/cli/prepare_data.md +++ b/docs/cli/prepare_data.md @@ -26,6 +26,8 @@ python scripts/prepare_data.py \ Example: `meta-llama/Llama-3.1-8B-Instruct` +- **`--trust-remote-code`** (flag) Allow executing code from HF Hub when loading the target model's processor. + ### Data Arguments - **`--data`** (str, required, repeatable) Path to training data. Can be a HuggingFace dataset name or local path. Use multiple times to specify multiple datasets. diff --git a/docs/cli/train.md b/docs/cli/train.md index 570fc4aaf..dee322121 100644 --- a/docs/cli/train.md +++ b/docs/cli/train.md @@ -32,6 +32,8 @@ torchrun --standalone --nproc_per_node=4 scripts/train.py \ - **`--verifier-name-or-path`** (str, required) HuggingFace model ID or local path for the verifier/target model. +- **`--trust-remote-code`** (flag) Allow executing code from HF Hub when loading the verifier's tokenizer. + - **`--speculator-type`** (str, default: `"eagle3"`) Type of speculator model to train. Options: `eagle3`, `dflash` - **`--from-pretrained`** (str, default: `""`) Path to a pretrained draft model to finetune. From 2dfa0f504980ba4fa6e8ff795bc439851158ba8e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 13:56:33 +0000 Subject: [PATCH 17/95] Format Signed-off-by: DarkLight1337 --- pyproject.toml | 1 + .../data_generation/preprocessing.py | 66 +++++++++++-------- .../integration/datagen/test_preprocessing.py | 50 +++++++++----- .../datagen/test_regex_patterns.py | 8 ++- 4 files changed, 78 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b3f28abff..596c50d17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,6 +250,7 @@ select = [ "INP001", # allow implicit namespace packages in scripts "PTH", # os.path is acceptable in scripts "T201", # print statements are acceptable in scripts + "PLR0915", # allow long parse_args functions ] "examples/**/*.py" = [ diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 424360591..a40d5b69b 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -155,26 +155,27 @@ def _hf_to_vllm_conv(normalized_conv: list[dict]): return [_hf_to_vllm_turn(turn) for turn in normalized_conv] +def _get_assistant_mask_test_conv(processor: ProcessorLike): + if isinstance(processor, ProcessorMixin): + return [{"role": "assistant", "content": [{"type": "text", "text": "test"}]}] + + return [{"role": "assistant", "content": "test"}] + + def _supports_assistant_mask(processor: ProcessorLike) -> bool: """Check if processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ - if isinstance(processor, ProcessorMixin): - test_conv = [ - {"role": "assistant", "content": [{"type": "text", "text": "test"}]} - ] - else: - test_conv = [{"role": "assistant", "content": "test"}] - - chat_template_kwargs: dict = { - "tokenize": True, - "return_assistant_tokens_mask": True, - "return_dict": True, - } + test_conv = _get_assistant_mask_test_conv(processor) try: - res_any = processor.apply_chat_template(test_conv, **chat_template_kwargs) + res_any = processor.apply_chat_template( + test_conv, + tokenizer=True, + return_assistant_tokens_mask=True, + return_dict=True, + ) res = cast("BatchEncoding | BatchFeature", res_any) # Check both singular and plural key names @@ -189,27 +190,36 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: return False +def _get_assistant_pattern_test_conv(processor: ProcessorLike): + if isinstance(processor, ProcessorMixin): + return [ + {"role": "user", "content": [{"type": "text", "text": "USER_MSG_1"}]}, + { + "role": "assistant", + "content": [{"type": "text", "text": "ASSISTANT_MSG_1"}], + }, + {"role": "user", "content": [{"type": "text", "text": "USER_MSG_2"}]}, + { + "role": "assistant", + "content": [{"type": "text", "text": "ASSISTANT_MSG_2"}], + }, + ] + + return [ + {"role": "user", "content": "USER_MSG_1"}, + {"role": "assistant", "content": "ASSISTANT_MSG_1"}, + {"role": "user", "content": "USER_MSG_2"}, + {"role": "assistant", "content": "ASSISTANT_MSG_2"}, + ] + + def _detect_assistant_pattern(processor: ProcessorLike) -> str: """Auto-detect the assistant message pattern from the processor's chat template. Uses multi-turn conversation but extracts pattern from the LAST assistant message only. """ - if isinstance(processor, ProcessorMixin): - test_conv = [ - {"role": "user", "content": [{"type": "text", "text": "USER_MSG_1"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "ASSISTANT_MSG_1"}]}, - {"role": "user", "content": [{"type": "text", "text": "USER_MSG_2"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "ASSISTANT_MSG_2"}]}, - ] - else: - test_conv = [ - {"role": "user", "content": "USER_MSG_1"}, - {"role": "assistant", "content": "ASSISTANT_MSG_1"}, - {"role": "user", "content": "USER_MSG_2"}, - {"role": "assistant", "content": "ASSISTANT_MSG_2"}, - ] - + test_conv = _get_assistant_pattern_test_conv(processor) formatted = processor.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False ) diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index 3fbfe24aa..f5eebc29d 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -82,11 +82,14 @@ def test_normalize_conversation_unknown_role(): # Tests for _hf_to_vllm_conv @pytest.mark.sanity def test_hf_to_vllm_all_content_formats(): - """Test converting from HF-format to vLLM-format messages with each supported content format.""" + """ + Test converting from HF-format to vLLM-format messages with + each supported content format. + """ conv = [ { "role": "system", - "content": "You are a helpful assistant." # Content as string + "content": "You are a helpful assistant.", # Content as string }, { "role": "assistant", @@ -94,9 +97,9 @@ def test_hf_to_vllm_all_content_formats(): "Hello,", # Text as string {"type": "text", "text": "I am"}, # Text as dictionary {"type": "image", "path": "/path/to/img"}, # Image file path - {"type": "image", "url": "http://path/to/img"} # Image URL - ] - }, + {"type": "image", "url": "http://path/to/img"}, # Image URL + ], + }, ] result = _hf_to_vllm_conv(conv) @@ -108,36 +111,51 @@ def test_hf_to_vllm_all_content_formats(): assert result[1]["role"] == "assistant" assert result[1]["content"][0] == {"type": "text", "text": "Hello,"} assert result[1]["content"][1] == {"type": "text", "text": "I am"} - assert result[1]["content"][2] == {"type": "image_url", "image_url": {"url": "file:///path/to/img"}} - assert result[1]["content"][3] == {"type": "image_url", "image_url": {"url": "http://path/to/img"}} + assert result[1]["content"][2] == { + "type": "image_url", + "image_url": {"url": "file:///path/to/img"}, + } + assert result[1]["content"][3] == { + "type": "image_url", + "image_url": {"url": "http://path/to/img"}, + } @pytest.mark.sanity def test_hf_to_vllm_invalid_content_formats(): - """Test converting from HF-format to vLLM-format messages with unsupported content formats.""" - # Image object is not supported to discourage copying images when saving the preprocessed dataset - with pytest.raises(NotImplementedError, match=r"No handler .* for fields: \{'part\.image'\}"): + """ + Test converting from HF-format to vLLM-format messages with + unsupported content formats. + """ + # Image object is not supported to discourage copying images + # when saving the preprocessed dataset + with pytest.raises( + NotImplementedError, match=r"No handler .* for fields: \{'part\.image'\}" + ): _hf_to_vllm_conv( [ { "role": "assistant", "content": [ {"type": "image", "image": Image.new("RGB", (256, 256))}, - ] - }, + ], + }, ] ) - # Image base64 is not supported to discourage copying images when saving the preprocessed dataset - with pytest.raises(NotImplementedError, match=r"No handler .* for fields: \{'part\.base64'\}"): + # Image base64 is not supported to discourage copying images + # when saving the preprocessed dataset + with pytest.raises( + NotImplementedError, match=r"No handler .* for fields: \{'part\.base64'\}" + ): _hf_to_vllm_conv( [ { "role": "assistant", "content": [ {"type": "image", "base64": "abcdef"}, - ] - }, + ], + }, ] ) diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index e3c87a9dd..b3ca1eebc 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -48,7 +48,9 @@ def test_regex_detection_across_models(tmp_path, processor): Verify that _detect_assistant_pattern and _preprocess_batch (regex path) work correctly for a variety of model families. """ - tokenizer = processor.tokenizer if isinstance(processor, ProcessorMixin) else processor + tokenizer = ( + processor.tokenizer if isinstance(processor, ProcessorMixin) else processor + ) model_name = tokenizer.name_or_path log.info(f"Testing family: {model_name}") @@ -78,14 +80,14 @@ def test_regex_detection_across_models(tmp_path, processor): "role": "assistant", "content": [ {"type": "text", "text": "I am a helpful assistant."}, - ] + ], }, { "role": "user", "content": [ {"type": "text", "text": "What is the capital"}, {"type": "image", "path": img_path}, - {"type": "text", "text": "of France?"} + {"type": "text", "text": "of France?"}, ], }, { From cf2631c00e0626c2a304678055f058983da89043 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 14:08:02 +0000 Subject: [PATCH 18/95] More cleanup and tests Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 28 ++- .../integration/datagen/test_preprocessing.py | 179 ++++++++++++------ .../datagen/test_regex_patterns.py | 5 +- 3 files changed, 141 insertions(+), 71 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index a40d5b69b..de1e74fc2 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -553,6 +553,24 @@ def load_raw_dataset( return raw_dataset, config.normalize_fn +def _resolve_eos_token(processor: ProcessorLike): + tokenizer = ( + processor.tokenizer if isinstance(processor, ProcessorMixin) else processor + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + +def _load_processor(target_model_path: str, *, trust_remote_code: bool = False): + processor = AutoProcessor.from_pretrained( + target_model_path, + trust_remote_code=trust_remote_code, + ) + _resolve_eos_token(processor) + + return processor + + def load_and_preprocess_dataset( target_model_path: str, train_data_paths: list[str], @@ -601,15 +619,7 @@ def load_and_preprocess_dataset( ) log.subsection("Loading processor") - processor = AutoProcessor.from_pretrained( - target_model_path, - trust_remote_code=trust_remote_code, - ) - tokenizer = ( - processor.tokenizer if isinstance(processor, ProcessorMixin) else processor - ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + processor = _load_processor(target_model_path, trust_remote_code=trust_remote_code) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: raise ValueError( diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index f5eebc29d..be1fdf782 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -8,12 +8,12 @@ import torch from datasets import Dataset as HFDataset from PIL import Image -from transformers import AutoProcessor from speculators.data_generation.preprocessing import ( _create_loss_mask_from_offsets, _detect_assistant_pattern, _hf_to_vllm_conv, + _load_processor, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, @@ -23,7 +23,9 @@ # Test model from HuggingFace with chat template # Using Qwen2-0.5B-Instruct: small (0.5B params), fast model with proper # chat template support -TEST_MODEL_REPO = "Qwen/Qwen2-0.5B-Instruct" +TEXT_MODEL_REPO = "Qwen/Qwen2-0.5B-Instruct" +# For testing multi-modal support +MM_MODEL_REPO = "Qwen/Qwen3-VL-2B-Instruct" # Tests for _normalize_conversation @@ -164,7 +166,7 @@ def test_hf_to_vllm_invalid_content_formats(): @pytest.mark.sanity def test_detect_assistant_pattern_structure(): """Test that the detected pattern has the correct regex structure.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -188,7 +190,7 @@ def test_detect_assistant_pattern_structure(): @pytest.mark.sanity def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): """Test that pattern correctly distinguishes assistant from user content.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -224,7 +226,7 @@ def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): @pytest.mark.sanity def test_detect_assistant_pattern_extracts_correct_content(): """Test that the pattern's capture group extracts only assistant message content.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -344,14 +346,11 @@ def test_create_loss_mask_empty_offsets(): @pytest.mark.sanity def test_preprocess_batch_basic(): """Test preprocessing a basic batch of conversations.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - examples = { "conversations": [ [ @@ -382,10 +381,106 @@ def test_preprocess_batch_basic(): assert isinstance(results["loss_mask"][i], torch.Tensor) +@pytest.mark.sanity +def test_preprocess_batch_multimodal(tmp_path): + """Test preprocessing a batch of multimodal conversations.""" + processor = _load_processor(MM_MODEL_REPO, trust_remote_code=True) + + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") + + img_path = str(tmp_path / "blank.png") + Image.new("RGB", (256, 256)).save(img_path) + + examples = { + "conversations": [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello, how are you?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I am a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the capital of France?"}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The capital of France is Paris.", + }, + ], + }, + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is the difference between these two images?", + }, + {"type": "image", "path": img_path}, + {"type": "image", "path": img_path}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "They are the exact same image."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Why?"}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "They are both blank.", + }, + ], + }, + ], + ] + } + + assistant_pattern = _detect_assistant_pattern(processor) + results = _preprocess_batch( + examples, processor, max_length=2048, assistant_pattern=assistant_pattern + ) + + assert "input_ids" in results + assert "loss_mask" in results + assert len(results["input_ids"]) == 2 + assert len(results["loss_mask"]) == 2 + + # Check that input_ids and loss_mask have same length for each example + for i in range(2): + assert len(results["input_ids"][i]) == len(results["loss_mask"][i]) + assert isinstance(results["input_ids"][i], torch.Tensor) + assert isinstance(results["loss_mask"][i], torch.Tensor) + + @pytest.mark.sanity def test_preprocess_batch_empty_conversations(): """Test preprocessing batch with no conversations.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -403,14 +498,11 @@ def test_preprocess_batch_empty_conversations(): @pytest.mark.sanity def test_preprocess_batch_invalid_conversation(): """Test preprocessing batch with invalid conversations.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - examples = { "conversations": [ None, # Invalid @@ -432,14 +524,11 @@ def test_preprocess_batch_invalid_conversation(): @pytest.mark.sanity def test_preprocess_batch_truncation(): """Test that long sequences are truncated to max_length.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - # Create a very long message long_content = "word " * 1000 @@ -467,14 +556,11 @@ def test_preprocess_batch_truncation(): @pytest.mark.sanity def test_preprocess_batch_uses_hf_assistant_mask(): """Test that HF assistant token mask is used when supported.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - # Skip test if assistant mask is not supported/functional for this processor if not _supports_assistant_mask(processor): pytest.skip("Processor does not support assistant token mask") @@ -510,14 +596,11 @@ def test_preprocess_batch_falls_back_to_regex(): """Test that preprocessing falls back to regex-based detection when HF mask is unavailable. """ - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - # Monkeypatch apply_chat_template to force HF mask failure original_apply_chat_template = processor.apply_chat_template @@ -558,14 +641,11 @@ def patched_apply_chat_template(*args, **kwargs): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): """Test that minimum_valid_tokens drops short samples on regex path.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - examples = { "conversations": [ [ @@ -604,14 +684,11 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): """Test that a sample is kept when valid tokens equal the threshold.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - examples = { "conversations": [ [ @@ -659,14 +736,11 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): @pytest.mark.sanity def test_build_eagle3_dataset_basic(): """Test building EAGLE3 dataset from a simple HuggingFace dataset.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - # Create a simple dataset data = { "conversations": [ @@ -696,14 +770,11 @@ def test_build_eagle3_dataset_basic(): @pytest.mark.sanity def test_build_eagle3_dataset_preserves_format(): """Test that build_eagle3_dataset sets the correct format.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - data = { "conversations": [ [ @@ -723,14 +794,11 @@ def test_build_eagle3_dataset_preserves_format(): @pytest.mark.sanity def test_build_eagle3_dataset_removes_original_columns(): """Test that original columns are removed after processing.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - data = { "conversations": [ [ @@ -753,14 +821,11 @@ def test_build_eagle3_dataset_removes_original_columns(): @pytest.mark.sanity def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): """Test that build_eagle3_dataset removes samples below the token threshold.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - short_conv = [ {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "OK"}, @@ -825,14 +890,11 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): @pytest.mark.sanity def test_preprocess_batch_with_turn_dropout(): """Test preprocessing batch with turn dropout enabled.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - examples = { "conversations": [ [ @@ -871,7 +933,7 @@ def test_detect_assistant_pattern_thinking_model(): but the pattern must still match real conversations where the think block contains substantial content. """ - processor = AutoProcessor.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) + processor = _load_processor("Qwen/Qwen3-8B", trust_remote_code=True) pattern = _detect_assistant_pattern(processor) # Format a multi-turn conversation with thinking content injected @@ -931,7 +993,7 @@ def test_create_loss_mask_thinking_model(thinking_content): Verifies correct masking both with and without thinking content in the block. """ - processor = AutoProcessor.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) + processor = _load_processor("Qwen/Qwen3-8B", trust_remote_code=True) pattern = _detect_assistant_pattern(processor) # Build formatted text using the real chat template @@ -981,14 +1043,11 @@ def test_create_loss_mask_thinking_model(thinking_content): @pytest.mark.sanity def test_build_eagle3_dataset_with_custom_pattern(): """Test building dataset with custom assistant pattern.""" - processor = AutoProcessor.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") - if processor.pad_token is None: - processor.pad_token = processor.eos_token - data = { "conversations": [ [ diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index b3ca1eebc..5a9500a04 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -7,10 +7,11 @@ import pytest from loguru import logger as log from PIL import Image -from transformers import AutoProcessor, ProcessorMixin +from transformers import ProcessorMixin from speculators.data_generation.preprocessing import ( _detect_assistant_pattern, + _load_processor, _preprocess_batch, ) @@ -38,7 +39,7 @@ def processor(request): model_id = request.param try: # Using trust_remote_code=True for variety of templates - return AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + return _load_processor(model_id, trust_remote_code=True) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: pytest.skip(f"Failed to load processor for {model_id}: {e}") From 3abacf8b31b08e26dab79af90114b5e921f976cb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 14:14:21 +0000 Subject: [PATCH 19/95] Reduce diff Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index de1e74fc2..6bccaa788 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -220,6 +220,7 @@ def _detect_assistant_pattern(processor: ProcessorLike) -> str: message only. """ test_conv = _get_assistant_pattern_test_conv(processor) + formatted = processor.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False ) From f2ed9c31cea401ce82c5859e82948f042ac95c27 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 14:20:40 +0000 Subject: [PATCH 20/95] Fix mypy Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 9 ++++---- .../data_generation/vllm_client.py | 21 ++++++++++++------- .../integration/datagen/test_preprocessing.py | 2 +- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 6bccaa788..6acd49097 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -37,8 +37,8 @@ def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: in """Visualize a single sample with color-coded trainable regions.""" # Get preprocessed sample prep_sample = preprocessed[idx] - input_ids = prep_sample["input_ids"] - loss_mask = prep_sample["loss_mask"] + input_ids = prep_sample["input_ids"].tolist() + loss_mask = prep_sample["loss_mask"].tolist() log.info(f"SAMPLE #{idx}") log.info("HIGHLIGHTED TEXT (BLUE = trainable, GREY = masked)") @@ -52,8 +52,9 @@ def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: in prev_state = None for i in range(len(input_ids)): - is_train = loss_mask[i].item() == 1 - token = processor.decode([input_ids[i].item()]) + is_train = loss_mask[i] == 1 + token = processor.decode([input_ids[i]]) + assert isinstance(token, str) # Switch colors when state changes if is_train != prev_state: diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 2473b7197..d2f98867b 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -2,9 +2,14 @@ import functools import logging import time +from typing import TYPE_CHECKING, Any import openai -from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam +from openai.types.completion import Completion + +if TYPE_CHECKING: + from collections.abc import Coroutine logger = logging.getLogger(__name__) @@ -121,6 +126,7 @@ async def generate_hidden_states_async( instead of passing `token_ids` to Completions API. timeout: Timeout in seconds for each request attempt. None for no timeout. """ + coro: Coroutine[Any, Any, Completion | ChatCompletion] if messages is None: coro = client.completions.create( model=model, @@ -139,11 +145,11 @@ async def generate_hidden_states_async( ) if timeout is not None: - completion = await asyncio.wait_for(coro, timeout=timeout) + res = await asyncio.wait_for(coro, timeout=timeout) else: - completion = await coro + res = await coro - return extract_output(completion, token_ids) + return extract_output(res, token_ids) @with_retries @@ -159,8 +165,9 @@ def generate_hidden_states( Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ + res: Completion | ChatCompletion if messages is None: - completion = client.completions.create( + res = client.completions.create( model=model, prompt=token_ids, max_tokens=1, @@ -168,7 +175,7 @@ def generate_hidden_states( timeout=timeout, ) else: - completion = client.chat.completions.create( + res = client.chat.completions.create( model=model, messages=messages, max_tokens=1, @@ -176,4 +183,4 @@ def generate_hidden_states( timeout=timeout, ) - return extract_output(completion, token_ids) + return extract_output(res, token_ids) diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index be1fdf782..0748c4f5f 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -88,7 +88,7 @@ def test_hf_to_vllm_all_content_formats(): Test converting from HF-format to vLLM-format messages with each supported content format. """ - conv = [ + conv: list[dict] = [ { "role": "system", "content": "You are a helpful assistant.", # Content as string From 8eab139db28d5d2ce505f6cf82f17972b86f6ef9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 14:23:41 +0000 Subject: [PATCH 21/95] Address AI comments Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- src/speculators/train/data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 6acd49097..4956913fe 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -130,7 +130,7 @@ def _hf_to_vllm_part(part: str | dict): for modality in ("image", "video", "audio"): if part_type == modality: if local_path := part.get("path"): - file_url = f"file://{local_path}" + file_url = f"file://{Path(local_path).absolute()}" return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} if url := part.get("url"): return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 7f36449ee..77d30c49a 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -260,7 +260,7 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: item = self.data[index] input_ids = item["input_ids"].tolist() - messages = self.data.get("_vllm_messages") + messages = item.get("_vllm_messages") try: hs_filepath = generate_hidden_states( From 5088837b33db0d7bb8a1bba885ee38fdaaae9ca9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 14:27:10 +0000 Subject: [PATCH 22/95] Add torchaudio and torchvision to dependencies Signed-off-by: DarkLight1337 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 596c50d17..a569d43d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ dependencies = [ "safetensors", "setuptools", "torch>=2.9.0,<=2.11.0", + "torchaudio", + "torchvision", "tqdm>=4.66.3,<=4.67.3", "transformers>=4.56.1,<5.7.0", "typer>=0.12.0", From a37e9006c51771420c76203557e49ede869136e9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 14:53:40 +0000 Subject: [PATCH 23/95] Handle transformers v4 Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 26 ++++++++++++++----- .../datagen/test_regex_patterns.py | 8 ++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 4956913fe..caf7d8c4c 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -9,6 +9,7 @@ import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset +from packaging.version import Version from transformers import ( AutoProcessor, BatchEncoding, @@ -16,6 +17,7 @@ PreTrainedTokenizerBase, ProcessorMixin, ) +from transformers import __version__ as TRANSFORMERS_VERSION # noqa: N812 from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger @@ -371,13 +373,23 @@ def _get_input_ids_loss_mask( } if isinstance(processor, ProcessorMixin): - encoded_any = processor.apply_chat_template( - normalized_conv, - tokenize=True, - add_generation_prompt=False, - return_dict=True, - processor_kwargs=processor_kwargs, - ) + if Version(TRANSFORMERS_VERSION) >= Version("5.4.0"): + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + return_dict=True, + processor_kwargs=processor_kwargs, + ) + else: + encoded_any = processor.apply_chat_template( + normalized_conv, + tokenize=True, + add_generation_prompt=False, + return_dict=True, + **processor_kwargs, + ) + encoded = cast("BatchFeature", encoded_any) # Remove batch dimension diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index 5a9500a04..4187e9dd5 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -6,8 +6,10 @@ import pytest from loguru import logger as log +from packaging.version import Version from PIL import Image from transformers import ProcessorMixin +from transformers import __version__ as TRANSFORMERS_VERSION # noqa: N812 from speculators.data_generation.preprocessing import ( _detect_assistant_pattern, @@ -29,10 +31,12 @@ "microsoft/Phi-3-mini-4k-instruct", # GPT-OSS "openai/gpt-oss-20b", - # Multimodal - "google/gemma-4-E2B-it", ] +if Version(TRANSFORMERS_VERSION) >= Version("5.5.0"): + # Multimodal + MODELS.append("google/gemma-4-E2B-it") + @pytest.fixture(scope="module", params=MODELS) def processor(request): From 1713b2285f30f8f3f940d9b0134d0159c134f438 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 May 2026 15:19:33 +0000 Subject: [PATCH 24/95] Typo Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index caf7d8c4c..c5f318b30 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -567,7 +567,7 @@ def load_raw_dataset( return raw_dataset, config.normalize_fn -def _resolve_eos_token(processor: ProcessorLike): +def _resolve_pad_token(processor: ProcessorLike): tokenizer = ( processor.tokenizer if isinstance(processor, ProcessorMixin) else processor ) @@ -580,7 +580,7 @@ def _load_processor(target_model_path: str, *, trust_remote_code: bool = False): target_model_path, trust_remote_code=trust_remote_code, ) - _resolve_eos_token(processor) + _resolve_pad_token(processor) return processor From 7e66a57455272536705f39682d244ef6e9d82c58 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 02:59:21 +0000 Subject: [PATCH 25/95] Clean up Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 65 +++++------ .../integration/datagen/test_preprocessing.py | 104 +++++++++++++----- 2 files changed, 112 insertions(+), 57 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index c5f318b30..cee076fcd 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -120,6 +120,26 @@ def _normalize_conversation( return normalized +def _adapt_part_for_hf(part: str | dict, processor: ProcessorLike): + if isinstance(part, str) and isinstance(processor, ProcessorMixin): + return {"type": "text", "text": part} + + return part + + +def _adapt_turn_for_hf(turn: dict, processor: ProcessorLike): + if isinstance(turn["content"], str) and isinstance(processor, ProcessorMixin): + return turn | {"content": [_adapt_part_for_hf(turn["content"], processor)]} + + return turn | { + "content": [_adapt_part_for_hf(part, processor) for part in turn["content"]] + } + + +def _adapt_conv_for_hf(normalized_conv: list[dict], processor: ProcessorLike): + return [_adapt_turn_for_hf(turn, processor) for turn in normalized_conv] + + def _hf_to_vllm_part(part: str | dict): if isinstance(part, str): return {"type": "text", "text": part} @@ -158,19 +178,15 @@ def _hf_to_vllm_conv(normalized_conv: list[dict]): return [_hf_to_vllm_turn(turn) for turn in normalized_conv] -def _get_assistant_mask_test_conv(processor: ProcessorLike): - if isinstance(processor, ProcessorMixin): - return [{"role": "assistant", "content": [{"type": "text", "text": "test"}]}] - - return [{"role": "assistant", "content": "test"}] - - def _supports_assistant_mask(processor: ProcessorLike) -> bool: """Check if processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ - test_conv = _get_assistant_mask_test_conv(processor) + test_conv = _adapt_conv_for_hf( + [{"role": "assistant", "content": "test"}], + processor, + ) try: res_any = processor.apply_chat_template( @@ -193,36 +209,21 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: return False -def _get_assistant_pattern_test_conv(processor: ProcessorLike): - if isinstance(processor, ProcessorMixin): - return [ - {"role": "user", "content": [{"type": "text", "text": "USER_MSG_1"}]}, - { - "role": "assistant", - "content": [{"type": "text", "text": "ASSISTANT_MSG_1"}], - }, - {"role": "user", "content": [{"type": "text", "text": "USER_MSG_2"}]}, - { - "role": "assistant", - "content": [{"type": "text", "text": "ASSISTANT_MSG_2"}], - }, - ] - - return [ - {"role": "user", "content": "USER_MSG_1"}, - {"role": "assistant", "content": "ASSISTANT_MSG_1"}, - {"role": "user", "content": "USER_MSG_2"}, - {"role": "assistant", "content": "ASSISTANT_MSG_2"}, - ] - - def _detect_assistant_pattern(processor: ProcessorLike) -> str: """Auto-detect the assistant message pattern from the processor's chat template. Uses multi-turn conversation but extracts pattern from the LAST assistant message only. """ - test_conv = _get_assistant_pattern_test_conv(processor) + test_conv = _adapt_conv_for_hf( + [ + {"role": "user", "content": "USER_MSG_1"}, + {"role": "assistant", "content": "ASSISTANT_MSG_1"}, + {"role": "user", "content": "USER_MSG_2"}, + {"role": "assistant", "content": "ASSISTANT_MSG_2"}, + ], + processor, + ) formatted = processor.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index 0748c4f5f..1e179dad0 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -10,6 +10,7 @@ from PIL import Image from speculators.data_generation.preprocessing import ( + _adapt_conv_for_hf, _create_loss_mask_from_offsets, _detect_assistant_pattern, _hf_to_vllm_conv, @@ -81,6 +82,68 @@ def test_normalize_conversation_unknown_role(): assert result[1]["role"] == "assistant" +# Tests for _adapt_conv_for_hf +@pytest.mark.sanity +def test_adapt_conv_for_hf_text_only_processor(): + """ + Test converting from normalized conversation to HF-format with + a text-only processor (i.e. tokenizer). + """ + processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + + conv: list[dict] = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": ["Hello"]}, + {"role": "assistant", "content": "Hi!"}, + ] + result = _adapt_conv_for_hf(conv, processor) + + assert result == conv + + +@pytest.mark.sanity +def test_adapt_conv_for_hf_multimodal_processor(): + """ + Test converting from normalized conversation to HF-format with + a multi-modal processor. + """ + processor = _load_processor(MM_MODEL_REPO, trust_remote_code=True) + + conv: list[dict] = [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": ["Hello", {"type": "image", "path": "/path/to/img"}], + }, + { + "role": "assistant", + "content": "Hi!", + }, + ] + result = _adapt_conv_for_hf(conv, processor) + + assert result == [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "image", "path": "/path/to/img"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "Hi!"}], + }, + ] + + # Tests for _hf_to_vllm_conv @pytest.mark.sanity def test_hf_to_vllm_all_content_formats(): @@ -111,16 +174,18 @@ def test_hf_to_vllm_all_content_formats(): assert result[0]["content"] == "You are a helpful assistant." assert result[1]["role"] == "assistant" - assert result[1]["content"][0] == {"type": "text", "text": "Hello,"} - assert result[1]["content"][1] == {"type": "text", "text": "I am"} - assert result[1]["content"][2] == { - "type": "image_url", - "image_url": {"url": "file:///path/to/img"}, - } - assert result[1]["content"][3] == { - "type": "image_url", - "image_url": {"url": "http://path/to/img"}, - } + assert result[1]["content"] == [ + {"type": "text", "text": "Hello,"}, + {"type": "text", "text": "I am"}, + { + "type": "image_url", + "image_url": {"url": "file:///path/to/img"}, + }, + { + "type": "image_url", + "image_url": {"url": "http://path/to/img"}, + }, + ] @pytest.mark.sanity @@ -397,30 +462,19 @@ def test_preprocess_batch_multimodal(tmp_path): [ { "role": "user", - "content": [ - {"type": "text", "text": "Hello, how are you?"}, - ], + "content": "Hello, how are you?", }, { "role": "assistant", - "content": [ - {"type": "text", "text": "I am a helpful assistant."}, - ], + "content": "I am a helpful assistant.", }, { "role": "user", - "content": [ - {"type": "text", "text": "What is the capital of France?"}, - ], + "content": "What is the capital of France?", }, { "role": "assistant", - "content": [ - { - "type": "text", - "text": "The capital of France is Paris.", - }, - ], + "content": "The capital of France is Paris.", }, ], [ From 84d0969efc3faf5da17b5f81a8377789d8d27f8f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 03:02:48 +0000 Subject: [PATCH 26/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index cee076fcd..042c03284 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -128,8 +128,11 @@ def _adapt_part_for_hf(part: str | dict, processor: ProcessorLike): def _adapt_turn_for_hf(turn: dict, processor: ProcessorLike): - if isinstance(turn["content"], str) and isinstance(processor, ProcessorMixin): - return turn | {"content": [_adapt_part_for_hf(turn["content"], processor)]} + if isinstance(turn["content"], str): + if isinstance(processor, ProcessorMixin): + return turn | {"content": [_adapt_part_for_hf(turn["content"], processor)]} + + return turn return turn | { "content": [_adapt_part_for_hf(part, processor) for part in turn["content"]] From f0151b3759068a96ec7e268731e0459f2e4c02e4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 03:21:20 +0000 Subject: [PATCH 27/95] Clean Signed-off-by: DarkLight1337 --- tests/integration/datagen/test_preprocessing.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index 1e179dad0..edb8ed519 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -504,10 +504,7 @@ def test_preprocess_batch_multimodal(tmp_path): { "role": "assistant", "content": [ - { - "type": "text", - "text": "They are both blank.", - }, + {"type": "text", "text": "They are both blank."}, ], }, ], From 83fc95ae1b7a609c7c3b8ab31272c5c12014c36b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 03:25:40 +0000 Subject: [PATCH 28/95] Improve UX Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 12 ++++++++++++ tests/integration/datagen/test_preprocessing.py | 14 +++----------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 042c03284..f469a9210 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -159,6 +159,18 @@ def _hf_to_vllm_part(part: str | dict): return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} if url := part.get("url"): return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} + if part.get("base64"): + raise ValueError( + f"base64 content is not supported. " + f"To avoid copying the image when saving the preprocessed dataset, " + f"please express {modality} inputs using file URLs." + ) + if obj := part.get(modality): + raise ValueError( + f"{type(obj).__name__} content is not supported. " + f"To avoid copying the image when saving the preprocessed dataset, " + f"please express {modality} inputs using file URLs." + ) fields_expr = {f"part.{k}" for k in part if k != "type"} diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index edb8ed519..8c725211c 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -194,11 +194,7 @@ def test_hf_to_vllm_invalid_content_formats(): Test converting from HF-format to vLLM-format messages with unsupported content formats. """ - # Image object is not supported to discourage copying images - # when saving the preprocessed dataset - with pytest.raises( - NotImplementedError, match=r"No handler .* for fields: \{'part\.image'\}" - ): + with pytest.raises(ValueError, match="Image content is not supported"): _hf_to_vllm_conv( [ { @@ -210,11 +206,7 @@ def test_hf_to_vllm_invalid_content_formats(): ] ) - # Image base64 is not supported to discourage copying images - # when saving the preprocessed dataset - with pytest.raises( - NotImplementedError, match=r"No handler .* for fields: \{'part\.base64'\}" - ): + with pytest.raises(ValueError, match="base64 content is not supported"): _hf_to_vllm_conv( [ { @@ -504,7 +496,7 @@ def test_preprocess_batch_multimodal(tmp_path): { "role": "assistant", "content": [ - {"type": "text", "text": "They are both blank."}, + {"type": "text", "text": "They are both blank."}, ], }, ], From b9bdc6c4b44428e3e6a1acfbfa3f0e03a80e1ea0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 03:33:35 +0000 Subject: [PATCH 29/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index f469a9210..cc5dab26b 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -357,6 +357,8 @@ def _get_input_ids_loss_mask( max_length: int, assistant_pattern: str | Pattern[str] | None, ): + normalized_conv = _adapt_conv_for_hf(normalized_conv, processor) + if assistant_pattern is None: # HF assistant token mask encoded_any = processor.apply_chat_template( From 97086d6e2fee20e7891cb796c59f1be72f129bdf Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 03:37:56 +0000 Subject: [PATCH 30/95] Add type annotations Signed-off-by: DarkLight1337 --- src/speculators/data_generation/vllm_client.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index d2f98867b..5cb17dfb1 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -88,8 +88,11 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs): return sync_wrapper -def extract_output(completion, token_ids) -> str: - prompt_token_ids = getattr(completion.choices[0], "prompt_token_ids", None) +def extract_output( + response: Completion | ChatCompletion, + token_ids: list[int], +) -> str: + prompt_token_ids = getattr(response.choices[0], "prompt_token_ids", None) if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") @@ -99,10 +102,10 @@ def extract_output(completion, token_ids) -> str: f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" ) - if not hasattr(completion, "kv_transfer_params"): + if not hasattr(response, "kv_transfer_params"): raise InvalidResponseError("Response missing kv_transfer_params") - return completion.kv_transfer_params.get("hidden_states_path") + return response.kv_transfer_params.get("hidden_states_path") @with_retries @@ -144,6 +147,7 @@ async def generate_hidden_states_async( timeout=timeout, ) + res: Completion | ChatCompletion if timeout is not None: res = await asyncio.wait_for(coro, timeout=timeout) else: From d9f6efa0a5616b36edb9ecc814fc408e7bd7a381 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 06:32:22 +0000 Subject: [PATCH 31/95] Improve error message Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 20 +++++++++---------- .../integration/datagen/test_preprocessing.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index cc5dab26b..b0a7a7314 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -159,27 +159,27 @@ def _hf_to_vllm_part(part: str | dict): return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} if url := part.get("url"): return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} + if part.get("base64"): + expr = {"type": modality, "base64": "..."} raise ValueError( - f"base64 content is not supported. " + f"Content part {expr} is not supported. " f"To avoid copying the image when saving the preprocessed dataset, " f"please express {modality} inputs using file URLs." ) - if obj := part.get(modality): + if part.get(modality): + expr = {"type": modality, modality: "..."} raise ValueError( - f"{type(obj).__name__} content is not supported. " + f"Content part {expr} is not supported. " f"To avoid copying the image when saving the preprocessed dataset, " f"please express {modality} inputs using file URLs." ) - fields_expr = {f"part.{k}" for k in part if k != "type"} - - raise NotImplementedError( - f"No handler defined in part.type={part_type!r} " - f"for fields: {fields_expr}" - ) + expr = {"type": modality} | {k: "..." for k in part if k != "type"} + raise NotImplementedError(f"Unknown content part: {expr}") - raise NotImplementedError(f"No handler defined for part.type={part_type!r}") + expr = {k: "..." for k in part} + raise NotImplementedError(f"Unknown content part: {expr}") def _hf_to_vllm_turn(turn: dict): diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index 8c725211c..521725971 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -194,7 +194,7 @@ def test_hf_to_vllm_invalid_content_formats(): Test converting from HF-format to vLLM-format messages with unsupported content formats. """ - with pytest.raises(ValueError, match="Image content is not supported"): + with pytest.raises(ValueError, match=r"'image':.* is not supported"): _hf_to_vllm_conv( [ { @@ -206,7 +206,7 @@ def test_hf_to_vllm_invalid_content_formats(): ] ) - with pytest.raises(ValueError, match="base64 content is not supported"): + with pytest.raises(ValueError, match=r"'base64':.* is not supported"): _hf_to_vllm_conv( [ { From 679d8c70311ea550f3e4f11241ff4d3ce7f80c85 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 06:35:06 +0000 Subject: [PATCH 32/95] Avoid name clash Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index b0a7a7314..24def8cfd 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -357,12 +357,12 @@ def _get_input_ids_loss_mask( max_length: int, assistant_pattern: str | Pattern[str] | None, ): - normalized_conv = _adapt_conv_for_hf(normalized_conv, processor) + hf_conv = _adapt_conv_for_hf(normalized_conv, processor) if assistant_pattern is None: # HF assistant token mask encoded_any = processor.apply_chat_template( - normalized_conv, + hf_conv, tokenize=True, add_generation_prompt=False, return_assistant_tokens_mask=True, @@ -393,7 +393,7 @@ def _get_input_ids_loss_mask( if isinstance(processor, ProcessorMixin): if Version(TRANSFORMERS_VERSION) >= Version("5.4.0"): encoded_any = processor.apply_chat_template( - normalized_conv, + hf_conv, tokenize=True, add_generation_prompt=False, return_dict=True, @@ -401,7 +401,7 @@ def _get_input_ids_loss_mask( ) else: encoded_any = processor.apply_chat_template( - normalized_conv, + hf_conv, tokenize=True, add_generation_prompt=False, return_dict=True, @@ -420,7 +420,7 @@ def _get_input_ids_loss_mask( else: # More optimized flow for text-only processors (i.e. tokenizers) formatted_text = processor.apply_chat_template( - normalized_conv, + hf_conv, tokenize=False, add_generation_prompt=False, ) From 720ccb423191a9e939185f00205375bbb9193a09 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 06:37:21 +0000 Subject: [PATCH 33/95] Rename Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 14 ++++----- .../integration/datagen/test_preprocessing.py | 30 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 24def8cfd..940c151c9 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -143,7 +143,7 @@ def _adapt_conv_for_hf(normalized_conv: list[dict], processor: ProcessorLike): return [_adapt_turn_for_hf(turn, processor) for turn in normalized_conv] -def _hf_to_vllm_part(part: str | dict): +def _adapt_part_for_vllm(part: str | dict): if isinstance(part, str): return {"type": "text", "text": part} @@ -178,19 +178,19 @@ def _hf_to_vllm_part(part: str | dict): expr = {"type": modality} | {k: "..." for k in part if k != "type"} raise NotImplementedError(f"Unknown content part: {expr}") - expr = {k: "..." for k in part} + expr = dict.fromkeys(part.keys(), "...") raise NotImplementedError(f"Unknown content part: {expr}") -def _hf_to_vllm_turn(turn: dict): +def _adapt_turn_for_vllm(turn: dict): if isinstance(turn["content"], str): return turn - return turn | {"content": [_hf_to_vllm_part(part) for part in turn["content"]]} + return turn | {"content": [_adapt_part_for_vllm(part) for part in turn["content"]]} -def _hf_to_vllm_conv(normalized_conv: list[dict]): - return [_hf_to_vllm_turn(turn) for turn in normalized_conv] +def _adapt_conv_for_vllm(normalized_conv: list[dict]): + return [_adapt_turn_for_vllm(turn) for turn in normalized_conv] def _supports_assistant_mask(processor: ProcessorLike) -> bool: @@ -501,7 +501,7 @@ def _preprocess_batch( results["seq_len"].append(len(input_ids)) if "_vllm_messages" in results: - results["_vllm_messages"].append(_hf_to_vllm_conv(normalized_conv)) + results["_vllm_messages"].append(_adapt_conv_for_vllm(normalized_conv)) return results diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index 521725971..c1e677698 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -11,9 +11,9 @@ from speculators.data_generation.preprocessing import ( _adapt_conv_for_hf, + _adapt_conv_for_vllm, _create_loss_mask_from_offsets, _detect_assistant_pattern, - _hf_to_vllm_conv, _load_processor, _normalize_conversation, _preprocess_batch, @@ -86,8 +86,8 @@ def test_normalize_conversation_unknown_role(): @pytest.mark.sanity def test_adapt_conv_for_hf_text_only_processor(): """ - Test converting from normalized conversation to HF-format with - a text-only processor (i.e. tokenizer). + Test converting from normalized conversation to HF format + with a text-only processor (i.e. tokenizer). """ processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) @@ -104,8 +104,8 @@ def test_adapt_conv_for_hf_text_only_processor(): @pytest.mark.sanity def test_adapt_conv_for_hf_multimodal_processor(): """ - Test converting from normalized conversation to HF-format with - a multi-modal processor. + Test converting from normalized conversation to HF format + with a multi-modal processor. """ processor = _load_processor(MM_MODEL_REPO, trust_remote_code=True) @@ -144,12 +144,12 @@ def test_adapt_conv_for_hf_multimodal_processor(): ] -# Tests for _hf_to_vllm_conv +# Tests for _adapt_conv_for_vllm @pytest.mark.sanity -def test_hf_to_vllm_all_content_formats(): +def test_adapt_conv_for_vllm_all_content_formats(): """ - Test converting from HF-format to vLLM-format messages with - each supported content format. + Test converting from normalized conversation to vLLM format + with each supported content format. """ conv: list[dict] = [ { @@ -166,7 +166,7 @@ def test_hf_to_vllm_all_content_formats(): ], }, ] - result = _hf_to_vllm_conv(conv) + result = _adapt_conv_for_vllm(conv) assert len(result) == 2 @@ -189,13 +189,13 @@ def test_hf_to_vllm_all_content_formats(): @pytest.mark.sanity -def test_hf_to_vllm_invalid_content_formats(): +def test_adapt_conv_for_vllm_invalid_content_formats(): """ - Test converting from HF-format to vLLM-format messages with - unsupported content formats. + Test converting from normalized conversation to vLLM format + with unsupported content formats. """ with pytest.raises(ValueError, match=r"'image':.* is not supported"): - _hf_to_vllm_conv( + _adapt_conv_for_vllm( [ { "role": "assistant", @@ -207,7 +207,7 @@ def test_hf_to_vllm_invalid_content_formats(): ) with pytest.raises(ValueError, match=r"'base64':.* is not supported"): - _hf_to_vllm_conv( + _adapt_conv_for_vllm( [ { "role": "assistant", From de2e38a090e297e3b1f1e10fd92227da3e1b36c6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 06:38:35 +0000 Subject: [PATCH 34/95] Reword Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 940c151c9..391f86128 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -165,14 +165,14 @@ def _adapt_part_for_vllm(part: str | dict): raise ValueError( f"Content part {expr} is not supported. " f"To avoid copying the image when saving the preprocessed dataset, " - f"please express {modality} inputs using file URLs." + f"please express {modality} inputs using file paths or URLs." ) if part.get(modality): expr = {"type": modality, modality: "..."} raise ValueError( f"Content part {expr} is not supported. " f"To avoid copying the image when saving the preprocessed dataset, " - f"please express {modality} inputs using file URLs." + f"please express {modality} inputs using file paths or URLs." ) expr = {"type": modality} | {k: "..." for k in part if k != "type"} From 95c6ec6f81e16d7c44c78a4359a0e8a20ba0ecd3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 08:11:11 +0000 Subject: [PATCH 35/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 391f86128..c36d48e68 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -163,15 +163,15 @@ def _adapt_part_for_vllm(part: str | dict): if part.get("base64"): expr = {"type": modality, "base64": "..."} raise ValueError( - f"Content part {expr} is not supported. " - f"To avoid copying the image when saving the preprocessed dataset, " + f"Content part {expr} is not supported. To avoid copying " + f"the {modality} when saving the preprocessed dataset, " f"please express {modality} inputs using file paths or URLs." ) if part.get(modality): expr = {"type": modality, modality: "..."} raise ValueError( - f"Content part {expr} is not supported. " - f"To avoid copying the image when saving the preprocessed dataset, " + f"Content part {expr} is not supported. The avoid copying " + f"the {modality} when saving the preprocessed dataset, " f"please express {modality} inputs using file paths or URLs." ) From f949b7c7526407a5d4bc0dc753ef915828411bbb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 08:11:44 +0000 Subject: [PATCH 36/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index c36d48e68..79ad3cf02 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -170,7 +170,7 @@ def _adapt_part_for_vllm(part: str | dict): if part.get(modality): expr = {"type": modality, modality: "..."} raise ValueError( - f"Content part {expr} is not supported. The avoid copying " + f"Content part {expr} is not supported. To avoid copying " f"the {modality} when saving the preprocessed dataset, " f"please express {modality} inputs using file paths or URLs." ) From 8d0c01d78f84900bf582c8d86282df0d147162fd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 08:22:42 +0000 Subject: [PATCH 37/95] Fix vllm messages not actually being passed Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 3da17a126..4df8b852b 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -327,12 +327,19 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): for i in to_process: if cancel_event.is_set(): break + item = dataset[i] + + partial_item = {"idx": i} + for k in ("input_ids", "_vllm_messages"): + if k in item: + partial_item[k] = item[k] + # Check cancel_event while waiting for queue space to avoid # deadlocking when all workers have died. while not cancel_event.is_set(): try: - queue.put_nowait({"idx": i, "input_ids": item["input_ids"]}) + queue.put_nowait(partial_item) break except asyncio.QueueFull: await asyncio.sleep(0.1) From 20b365f448bddff121a2f193aa9060e682c48782 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 08:54:41 +0000 Subject: [PATCH 38/95] Fix whitespacec Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 4df8b852b..4397b75da 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -66,8 +66,8 @@ def parse_args(): type=str, default=None, help=( - "HuggingFace model ID or local path for target model (default auto select)." - "For verification purposes only." + "HuggingFace model ID or local path for target model " + "(default auto select). For verification purposes only." ), ) parser.add_argument( @@ -113,7 +113,7 @@ def parse_args(): type=int, default=32, help=( - "Number of active vLLM requests at a time." + "Number of active vLLM requests at a time. " "Note: number of async workers set to 2*concurrency" ), ) @@ -121,8 +121,8 @@ def parse_args(): "--validate-outputs", action="store_true", help=( - "Load generated safetensor files and check output token ids match prompt" - " tokens and hidden states seq_len matches num tokens" + "Load generated safetensor files and check output token ids match " + "prompt tokens and hidden states seq_len matches num tokens" ), ) parser.add_argument( From a3d7a3528294e85d1dbd42383b96d37cef274c26 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:08:51 +0000 Subject: [PATCH 39/95] Cleanup Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 13 +++++-------- src/speculators/train/data.py | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 4397b75da..ac4fba460 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -31,6 +31,7 @@ DEFAULT_REQUEST_TIMEOUT, generate_hidden_states_async, ) +from speculators.train.data import ArrowDataset from speculators.train.logger import setup_root_logger logger = logging.getLogger(__name__) @@ -277,7 +278,7 @@ async def worker( continue input_ids = item["input_ids"].tolist() - messages = item.get("_vllm_messages") + messages = item.get("messages") target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" @@ -328,18 +329,14 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): if cancel_event.is_set(): break - item = dataset[i] - - partial_item = {"idx": i} - for k in ("input_ids", "_vllm_messages"): - if k in item: - partial_item[k] = item[k] + dataset_item = dataset[i] + openai_item = ArrowDataset.convert_to_openai(dataset_item) | {"idx": i} # Check cancel_event while waiting for queue space to avoid # deadlocking when all workers have died. while not cancel_event.is_set(): try: - queue.put_nowait(partial_item) + queue.put_nowait(openai_item) break except asyncio.QueueFull: await asyncio.sleep(0.1) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 77d30c49a..b886e4f52 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -168,6 +168,19 @@ def __getitem__(self, index) -> BatchType | None: class ArrowDataset(BaseDataset): + DATASET_TO_OPENAI_FIELDS = { + "input_ids": "input_ids", + "_vllm_messages": "messages", + } + + @classmethod + def convert_to_openai(cls, dataset_item: dict): + return { + openai_field: dataset_item[dataset_field] + for dataset_field, openai_field in cls.DATASET_TO_OPENAI_FIELDS.items() + if dataset_field in dataset_item + } + def __init__( self, max_len: int, @@ -258,9 +271,10 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: if not self.client: self._setup_client() - item = self.data[index] - input_ids = item["input_ids"].tolist() - messages = item.get("_vllm_messages") + dataset_item = self.data[index] + openai_item = self.convert_to_openai(dataset_item) + input_ids = openai_item["input_ids"].tolist() + messages = openai_item.get("messages") try: hs_filepath = generate_hidden_states( From b16381761f93cf17c623732ff478118008df14bb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:11:53 +0000 Subject: [PATCH 40/95] Simplify Signed-off-by: DarkLight1337 --- src/speculators/train/data.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index b886e4f52..482e438db 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -168,18 +168,16 @@ def __getitem__(self, index) -> BatchType | None: class ArrowDataset(BaseDataset): - DATASET_TO_OPENAI_FIELDS = { - "input_ids": "input_ids", - "_vllm_messages": "messages", - } - @classmethod def convert_to_openai(cls, dataset_item: dict): - return { - openai_field: dataset_item[dataset_field] - for dataset_field, openai_field in cls.DATASET_TO_OPENAI_FIELDS.items() - if dataset_field in dataset_item - } + openai_item = {} + + openai_item["input_ids"] = dataset_item["input_ids"].tolist() + + if "_vllm_messages" in dataset_item: + openai_item["messages"] = dataset_item["_vllm_messages"] + + return openai_item def __init__( self, From 347471ef1014047b9a03eaed75917d2d64683aec Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:14:16 +0000 Subject: [PATCH 41/95] Fix Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 10 ++++------ src/speculators/data_generation/vllm_client.py | 12 ++++++++---- src/speculators/train/data.py | 5 +---- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index ac4fba460..13d17a5ee 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -277,9 +277,6 @@ async def worker( queue.task_done() continue - input_ids = item["input_ids"].tolist() - messages = item.get("messages") - target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" try: @@ -287,8 +284,7 @@ async def worker( hidden_states_path = await generate_hidden_states_async( client, model, - input_ids, - messages=messages, + item, timeout=request_timeout, max_retries=max_retries, ) @@ -298,7 +294,9 @@ async def worker( ) if validate_outputs: await asyncio.to_thread( - check_safetensors_file, target_hidden_states_path, input_ids + check_safetensors_file, + target_hidden_states_path, + item["input_ids"], ) except Exception as e: if fail_on_error: diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 5cb17dfb1..f1e517d5e 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -112,9 +112,8 @@ def extract_output( async def generate_hidden_states_async( client: openai.AsyncClient, model: str, - token_ids: list[int], + dataset_item: dict, *, - messages: list[ChatCompletionMessageParam] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ @@ -129,6 +128,9 @@ async def generate_hidden_states_async( instead of passing `token_ids` to Completions API. timeout: Timeout in seconds for each request attempt. None for no timeout. """ + token_ids: list[int] = dataset_item["input_ids"] + messages: list[ChatCompletionMessageParam] | None = dataset_item.get("messages") + coro: Coroutine[Any, Any, Completion | ChatCompletion] if messages is None: coro = client.completions.create( @@ -160,15 +162,17 @@ async def generate_hidden_states_async( def generate_hidden_states( client: openai.Client, model: str, - token_ids: list[int], + dataset_item: dict, *, - messages: list[ChatCompletionMessageParam] | None = None, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ + token_ids: list[int] = dataset_item["input_ids"] + messages: list[ChatCompletionMessageParam] | None = dataset_item.get("messages") + res: Completion | ChatCompletion if messages is None: res = client.completions.create( diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 482e438db..6604a0e10 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -271,15 +271,12 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: dataset_item = self.data[index] openai_item = self.convert_to_openai(dataset_item) - input_ids = openai_item["input_ids"].tolist() - messages = openai_item.get("messages") try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] - input_ids, - messages=messages, + openai_item, timeout=self.request_timeout, max_retries=self.max_retries, ) From 7af08bb78fce04bc6812c80f9e56d50060f1cc80 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:19:09 +0000 Subject: [PATCH 42/95] Simplify Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 4 ++-- .../data_generation/vllm_client.py | 20 ++++++++++++------- src/speculators/train/data.py | 18 ++++++++--------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 13d17a5ee..e9847109c 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -328,13 +328,13 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): break dataset_item = dataset[i] - openai_item = ArrowDataset.convert_to_openai(dataset_item) | {"idx": i} + client_item = ArrowDataset.build_client_item(dataset_item) | {"idx": i} # Check cancel_event while waiting for queue space to avoid # deadlocking when all workers have died. while not cancel_event.is_set(): try: - queue.put_nowait(openai_item) + queue.put_nowait(client_item) break except asyncio.QueueFull: await asyncio.sleep(0.1) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index f1e517d5e..a77da0c53 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -2,11 +2,12 @@ import functools import logging import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict import openai from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai.types.completion import Completion +from typing_extensions import NotRequired if TYPE_CHECKING: from collections.abc import Coroutine @@ -108,11 +109,16 @@ def extract_output( return response.kv_transfer_params.get("hidden_states_path") +class ClientItem(TypedDict): + input_ids: list[int] + messages: NotRequired[list[ChatCompletionMessageParam]] + + @with_retries async def generate_hidden_states_async( client: openai.AsyncClient, model: str, - dataset_item: dict, + client_item: ClientItem, *, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: @@ -128,8 +134,8 @@ async def generate_hidden_states_async( instead of passing `token_ids` to Completions API. timeout: Timeout in seconds for each request attempt. None for no timeout. """ - token_ids: list[int] = dataset_item["input_ids"] - messages: list[ChatCompletionMessageParam] | None = dataset_item.get("messages") + token_ids = client_item["input_ids"] + messages = client_item.get("messages") coro: Coroutine[Any, Any, Completion | ChatCompletion] if messages is None: @@ -162,7 +168,7 @@ async def generate_hidden_states_async( def generate_hidden_states( client: openai.Client, model: str, - dataset_item: dict, + client_item: ClientItem, *, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: @@ -170,8 +176,8 @@ def generate_hidden_states( Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ - token_ids: list[int] = dataset_item["input_ids"] - messages: list[ChatCompletionMessageParam] | None = dataset_item.get("messages") + token_ids = client_item["input_ids"] + messages = client_item.get("messages") res: Completion | ChatCompletion if messages is None: diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 6604a0e10..67b07e694 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -8,7 +8,7 @@ from collections.abc import Callable from os import PathLike from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import openai import torch @@ -20,6 +20,7 @@ from speculators.data_generation.vllm_client import ( DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT, + ClientItem, generate_hidden_states, ) from speculators.train.noise_transforms import TransformTensors @@ -169,15 +170,14 @@ def __getitem__(self, index) -> BatchType | None: class ArrowDataset(BaseDataset): @classmethod - def convert_to_openai(cls, dataset_item: dict): - openai_item = {} - - openai_item["input_ids"] = dataset_item["input_ids"].tolist() + def build_client_item(cls, dataset_item: dict) -> ClientItem: + out_dict = {} + out_dict["input_ids"] = dataset_item["input_ids"].tolist() if "_vllm_messages" in dataset_item: - openai_item["messages"] = dataset_item["_vllm_messages"] + out_dict["messages"] = dataset_item["_vllm_messages"] - return openai_item + return cast("ClientItem", out_dict) def __init__( self, @@ -270,13 +270,13 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: self._setup_client() dataset_item = self.data[index] - openai_item = self.convert_to_openai(dataset_item) + client_item = self.build_client_item(dataset_item) try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] - openai_item, + client_item, timeout=self.request_timeout, max_retries=self.max_retries, ) From 68897d9f0ac9840b1ef1d1cad54b9d33bdd1ad6c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:20:42 +0000 Subject: [PATCH 43/95] Move to global Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 4 ++-- src/speculators/train/data.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index e9847109c..b1962e433 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -31,7 +31,7 @@ DEFAULT_REQUEST_TIMEOUT, generate_hidden_states_async, ) -from speculators.train.data import ArrowDataset +from speculators.train.data import build_client_item from speculators.train.logger import setup_root_logger logger = logging.getLogger(__name__) @@ -328,7 +328,7 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): break dataset_item = dataset[i] - client_item = ArrowDataset.build_client_item(dataset_item) | {"idx": i} + client_item = build_client_item(dataset_item) | {"idx": i} # Check cancel_event while waiting for queue space to avoid # deadlocking when all workers have died. diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 67b07e694..3d0994a3e 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -107,6 +107,16 @@ def standardize_data_v1(data: dict[str, Any]) -> dict[str, Any]: } +def build_client_item(dataset_item: dict) -> ClientItem: + out_dict = {} + out_dict["input_ids"] = dataset_item["input_ids"].tolist() + + if "_vllm_messages" in dataset_item: + out_dict["messages"] = dataset_item["_vllm_messages"] + + return cast("ClientItem", out_dict) + + class BaseDataset(Dataset): def __init__( self, @@ -169,16 +179,6 @@ def __getitem__(self, index) -> BatchType | None: class ArrowDataset(BaseDataset): - @classmethod - def build_client_item(cls, dataset_item: dict) -> ClientItem: - out_dict = {} - out_dict["input_ids"] = dataset_item["input_ids"].tolist() - - if "_vllm_messages" in dataset_item: - out_dict["messages"] = dataset_item["_vllm_messages"] - - return cast("ClientItem", out_dict) - def __init__( self, max_len: int, @@ -270,7 +270,7 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: self._setup_client() dataset_item = self.data[index] - client_item = self.build_client_item(dataset_item) + client_item = build_client_item(dataset_item) try: hs_filepath = generate_hidden_states( From 99cac6b2af0ef3782127fb971eb5451397b5b5db Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:52:56 +0000 Subject: [PATCH 44/95] Fix response parsing Signed-off-by: DarkLight1337 --- src/speculators/data_generation/vllm_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index a77da0c53..b9ca88c5f 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -93,7 +93,10 @@ def extract_output( response: Completion | ChatCompletion, token_ids: list[int], ) -> str: - prompt_token_ids = getattr(response.choices[0], "prompt_token_ids", None) + if isinstance(response, Completion): + prompt_token_ids = getattr(response.choices[0], "prompt_token_ids", None) + else: + prompt_token_ids = getattr(response, "prompt_token_ids", None) if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") From 9199bbb58c8b5387ab97d6731b889902cabfe4b0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 09:55:31 +0000 Subject: [PATCH 45/95] Fix missing special tokens Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 79ad3cf02..fd3f79d5f 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -387,7 +387,7 @@ def _get_input_ids_loss_mask( "return_offsets_mapping": True, "max_length": max_length, "truncation": True, - "add_special_tokens": False, + "add_special_tokens": isinstance(processor, ProcessorMixin), } if isinstance(processor, ProcessorMixin): From fc66e084d00d5f976293fef4942580455992abd9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 10:01:17 +0000 Subject: [PATCH 46/95] Update Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index fd3f79d5f..79ad3cf02 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -387,7 +387,7 @@ def _get_input_ids_loss_mask( "return_offsets_mapping": True, "max_length": max_length, "truncation": True, - "add_special_tokens": isinstance(processor, ProcessorMixin), + "add_special_tokens": False, } if isinstance(processor, ProcessorMixin): From 2a01ac0866990b0ac6e78b6bf1d60b8448b08d90 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 12:46:01 +0000 Subject: [PATCH 47/95] Fix chat Signed-off-by: DarkLight1337 --- src/speculators/data_generation/vllm_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index b9ca88c5f..577085ee3 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -154,7 +154,7 @@ async def generate_hidden_states_async( model=model, messages=messages, max_tokens=1, - extra_body={"return_token_ids": True}, + extra_body={"add_generation_prompt": False, "return_token_ids": True}, timeout=timeout, ) @@ -196,7 +196,7 @@ def generate_hidden_states( model=model, messages=messages, max_tokens=1, - extra_body={"return_token_ids": True}, + extra_body={"add_generation_prompt": False, "return_token_ids": True}, timeout=timeout, ) From df0b72162dcdbbe24190dd2b646e6a92d96c4081 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 12:55:28 +0000 Subject: [PATCH 48/95] Parameterize `run_prepare_data` Signed-off-by: DarkLight1337 --- tests/e2e/smoke/test_offline_training.py | 2 +- tests/e2e/smoke/test_online_training.py | 2 +- tests/e2e/smoke/test_resume_optimizer.py | 2 +- tests/e2e/utils.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 125af8046..ebf688498 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -82,7 +82,7 @@ def run_offline_e2e( save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(model, data_path, max_samples, seq_length) + run_prepare_data(model, "sharegpt", data_path, max_samples, seq_length) with launch_vllm_server_context( model, diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 7224784d8..03c47f9c3 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -55,7 +55,7 @@ def run_online_e2e( save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(model, data_path, max_samples, seq_length) + run_prepare_data(model, "sharegpt", data_path, max_samples, seq_length) hidden_states_path = str(tmp_path / "hidden_states") with launch_vllm_server_context( diff --git a/tests/e2e/smoke/test_resume_optimizer.py b/tests/e2e/smoke/test_resume_optimizer.py index c887d89e9..7929bc32e 100644 --- a/tests/e2e/smoke/test_resume_optimizer.py +++ b/tests/e2e/smoke/test_resume_optimizer.py @@ -82,7 +82,7 @@ def test_resume_after_checkpoint_best(tmp_path: Path): save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(MODEL, data_path) + run_prepare_data(MODEL, "sharegpt", data_path) # Step 2: Generate hidden states offline with launch_vllm_server_context(MODEL, VLLM_PORT, str(tmp_path / "hidden_states")): diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index b85e896a4..2a154f1f1 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -137,6 +137,7 @@ def launch_vllm_server_context(*args, **kwargs): def run_prepare_data( model: str, + data: str, data_path: Path, max_samples: int = 50, seq_length: int = 512, @@ -149,7 +150,7 @@ def run_prepare_data( "--model", model, "--data", - "sharegpt", + data, "--output", str(data_path), "--max-samples", From f1c9a5c28070649f1b158781f65ae9207c1964d2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 13:43:12 +0000 Subject: [PATCH 49/95] Add e2e test Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 1 + tests/e2e/smoke/test_offline_training.py | 24 ++++++++--- tests/e2e/smoke/test_online_training.py | 5 ++- tests/e2e/utils.py | 48 ++++++++++++++++++++-- 4 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 8f57cd26f..20881ab2d 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -59,6 +59,7 @@ def _normalize_sharegpt4v_coco(example: dict) -> dict: state_str = "set to" if os.getenv("COCO_DIR") else "default" raise ValueError( + f"No image found at <{image_path}>. " f"Please download COCO 2017 Train Images from " f" and place the " f"extracted folder under `COCO_DIR` ({state_str}: `{coco_dir}`)." diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index ebf688498..4f81fe72c 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -19,34 +19,47 @@ run_prepare_data, run_training, run_vllm_engine, + setup_dummy_sharegpt4v_coco, ) -MODEL = "Qwen/Qwen3-0.6B" +TEXT_MODEL = "Qwen/Qwen3-0.6B" +MM_MODEL = "Qwen/Qwen3-VL-2B-Instruct" @pytest.mark.e2e @pytest.mark.slow @pytest.mark.parametrize( - ("speculator_type", "extra_train_args", "target_layer_ids"), + ("model", "dataset", "speculator_type", "extra_train_args", "target_layer_ids"), [ - ("eagle3", [], None), # Use default EAGLE layers + (TEXT_MODEL, "sharegpt", "eagle3", [], None), # Use default EAGLE layers ( + TEXT_MODEL, + "sharegpt", "dflash", ["--block-size", "8", "--max-anchors", "256", "--num-layers", "3"], [1, 13, 25], ), # DFlash with 3 layers + verifier last layer + (MM_MODEL, "sharegpt4v_coco", "eagle3", [], None), # Multimodal ], ) def test_offline_smoke( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, + model: str, + dataset: str, prompts: list[list[dict[str, str]]], speculator_type: str, extra_train_args: list[str], target_layer_ids: list[int] | None, ): + if dataset == "sharegpt4v_coco": + monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) + setup_dummy_sharegpt4v_coco(tmp_path / "coco") + run_offline_e2e( tmp_path, - MODEL, + model, + dataset=dataset, prompts=prompts, vllm_gpu_util=0.9, speculator_type=speculator_type, @@ -58,6 +71,7 @@ def test_offline_smoke( def run_offline_e2e( tmp_path: Path, model: str, + dataset: str, max_samples: int = 50, seq_length: int = 512, vllm_gpu_util: float = 0.5, @@ -82,7 +96,7 @@ def run_offline_e2e( save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(model, "sharegpt", data_path, max_samples, seq_length) + run_prepare_data(model, dataset, data_path, max_samples, seq_length) with launch_vllm_server_context( model, diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 03c47f9c3..4803f5925 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -25,12 +25,13 @@ @pytest.mark.e2e @pytest.mark.slow def test_online_smoke(tmp_path: Path, prompts: list[list[dict[str, str]]]): - run_online_e2e(tmp_path, MODEL, prompts=prompts) + run_online_e2e(tmp_path, MODEL, dataset="sharegpt", prompts=prompts) def run_online_e2e( tmp_path: Path, model: str, + dataset: str, max_samples: int = 50, seq_length: int = 512, vllm_gpu_util: float = 0.5, @@ -55,7 +56,7 @@ def run_online_e2e( save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(model, "sharegpt", data_path, max_samples, seq_length) + run_prepare_data(model, dataset, data_path, max_samples, seq_length) hidden_states_path = str(tmp_path / "hidden_states") with launch_vllm_server_context( diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 2a154f1f1..3218025cb 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -11,6 +11,9 @@ from textwrap import indent from loguru import logger +from PIL import Image + +from speculators.data_generation.preprocessing import load_raw_dataset __all__ = [ "SCRIPTS_DIR", @@ -135,15 +138,42 @@ def launch_vllm_server_context(*args, **kwargs): stop_vllm_server(process) +def setup_dummy_sharegpt4v_coco( + coco_dir: Path, + max_samples: int = 50, + seed: int = 0, +): + """Enable ShareGPT4V to be used without downloading the actual COCO dataset.""" + coco_dir.mkdir(parents=True, exist_ok=True) + + # In load_and_process_dataset, we shuffle and then + # select 3 * max_samples from the dataset + # We must ensure all sample filepaths can be loaded successfully + raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") + raw_dataset = raw_dataset.shuffle(seed=seed) + raw_dataset = raw_dataset.select(range(3 * max_samples)) + + dummy_image = Image.new("RGB", (256, 256)) + dummy_image_path = coco_dir / "dummy.png" + dummy_image.save(dummy_image_path) + + # Use symlinks to avoid copying the image + for raw_path in raw_dataset["image"]: + image_path = coco_dir / raw_path.removeprefix("coco/") + image_path.parent.mkdir(parents=True, exist_ok=True) + image_path.symlink_to(dummy_image_path) + + def run_prepare_data( model: str, data: str, data_path: Path, max_samples: int = 50, seq_length: int = 512, + seed: int = 0, timeout: float | None = None, ): - """Tokenize ShareGPT data using prepare_data.py.""" + """Tokenize data using prepare_data.py.""" cmd = [ sys.executable, str(SCRIPTS_DIR / "prepare_data.py"), @@ -155,6 +185,8 @@ def run_prepare_data( str(data_path), "--max-samples", str(max_samples), + "--seed", + str(seed), "--seq-length", str(seq_length), ] @@ -197,7 +229,12 @@ def run_data_generation_offline( logger.info("Generating hidden states offline: {}", " ".join(datagen_cmd)) result = subprocess.run( # noqa: S603 - datagen_cmd, stderr=subprocess.PIPE, text=True, check=False, timeout=timeout + datagen_cmd, + stderr=subprocess.PIPE, + text=True, + check=False, + timeout=timeout, + env=os.environ.copy(), ) assert result.returncode == 0, ( f"data_generation_offline.py failed:\n{result.stderr}" @@ -269,7 +306,12 @@ def run_training( logger.info("Running training: {}", " ".join(train_cmd)) result = subprocess.run( # noqa: S603 - train_cmd, stderr=subprocess.PIPE, text=True, check=False, timeout=timeout + train_cmd, + stderr=subprocess.PIPE, + text=True, + check=False, + timeout=timeout, + env=os.environ.copy(), ) assert result.returncode == 0, f"train.py failed:\n{result.stderr}" From 46e6ab361b42d9c7c835817e6f9c1c6f717ccd7a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 13:51:15 +0000 Subject: [PATCH 50/95] Enforce eager to reduce startup time Signed-off-by: DarkLight1337 --- .../test_eagle3_offline_acceptance.py | 23 +++++++++++-- .../test_eagle3_online_acceptance.py | 23 +++++++++++-- tests/e2e/smoke/test_offline_training.py | 3 ++ tests/e2e/smoke/test_online_training.py | 33 +++++++++++++++++-- tests/e2e/utils.py | 3 ++ 5 files changed, 76 insertions(+), 9 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 598e35007..0e923c0ab 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -11,20 +11,37 @@ from pathlib import Path +import pytest + from tests.e2e.smoke.test_offline_training import run_offline_e2e from tests.utils import requires_cadence @requires_cadence("nightly") -def test_offline_regression(tmp_path: Path, prompts): +@pytest.mark.parametrize( + ("model", "dataset", "acceptance_thresholds"), + [ + ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.07, 0.007]), + ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.07, 0.007]), + ], +) +def test_offline_regression( + tmp_path: Path, + model: str, + dataset: str, + acceptance_thresholds: list[float], + prompts: list[list[dict[str, str]]], +): run_offline_e2e( tmp_path, - "Qwen/Qwen3-8B", + model, + dataset, max_samples=5000, seq_length=8192, vllm_gpu_util=0.9, + vllm_enforce_eager=dataset == "sharegpt4v_coco", epochs=3, prompts=prompts, - acceptance_thresholds=[0.4, 0.07, 0.007], + acceptance_thresholds=acceptance_thresholds, log_freq=50, ) diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 8f2faecff..62de08e6d 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -10,21 +10,38 @@ from pathlib import Path +import pytest + from tests.e2e.smoke.test_online_training import run_online_e2e from tests.utils import requires_cadence @requires_cadence("nightly") -def test_online_regression(tmp_path: Path, prompts): +@pytest.mark.parametrize( + ("model", "dataset", "acceptance_thresholds"), + [ + ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.1, 0.01]), + ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.1, 0.01]), + ], +) +def test_online_regression( + tmp_path: Path, + model: str, + dataset: str, + acceptance_thresholds: list[float], + prompts: list[list[dict[str, str]]], +): run_online_e2e( tmp_path, - "Qwen/Qwen3-8B", + model, + dataset, max_samples=5000, seq_length=8192, vllm_gpu_util=0.75, + vllm_enforce_eager=dataset == "sharegpt4v_coco", epochs=3, prompts=prompts, - acceptance_thresholds=[0.4, 0.1, 0.01], + acceptance_thresholds=acceptance_thresholds, log_freq=50, train_timeout=45 * 60, # 45 mins ) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 4f81fe72c..aafc1b151 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -62,6 +62,7 @@ def test_offline_smoke( dataset=dataset, prompts=prompts, vllm_gpu_util=0.9, + vllm_enforce_eager=dataset == "sharegpt4v_coco", speculator_type=speculator_type, extra_train_args=extra_train_args, target_layer_ids=target_layer_ids, @@ -75,6 +76,7 @@ def run_offline_e2e( max_samples: int = 50, seq_length: int = 512, vllm_gpu_util: float = 0.5, + vllm_enforce_eager: bool = False, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -105,6 +107,7 @@ def run_offline_e2e( max_model_len=seq_length + 1, gpu_memory_utilization=vllm_gpu_util, target_layer_ids=target_layer_ids, + enforce_eager=vllm_enforce_eager, ): # Step 2: Generate hidden states offline run_data_generation_offline( diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 4803f5925..2a0229cd5 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -17,15 +17,40 @@ run_prepare_data, run_training, run_vllm_engine, + setup_dummy_sharegpt4v_coco, ) -MODEL = "Qwen/Qwen3-0.6B" +TEXT_MODEL = "Qwen/Qwen3-0.6B" +MM_MODEL = "Qwen/Qwen3-VL-2B-Instruct" @pytest.mark.e2e @pytest.mark.slow -def test_online_smoke(tmp_path: Path, prompts: list[list[dict[str, str]]]): - run_online_e2e(tmp_path, MODEL, dataset="sharegpt", prompts=prompts) +@pytest.mark.parametrize( + ("model", "dataset"), + [ + (TEXT_MODEL, "sharegpt"), + (MM_MODEL, "sharegpt4v_coco"), + ], +) +def test_online_smoke( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + model: str, + dataset: str, + prompts: list[list[dict[str, str]]], +): + if dataset == "sharegpt4v_coco": + monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) + setup_dummy_sharegpt4v_coco(tmp_path / "coco") + + run_online_e2e( + tmp_path, + model, + dataset=dataset, + prompts=prompts, + vllm_enforce_eager=dataset == "sharegpt4v_coco", + ) def run_online_e2e( @@ -35,6 +60,7 @@ def run_online_e2e( max_samples: int = 50, seq_length: int = 512, vllm_gpu_util: float = 0.5, + vllm_enforce_eager: bool = False, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -65,6 +91,7 @@ def run_online_e2e( hidden_states_path, max_model_len=seq_length + 1, gpu_memory_utilization=vllm_gpu_util, + enforce_eager=vllm_enforce_eager, ): # Step 2: Train against live vLLM server run_training( diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 3218025cb..5769802fd 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -71,6 +71,7 @@ def launch_vllm_server( max_model_len: int = 513, gpu_memory_utilization: float = 0.5, target_layer_ids: list[int] | None = None, + enforce_eager: bool = False, ) -> subprocess.Popen: """Launch a vLLM server configured for hidden-state extraction. @@ -86,6 +87,8 @@ def launch_vllm_server( ] if target_layer_ids is not None: cmd += ["--target-layer-ids"] + [str(lid) for lid in target_layer_ids] + if enforce_eager: + cmd += ["--enforce-eager"] cmd += [ "--", "--port", From 4a53d6035dbd9f55a9c1de2437d42eb48d1563bc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 13:59:23 +0000 Subject: [PATCH 51/95] Update Signed-off-by: DarkLight1337 --- .../regression/test_eagle3_offline_acceptance.py | 15 ++++++++++++++- .../regression/test_eagle3_online_acceptance.py | 15 ++++++++++++++- tests/e2e/smoke/test_offline_training.py | 11 ++++++++++- tests/e2e/smoke/test_online_training.py | 11 ++++++++++- tests/e2e/utils.py | 4 ++++ 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 0e923c0ab..5a77a2027 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -14,6 +14,7 @@ import pytest from tests.e2e.smoke.test_offline_training import run_offline_e2e +from tests.e2e.utils import setup_dummy_sharegpt4v_coco from tests.utils import requires_cadence @@ -26,12 +27,23 @@ ], ) def test_offline_regression( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, model: str, dataset: str, acceptance_thresholds: list[float], prompts: list[list[dict[str, str]]], ): + if dataset == "sharegpt4v_coco": + monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) + setup_dummy_sharegpt4v_coco(tmp_path / "coco") + + vllm_enforce_eager = True + vllm_media_path = str(tmp_path / "coco") + else: + vllm_enforce_eager = False + vllm_media_path = None + run_offline_e2e( tmp_path, model, @@ -39,7 +51,8 @@ def test_offline_regression( max_samples=5000, seq_length=8192, vllm_gpu_util=0.9, - vllm_enforce_eager=dataset == "sharegpt4v_coco", + vllm_enforce_eager=vllm_enforce_eager, + vllm_media_path=vllm_media_path, epochs=3, prompts=prompts, acceptance_thresholds=acceptance_thresholds, diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 62de08e6d..f34093bef 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -13,6 +13,7 @@ import pytest from tests.e2e.smoke.test_online_training import run_online_e2e +from tests.e2e.utils import setup_dummy_sharegpt4v_coco from tests.utils import requires_cadence @@ -25,12 +26,23 @@ ], ) def test_online_regression( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, model: str, dataset: str, acceptance_thresholds: list[float], prompts: list[list[dict[str, str]]], ): + if dataset == "sharegpt4v_coco": + monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) + setup_dummy_sharegpt4v_coco(tmp_path / "coco") + + vllm_enforce_eager = True + vllm_media_path = str(tmp_path / "coco") + else: + vllm_enforce_eager = False + vllm_media_path = None + run_online_e2e( tmp_path, model, @@ -38,7 +50,8 @@ def test_online_regression( max_samples=5000, seq_length=8192, vllm_gpu_util=0.75, - vllm_enforce_eager=dataset == "sharegpt4v_coco", + vllm_enforce_eager=vllm_enforce_eager, + vllm_media_path=vllm_media_path, epochs=3, prompts=prompts, acceptance_thresholds=acceptance_thresholds, diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index aafc1b151..5fb564d89 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -56,13 +56,20 @@ def test_offline_smoke( monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) setup_dummy_sharegpt4v_coco(tmp_path / "coco") + vllm_enforce_eager = True + vllm_media_path = str(tmp_path / "coco") + else: + vllm_enforce_eager = False + vllm_media_path = None + run_offline_e2e( tmp_path, model, dataset=dataset, prompts=prompts, vllm_gpu_util=0.9, - vllm_enforce_eager=dataset == "sharegpt4v_coco", + vllm_enforce_eager=vllm_enforce_eager, + vllm_media_path=vllm_media_path, speculator_type=speculator_type, extra_train_args=extra_train_args, target_layer_ids=target_layer_ids, @@ -77,6 +84,7 @@ def run_offline_e2e( seq_length: int = 512, vllm_gpu_util: float = 0.5, vllm_enforce_eager: bool = False, + vllm_media_path: str | None = None, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -108,6 +116,7 @@ def run_offline_e2e( gpu_memory_utilization=vllm_gpu_util, target_layer_ids=target_layer_ids, enforce_eager=vllm_enforce_eager, + allowed_local_media_path=vllm_media_path, ): # Step 2: Generate hidden states offline run_data_generation_offline( diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 2a0229cd5..4f5c0ba94 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -44,12 +44,19 @@ def test_online_smoke( monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) setup_dummy_sharegpt4v_coco(tmp_path / "coco") + vllm_enforce_eager = True + vllm_media_path = str(tmp_path / "coco") + else: + vllm_enforce_eager = False + vllm_media_path = None + run_online_e2e( tmp_path, model, dataset=dataset, prompts=prompts, - vllm_enforce_eager=dataset == "sharegpt4v_coco", + vllm_enforce_eager=vllm_enforce_eager, + vllm_media_path=vllm_media_path, ) @@ -61,6 +68,7 @@ def run_online_e2e( seq_length: int = 512, vllm_gpu_util: float = 0.5, vllm_enforce_eager: bool = False, + vllm_media_path: str | None = None, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -92,6 +100,7 @@ def run_online_e2e( max_model_len=seq_length + 1, gpu_memory_utilization=vllm_gpu_util, enforce_eager=vllm_enforce_eager, + allowed_local_media_path=vllm_media_path, ): # Step 2: Train against live vLLM server run_training( diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 5769802fd..308b3f56e 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -68,10 +68,12 @@ def launch_vllm_server( model: str, port: int, hidden_states_path: str, + *, max_model_len: int = 513, gpu_memory_utilization: float = 0.5, target_layer_ids: list[int] | None = None, enforce_eager: bool = False, + allowed_local_media_path: str | None = None, ) -> subprocess.Popen: """Launch a vLLM server configured for hidden-state extraction. @@ -89,6 +91,8 @@ def launch_vllm_server( cmd += ["--target-layer-ids"] + [str(lid) for lid in target_layer_ids] if enforce_eager: cmd += ["--enforce-eager"] + if allowed_local_media_path: + cmd += ["--allowed-local-media-path", allowed_local_media_path] cmd += [ "--", "--port", From bf2a5c752aee46aeffca9e45352134d593fc6078 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 14:04:09 +0000 Subject: [PATCH 52/95] Avoid argument bloat Signed-off-by: DarkLight1337 --- .../test_eagle3_offline_acceptance.py | 8 +++++--- .../regression/test_eagle3_online_acceptance.py | 8 +++++--- tests/e2e/smoke/test_offline_training.py | 17 ++++++++--------- tests/e2e/smoke/test_online_training.py | 15 +++++++-------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 5a77a2027..775be593c 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -50,9 +50,11 @@ def test_offline_regression( dataset, max_samples=5000, seq_length=8192, - vllm_gpu_util=0.9, - vllm_enforce_eager=vllm_enforce_eager, - vllm_media_path=vllm_media_path, + vllm_kwargs={ + "gpu_memory_utilization": 0.9, + "enforce_eager": vllm_enforce_eager, + "allowed_local_media_path": vllm_media_path, + }, epochs=3, prompts=prompts, acceptance_thresholds=acceptance_thresholds, diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index f34093bef..a90366765 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -49,9 +49,11 @@ def test_online_regression( dataset, max_samples=5000, seq_length=8192, - vllm_gpu_util=0.75, - vllm_enforce_eager=vllm_enforce_eager, - vllm_media_path=vllm_media_path, + vllm_kwargs={ + "gpu_memory_utilization": 0.75, + "enforce_eager": vllm_enforce_eager, + "allowed_local_media_path": vllm_media_path, + }, epochs=3, prompts=prompts, acceptance_thresholds=acceptance_thresholds, diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 5fb564d89..25722465d 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -10,6 +10,7 @@ """ from pathlib import Path +from typing import Any import pytest @@ -67,9 +68,11 @@ def test_offline_smoke( model, dataset=dataset, prompts=prompts, - vllm_gpu_util=0.9, - vllm_enforce_eager=vllm_enforce_eager, - vllm_media_path=vllm_media_path, + vllm_kwargs={ + "gpu_memory_utilization": 0.9, + "enforce_eager": vllm_enforce_eager, + "allowed_local_media_path": vllm_media_path, + }, speculator_type=speculator_type, extra_train_args=extra_train_args, target_layer_ids=target_layer_ids, @@ -82,9 +85,7 @@ def run_offline_e2e( dataset: str, max_samples: int = 50, seq_length: int = 512, - vllm_gpu_util: float = 0.5, - vllm_enforce_eager: bool = False, - vllm_media_path: str | None = None, + vllm_kwargs: dict[str, Any] | None = None, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -113,10 +114,8 @@ def run_offline_e2e( port, str(tmp_path / "vllm_hidden_states"), max_model_len=seq_length + 1, - gpu_memory_utilization=vllm_gpu_util, target_layer_ids=target_layer_ids, - enforce_eager=vllm_enforce_eager, - allowed_local_media_path=vllm_media_path, + **(vllm_kwargs or {}), ): # Step 2: Generate hidden states offline run_data_generation_offline( diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 4f5c0ba94..800691bfd 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -9,6 +9,7 @@ """ from pathlib import Path +from typing import Any import pytest @@ -55,8 +56,10 @@ def test_online_smoke( model, dataset=dataset, prompts=prompts, - vllm_enforce_eager=vllm_enforce_eager, - vllm_media_path=vllm_media_path, + vllm_kwargs={ + "enforce_eager": vllm_enforce_eager, + "allowed_local_media_path": vllm_media_path, + }, ) @@ -66,9 +69,7 @@ def run_online_e2e( dataset: str, max_samples: int = 50, seq_length: int = 512, - vllm_gpu_util: float = 0.5, - vllm_enforce_eager: bool = False, - vllm_media_path: str | None = None, + vllm_kwargs: dict[str, Any] | None = None, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -98,9 +99,7 @@ def run_online_e2e( port, hidden_states_path, max_model_len=seq_length + 1, - gpu_memory_utilization=vllm_gpu_util, - enforce_eager=vllm_enforce_eager, - allowed_local_media_path=vllm_media_path, + **(vllm_kwargs or {}), ): # Step 2: Train against live vLLM server run_training( From 8a4ac6dbead051c5046528eebaa117089472992d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 14:14:07 +0000 Subject: [PATCH 53/95] Use enforce eager by default Signed-off-by: DarkLight1337 --- tests/e2e/regression/test_eagle3_offline_acceptance.py | 3 --- tests/e2e/regression/test_eagle3_online_acceptance.py | 3 --- tests/e2e/smoke/test_offline_training.py | 3 --- tests/e2e/smoke/test_online_training.py | 3 --- tests/e2e/utils.py | 2 +- 5 files changed, 1 insertion(+), 13 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 775be593c..558d6bc0b 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -38,10 +38,8 @@ def test_offline_regression( monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) setup_dummy_sharegpt4v_coco(tmp_path / "coco") - vllm_enforce_eager = True vllm_media_path = str(tmp_path / "coco") else: - vllm_enforce_eager = False vllm_media_path = None run_offline_e2e( @@ -52,7 +50,6 @@ def test_offline_regression( seq_length=8192, vllm_kwargs={ "gpu_memory_utilization": 0.9, - "enforce_eager": vllm_enforce_eager, "allowed_local_media_path": vllm_media_path, }, epochs=3, diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index a90366765..1673b7c06 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -37,10 +37,8 @@ def test_online_regression( monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) setup_dummy_sharegpt4v_coco(tmp_path / "coco") - vllm_enforce_eager = True vllm_media_path = str(tmp_path / "coco") else: - vllm_enforce_eager = False vllm_media_path = None run_online_e2e( @@ -51,7 +49,6 @@ def test_online_regression( seq_length=8192, vllm_kwargs={ "gpu_memory_utilization": 0.75, - "enforce_eager": vllm_enforce_eager, "allowed_local_media_path": vllm_media_path, }, epochs=3, diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 25722465d..08d5a1f0b 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -57,10 +57,8 @@ def test_offline_smoke( monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) setup_dummy_sharegpt4v_coco(tmp_path / "coco") - vllm_enforce_eager = True vllm_media_path = str(tmp_path / "coco") else: - vllm_enforce_eager = False vllm_media_path = None run_offline_e2e( @@ -70,7 +68,6 @@ def test_offline_smoke( prompts=prompts, vllm_kwargs={ "gpu_memory_utilization": 0.9, - "enforce_eager": vllm_enforce_eager, "allowed_local_media_path": vllm_media_path, }, speculator_type=speculator_type, diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 800691bfd..b2999c052 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -45,10 +45,8 @@ def test_online_smoke( monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) setup_dummy_sharegpt4v_coco(tmp_path / "coco") - vllm_enforce_eager = True vllm_media_path = str(tmp_path / "coco") else: - vllm_enforce_eager = False vllm_media_path = None run_online_e2e( @@ -57,7 +55,6 @@ def test_online_smoke( dataset=dataset, prompts=prompts, vllm_kwargs={ - "enforce_eager": vllm_enforce_eager, "allowed_local_media_path": vllm_media_path, }, ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 308b3f56e..60c61c7c7 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -72,7 +72,7 @@ def launch_vllm_server( max_model_len: int = 513, gpu_memory_utilization: float = 0.5, target_layer_ids: list[int] | None = None, - enforce_eager: bool = False, + enforce_eager: bool = True, allowed_local_media_path: str | None = None, ) -> subprocess.Popen: """Launch a vLLM server configured for hidden-state extraction. From 999be1f4630149e99a722a5d02d7b63e340c260b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 14:17:24 +0000 Subject: [PATCH 54/95] Typo Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 60c61c7c7..5251fc911 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -153,7 +153,7 @@ def setup_dummy_sharegpt4v_coco( """Enable ShareGPT4V to be used without downloading the actual COCO dataset.""" coco_dir.mkdir(parents=True, exist_ok=True) - # In load_and_process_dataset, we shuffle and then + # In load_and_preprocess_dataset, we shuffle and then # select 3 * max_samples from the dataset # We must ensure all sample filepaths can be loaded successfully raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") From 204687da83286d93aa480c749745a1937040a390 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 14:18:29 +0000 Subject: [PATCH 55/95] Up Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 5251fc911..897523cd5 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -158,7 +158,9 @@ def setup_dummy_sharegpt4v_coco( # We must ensure all sample filepaths can be loaded successfully raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") raw_dataset = raw_dataset.shuffle(seed=seed) - raw_dataset = raw_dataset.select(range(3 * max_samples)) + + if max_samples is not None and len(raw_dataset) > 3 * max_samples + raw_dataset = raw_dataset.select(range(3 * max_samples)) dummy_image = Image.new("RGB", (256, 256)) dummy_image_path = coco_dir / "dummy.png" From c31cb459c5463825d0ba032bf9feb577d420b71c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 14:18:48 +0000 Subject: [PATCH 56/95] Typo Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 897523cd5..a119186b4 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -159,7 +159,7 @@ def setup_dummy_sharegpt4v_coco( raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") raw_dataset = raw_dataset.shuffle(seed=seed) - if max_samples is not None and len(raw_dataset) > 3 * max_samples + if max_samples is not None and len(raw_dataset) > 3 * max_samples: raw_dataset = raw_dataset.select(range(3 * max_samples)) dummy_image = Image.new("RGB", (256, 256)) From 4a037d6cbc281114ad61c9f0699cf606857a7ba3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 14:59:20 +0000 Subject: [PATCH 57/95] Don't rely on sampling Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index a119186b4..ccdedd084 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -145,27 +145,16 @@ def launch_vllm_server_context(*args, **kwargs): stop_vllm_server(process) -def setup_dummy_sharegpt4v_coco( - coco_dir: Path, - max_samples: int = 50, - seed: int = 0, -): +def setup_dummy_sharegpt4v_coco(coco_dir: Path): """Enable ShareGPT4V to be used without downloading the actual COCO dataset.""" coco_dir.mkdir(parents=True, exist_ok=True) - # In load_and_preprocess_dataset, we shuffle and then - # select 3 * max_samples from the dataset - # We must ensure all sample filepaths can be loaded successfully - raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") - raw_dataset = raw_dataset.shuffle(seed=seed) - - if max_samples is not None and len(raw_dataset) > 3 * max_samples: - raw_dataset = raw_dataset.select(range(3 * max_samples)) - dummy_image = Image.new("RGB", (256, 256)) dummy_image_path = coco_dir / "dummy.png" dummy_image.save(dummy_image_path) + raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") + # Use symlinks to avoid copying the image for raw_path in raw_dataset["image"]: image_path = coco_dir / raw_path.removeprefix("coco/") From a1fa7b13c9c18106d111d37cadca1a40dab1a665 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 15:02:19 +0000 Subject: [PATCH 58/95] Reduce diff Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index ccdedd084..453038f28 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -227,12 +227,7 @@ def run_data_generation_offline( logger.info("Generating hidden states offline: {}", " ".join(datagen_cmd)) result = subprocess.run( # noqa: S603 - datagen_cmd, - stderr=subprocess.PIPE, - text=True, - check=False, - timeout=timeout, - env=os.environ.copy(), + datagen_cmd, stderr=subprocess.PIPE, text=True, check=False, timeout=timeout ) assert result.returncode == 0, ( f"data_generation_offline.py failed:\n{result.stderr}" @@ -304,12 +299,7 @@ def run_training( logger.info("Running training: {}", " ".join(train_cmd)) result = subprocess.run( # noqa: S603 - train_cmd, - stderr=subprocess.PIPE, - text=True, - check=False, - timeout=timeout, - env=os.environ.copy(), + train_cmd, stderr=subprocess.PIPE, text=True, check=False, timeout=timeout ) assert result.returncode == 0, f"train.py failed:\n{result.stderr}" From d9c5ca48c4bf031bc3b23ae1b56d25ae27a343b7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 15:06:57 +0000 Subject: [PATCH 59/95] Fix symlink Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 453038f28..df9cf41df 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -158,8 +158,10 @@ def setup_dummy_sharegpt4v_coco(coco_dir: Path): # Use symlinks to avoid copying the image for raw_path in raw_dataset["image"]: image_path = coco_dir / raw_path.removeprefix("coco/") - image_path.parent.mkdir(parents=True, exist_ok=True) - image_path.symlink_to(dummy_image_path) + + if not image_path.exists(): + image_path.parent.mkdir(parents=True, exist_ok=True) + image_path.symlink_to(dummy_image_path) def run_prepare_data( From 09a9bb7c6ab4b537b1c80bf9b469327b6bbc87ce Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 15:10:55 +0000 Subject: [PATCH 60/95] Use the real dataset for nightly Signed-off-by: DarkLight1337 --- tests/e2e/regression/test_eagle3_offline_acceptance.py | 10 ++++++---- tests/e2e/regression/test_eagle3_online_acceptance.py | 10 ++++++---- tests/e2e/smoke/test_offline_training.py | 8 +++++--- tests/e2e/smoke/test_online_training.py | 8 +++++--- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 558d6bc0b..8aa4250c0 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -13,8 +13,8 @@ import pytest +from speculators.data_generation.configs import get_coco_dir from tests.e2e.smoke.test_offline_training import run_offline_e2e -from tests.e2e.utils import setup_dummy_sharegpt4v_coco from tests.utils import requires_cadence @@ -35,10 +35,12 @@ def test_offline_regression( prompts: list[list[dict[str, str]]], ): if dataset == "sharegpt4v_coco": - monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) - setup_dummy_sharegpt4v_coco(tmp_path / "coco") + coco_dir = get_coco_dir() - vllm_media_path = str(tmp_path / "coco") + if not Path(coco_dir).exists(): + pytest.skip(f"Cannot find COCO dataset at {coco_dir}") + + vllm_media_path = coco_dir else: vllm_media_path = None diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 1673b7c06..868a183b2 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -12,8 +12,8 @@ import pytest +from speculators.data_generation.configs import get_coco_dir from tests.e2e.smoke.test_online_training import run_online_e2e -from tests.e2e.utils import setup_dummy_sharegpt4v_coco from tests.utils import requires_cadence @@ -34,10 +34,12 @@ def test_online_regression( prompts: list[list[dict[str, str]]], ): if dataset == "sharegpt4v_coco": - monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) - setup_dummy_sharegpt4v_coco(tmp_path / "coco") + coco_dir = get_coco_dir() - vllm_media_path = str(tmp_path / "coco") + if not Path(coco_dir).exists(): + pytest.skip(f"Cannot find COCO dataset at {coco_dir}") + + vllm_media_path = coco_dir else: vllm_media_path = None diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 08d5a1f0b..a5c1b2a50 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -54,10 +54,12 @@ def test_offline_smoke( target_layer_ids: list[int] | None, ): if dataset == "sharegpt4v_coco": - monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) - setup_dummy_sharegpt4v_coco(tmp_path / "coco") + coco_dir = tmp_path / "coco" - vllm_media_path = str(tmp_path / "coco") + monkeypatch.setenv("COCO_DIR", str(coco_dir)) + setup_dummy_sharegpt4v_coco(coco_dir) + + vllm_media_path = str(coco_dir) else: vllm_media_path = None diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index b2999c052..4e20de5cd 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -42,10 +42,12 @@ def test_online_smoke( prompts: list[list[dict[str, str]]], ): if dataset == "sharegpt4v_coco": - monkeypatch.setenv("COCO_DIR", str(tmp_path / "coco")) - setup_dummy_sharegpt4v_coco(tmp_path / "coco") + coco_dir = tmp_path / "coco" - vllm_media_path = str(tmp_path / "coco") + monkeypatch.setenv("COCO_DIR", str(coco_dir)) + setup_dummy_sharegpt4v_coco(coco_dir) + + vllm_media_path = str(coco_dir) else: vllm_media_path = None From 54f5bf7dffac6fccfe3efd7d12015c8950d71be4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 16:04:42 +0000 Subject: [PATCH 61/95] Use MM prompts for testing Signed-off-by: DarkLight1337 --- .../regression/test_eagle3_offline_acceptance.py | 15 ++++++++++++++- .../regression/test_eagle3_online_acceptance.py | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 8aa4250c0..69276fa36 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -14,6 +14,11 @@ import pytest from speculators.data_generation.configs import get_coco_dir +from speculators.data_generation.preprocessing import ( + _adapt_conv_for_vllm, + _normalize_conversation, + load_raw_dataset, +) from tests.e2e.smoke.test_offline_training import run_offline_e2e from tests.utils import requires_cadence @@ -27,7 +32,6 @@ ], ) def test_offline_regression( - monkeypatch: pytest.MonkeyPatch, tmp_path: Path, model: str, dataset: str, @@ -41,6 +45,15 @@ def test_offline_regression( pytest.skip(f"Cannot find COCO dataset at {coco_dir}") vllm_media_path = coco_dir + + raw_dataset, normalize_fn = load_raw_dataset(dataset) + raw_dataset = raw_dataset.skip(len(dataset) - len(prompts)) + if normalize_fn is not None: + raw_dataset = raw_dataset.map(normalize_fn, keep_in_memory=True) + + raw_convs = raw_dataset["conversations"] + normalized_convs = [_normalize_conversation(conv) for conv in raw_convs] + prompts = [_adapt_conv_for_vllm(conv) for conv in normalized_convs] else: vllm_media_path = None diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 868a183b2..30d1f2b5c 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -13,6 +13,11 @@ import pytest from speculators.data_generation.configs import get_coco_dir +from speculators.data_generation.preprocessing import ( + _adapt_conv_for_vllm, + _normalize_conversation, + load_raw_dataset, +) from tests.e2e.smoke.test_online_training import run_online_e2e from tests.utils import requires_cadence @@ -26,7 +31,6 @@ ], ) def test_online_regression( - monkeypatch: pytest.MonkeyPatch, tmp_path: Path, model: str, dataset: str, @@ -40,6 +44,15 @@ def test_online_regression( pytest.skip(f"Cannot find COCO dataset at {coco_dir}") vllm_media_path = coco_dir + + raw_dataset, normalize_fn = load_raw_dataset(dataset) + raw_dataset = raw_dataset.skip(len(dataset) - len(prompts)) + if normalize_fn is not None: + raw_dataset = raw_dataset.map(normalize_fn, keep_in_memory=True) + + raw_convs = raw_dataset["conversations"] + normalized_convs = [_normalize_conversation(conv) for conv in raw_convs] + prompts = [_adapt_conv_for_vllm(conv) for conv in normalized_convs] else: vllm_media_path = None From f8fa8f9b63f815964638bff4161d8014e47f6a15 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 16:10:54 +0000 Subject: [PATCH 62/95] Fix Signed-off-by: DarkLight1337 --- tests/e2e/regression/test_eagle3_conversion_acceptance.py | 1 + tests/e2e/smoke/test_offline_training.py | 1 + tests/e2e/smoke/test_online_training.py | 1 + tests/e2e/utils.py | 4 ++++ 4 files changed, 7 insertions(+) diff --git a/tests/e2e/regression/test_eagle3_conversion_acceptance.py b/tests/e2e/regression/test_eagle3_conversion_acceptance.py index 59707fdbf..6f895e9a1 100644 --- a/tests/e2e/regression/test_eagle3_conversion_acceptance.py +++ b/tests/e2e/regression/test_eagle3_conversion_acceptance.py @@ -69,6 +69,7 @@ def test_convert_run_vllm_engine_eagle3( run_vllm_engine( model_path=str(converted_path), tmp_path=tmp_path, + enforce_eager=False, disable_compile_cache=disable_compile_cache, prompts=prompts, acceptance_thresholds=acceptance_thresholds, diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index a5c1b2a50..84f6c9aba 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -155,4 +155,5 @@ def run_offline_e2e( max_tokens=max_tokens, ignore_eos=ignore_eos, acceptance_thresholds=acceptance_thresholds, + **(vllm_kwargs or {}), ) diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 4e20de5cd..998f73715 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -125,4 +125,5 @@ def run_online_e2e( max_tokens=max_tokens, ignore_eos=ignore_eos, acceptance_thresholds=acceptance_thresholds, + **(vllm_kwargs or {}), ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index df9cf41df..8c001c521 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -310,6 +310,8 @@ def run_vllm_engine( model_path: str, tmp_path: Path, prompts: list[list[dict[str, str]]], + enforce_eager: bool = True, + allowed_local_media_path: str | None = None, disable_compile_cache: bool = False, max_tokens: int = 50, ignore_eos: bool = True, @@ -340,6 +342,8 @@ def run_vllm_engine( "model": model_path, "max_model_len": 1024, "gpu_memory_utilization": 0.8, + "enforce_eager": enforce_eager, + "allowed_local_media_path": allowed_local_media_path, } ), "--prompts", From 6cbd2c82ea8bdba7db19de6890accded15f80fd9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 May 2026 16:35:16 +0000 Subject: [PATCH 63/95] Fix Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 8c001c521..6a50321a1 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -310,6 +310,8 @@ def run_vllm_engine( model_path: str, tmp_path: Path, prompts: list[list[dict[str, str]]], + max_model_len: int = 1024, + gpu_memory_utilization: float = 0.8, enforce_eager: bool = True, allowed_local_media_path: str | None = None, disable_compile_cache: bool = False, @@ -340,8 +342,8 @@ def run_vllm_engine( json.dumps( { "model": model_path, - "max_model_len": 1024, - "gpu_memory_utilization": 0.8, + "max_model_len": max_model_len, + "gpu_memory_utilization": gpu_memory_utilization, "enforce_eager": enforce_eager, "allowed_local_media_path": allowed_local_media_path, } From feb939e6e763c7ec63249aaf17828455949dd340 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 May 2026 13:34:32 +0000 Subject: [PATCH 64/95] Fix Signed-off-by: DarkLight1337 --- tests/e2e/run_vllm.py | 23 +++++++++++++++++------ tests/e2e/utils.py | 6 +++++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/e2e/run_vllm.py b/tests/e2e/run_vllm.py index 58cb69cdf..a2e184cf7 100644 --- a/tests/e2e/run_vllm.py +++ b/tests/e2e/run_vllm.py @@ -63,16 +63,19 @@ def parse_args(): "--sampling-params-args", type=str, required=True, - help="JSON-serialized kwargs for SamplingParams instantiation", + help="JSON-serialized kwargs or path to JSON file for SamplingParams instantiation.", ) parser.add_argument( "--llm-args", type=str, required=True, - help="JSON-serialized kwargs for LLM instantiation", + help="JSON-serialized kwargs or path to JSON file for LLM instantiation", ) parser.add_argument( - "--prompts", type=str, required=True, help="JSON-serialized prompts" + "--prompts", + type=str, + required=True, + help="JSON-serialized prompts or path to JSON file", ) parser.add_argument( "--results-file", @@ -120,10 +123,18 @@ def extract_metrics(raw_metrics: list[Metric], total_num_output_tokens: int) -> return metrics_dict +def _load_json(value: str): + if value.endswith(".json") and Path(value).is_file(): + with open(value, "rb") as f: + return json.load(f) + + return json.loads(value) + + def run_vllm(args: argparse.Namespace): - sampling_params = SamplingParams(**json.loads(args.sampling_params_args)) - llm = LLM(**json.loads(args.llm_args), disable_log_stats=False) - outputs = llm.chat(json.loads(args.prompts), sampling_params) + sampling_params = SamplingParams(**_load_json(args.sampling_params_args)) + llm = LLM(**_load_json(args.llm_args), disable_log_stats=False) + outputs = llm.chat(_load_json(args.prompts), sampling_params) total_num_output_tokens = sum( len(output.outputs[0].token_ids) for output in outputs ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 6a50321a1..47e98cbe1 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -326,6 +326,10 @@ def run_vllm_engine( run_vllm_file = str(Path(__file__).with_name("run_vllm.py")) results_file = str(tmp_path / "results.json") + prompts_file = str(tmp_path / "prompts.json") + with open(prompts_file, "w") as f: + json.dump(prompts, f) + command = [ VLLM_PYTHON, run_vllm_file, @@ -349,7 +353,7 @@ def run_vllm_engine( } ), "--prompts", - json.dumps(prompts), + prompts_file, "--results-file", results_file, ] From 04e89e3a61846e9338249cf26e2c9e6fe0029466 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 May 2026 14:30:06 +0000 Subject: [PATCH 65/95] Quality Signed-off-by: DarkLight1337 --- tests/e2e/run_vllm.py | 9 ++++++--- tests/e2e/utils.py | 14 +++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/e2e/run_vllm.py b/tests/e2e/run_vllm.py index a2e184cf7..e0ed6abc8 100644 --- a/tests/e2e/run_vllm.py +++ b/tests/e2e/run_vllm.py @@ -63,7 +63,10 @@ def parse_args(): "--sampling-params-args", type=str, required=True, - help="JSON-serialized kwargs or path to JSON file for SamplingParams instantiation.", + help=( + "JSON-serialized kwargs or path to JSON file for " + "SamplingParams instantiation." + ), ) parser.add_argument( "--llm-args", @@ -125,9 +128,9 @@ def extract_metrics(raw_metrics: list[Metric], total_num_output_tokens: int) -> def _load_json(value: str): if value.endswith(".json") and Path(value).is_file(): - with open(value, "rb") as f: + with Path(value).open("rb") as f: return json.load(f) - + return json.loads(value) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 47e98cbe1..436d2c232 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -323,16 +323,16 @@ def run_vllm_engine( VLLM_PYTHON = os.environ.get("VLLM_PYTHON", sys.executable) logger.info("vLLM Python executable: {}", VLLM_PYTHON) - run_vllm_file = str(Path(__file__).with_name("run_vllm.py")) - results_file = str(tmp_path / "results.json") + run_vllm_file = Path(__file__).with_name("run_vllm.py") + results_file = tmp_path / "results.json" - prompts_file = str(tmp_path / "prompts.json") - with open(prompts_file, "w") as f: + prompts_file = tmp_path / "prompts.json" + with prompts_file.open("w") as f: json.dump(prompts, f) command = [ VLLM_PYTHON, - run_vllm_file, + str(run_vllm_file), "--sampling-params-args", json.dumps( { @@ -353,9 +353,9 @@ def run_vllm_engine( } ), "--prompts", - prompts_file, + str(prompts_file), "--results-file", - results_file, + str(results_file), ] logger.info("run_vllm.py command:\n {}", command) From f8e9bd4502d2074c6030c301661119384bcf9bbd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 May 2026 16:55:16 +0000 Subject: [PATCH 66/95] Fix evaluation Signed-off-by: DarkLight1337 --- tests/e2e/regression/test_eagle3_offline_acceptance.py | 2 +- tests/e2e/regression/test_eagle3_online_acceptance.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 69276fa36..adbbf637e 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -47,7 +47,7 @@ def test_offline_regression( vllm_media_path = coco_dir raw_dataset, normalize_fn = load_raw_dataset(dataset) - raw_dataset = raw_dataset.skip(len(dataset) - len(prompts)) + raw_dataset = raw_dataset.skip(len(raw_dataset) - len(prompts)) if normalize_fn is not None: raw_dataset = raw_dataset.map(normalize_fn, keep_in_memory=True) diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 30d1f2b5c..0839de3c8 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -46,7 +46,7 @@ def test_online_regression( vllm_media_path = coco_dir raw_dataset, normalize_fn = load_raw_dataset(dataset) - raw_dataset = raw_dataset.skip(len(dataset) - len(prompts)) + raw_dataset = raw_dataset.skip(len(raw_dataset) - len(prompts)) if normalize_fn is not None: raw_dataset = raw_dataset.map(normalize_fn, keep_in_memory=True) From 2a338ba3aa22d1d5a33969662badebf9c4aaa463 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 May 2026 17:00:19 +0000 Subject: [PATCH 67/95] Revert unnecsesary changes Signed-off-by: DarkLight1337 --- tests/e2e/run_vllm.py | 26 ++++++-------------------- tests/e2e/utils.py | 14 +++++--------- 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/tests/e2e/run_vllm.py b/tests/e2e/run_vllm.py index e0ed6abc8..58cb69cdf 100644 --- a/tests/e2e/run_vllm.py +++ b/tests/e2e/run_vllm.py @@ -63,22 +63,16 @@ def parse_args(): "--sampling-params-args", type=str, required=True, - help=( - "JSON-serialized kwargs or path to JSON file for " - "SamplingParams instantiation." - ), + help="JSON-serialized kwargs for SamplingParams instantiation", ) parser.add_argument( "--llm-args", type=str, required=True, - help="JSON-serialized kwargs or path to JSON file for LLM instantiation", + help="JSON-serialized kwargs for LLM instantiation", ) parser.add_argument( - "--prompts", - type=str, - required=True, - help="JSON-serialized prompts or path to JSON file", + "--prompts", type=str, required=True, help="JSON-serialized prompts" ) parser.add_argument( "--results-file", @@ -126,18 +120,10 @@ def extract_metrics(raw_metrics: list[Metric], total_num_output_tokens: int) -> return metrics_dict -def _load_json(value: str): - if value.endswith(".json") and Path(value).is_file(): - with Path(value).open("rb") as f: - return json.load(f) - - return json.loads(value) - - def run_vllm(args: argparse.Namespace): - sampling_params = SamplingParams(**_load_json(args.sampling_params_args)) - llm = LLM(**_load_json(args.llm_args), disable_log_stats=False) - outputs = llm.chat(_load_json(args.prompts), sampling_params) + sampling_params = SamplingParams(**json.loads(args.sampling_params_args)) + llm = LLM(**json.loads(args.llm_args), disable_log_stats=False) + outputs = llm.chat(json.loads(args.prompts), sampling_params) total_num_output_tokens = sum( len(output.outputs[0].token_ids) for output in outputs ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 436d2c232..6a50321a1 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -323,16 +323,12 @@ def run_vllm_engine( VLLM_PYTHON = os.environ.get("VLLM_PYTHON", sys.executable) logger.info("vLLM Python executable: {}", VLLM_PYTHON) - run_vllm_file = Path(__file__).with_name("run_vllm.py") - results_file = tmp_path / "results.json" - - prompts_file = tmp_path / "prompts.json" - with prompts_file.open("w") as f: - json.dump(prompts, f) + run_vllm_file = str(Path(__file__).with_name("run_vllm.py")) + results_file = str(tmp_path / "results.json") command = [ VLLM_PYTHON, - str(run_vllm_file), + run_vllm_file, "--sampling-params-args", json.dumps( { @@ -353,9 +349,9 @@ def run_vllm_engine( } ), "--prompts", - str(prompts_file), + json.dumps(prompts), "--results-file", - str(results_file), + results_file, ] logger.info("run_vllm.py command:\n {}", command) From 427bdab27d8fb125d8cf040eeea582bc4c3a65e9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 May 2026 17:01:11 +0000 Subject: [PATCH 68/95] Up Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 6a50321a1..9c85c05b4 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -170,7 +170,6 @@ def run_prepare_data( data_path: Path, max_samples: int = 50, seq_length: int = 512, - seed: int = 0, timeout: float | None = None, ): """Tokenize data using prepare_data.py.""" @@ -185,8 +184,6 @@ def run_prepare_data( str(data_path), "--max-samples", str(max_samples), - "--seed", - str(seed), "--seq-length", str(seq_length), ] From 9799d261ea6a2014a93f48f2dc978c4fbf8211e9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 May 2026 02:54:09 +0000 Subject: [PATCH 69/95] Update thresholds Signed-off-by: DarkLight1337 --- tests/e2e/regression/test_eagle3_offline_acceptance.py | 4 ++-- tests/e2e/regression/test_eagle3_online_acceptance.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index adbbf637e..f12865d4f 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -27,8 +27,8 @@ @pytest.mark.parametrize( ("model", "dataset", "acceptance_thresholds"), [ - ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.07, 0.007]), - ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.07, 0.007]), + ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.1, 0.01]), + ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.2, 0.04]), ], ) def test_offline_regression( diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 0839de3c8..044e2d87d 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -27,7 +27,7 @@ ("model", "dataset", "acceptance_thresholds"), [ ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.1, 0.01]), - ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.1, 0.01]), + ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.2, 0.04]), ], ) def test_online_regression( From d5b5b43f0a1c89b10feb46011c046203623f90b1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 May 2026 02:58:33 +0000 Subject: [PATCH 70/95] Change the key Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 6 +++--- src/speculators/train/data.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 79ad3cf02..c3ea62a3f 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -455,7 +455,7 @@ def _preprocess_batch( # MM inputs must use Chat Completion API if isinstance(processor, ProcessorMixin): - results["_vllm_messages"] = [] + results["messages"] = [] if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") @@ -500,8 +500,8 @@ def _preprocess_batch( results["loss_mask"].append(loss_mask) results["seq_len"].append(len(input_ids)) - if "_vllm_messages" in results: - results["_vllm_messages"].append(_adapt_conv_for_vllm(normalized_conv)) + if "messages" in results: + results["messages"].append(_adapt_conv_for_vllm(normalized_conv)) return results diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 3d0994a3e..d094ec8a6 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -111,8 +111,8 @@ def build_client_item(dataset_item: dict) -> ClientItem: out_dict = {} out_dict["input_ids"] = dataset_item["input_ids"].tolist() - if "_vllm_messages" in dataset_item: - out_dict["messages"] = dataset_item["_vllm_messages"] + if "messages" in dataset_item: + out_dict["messages"] = dataset_item["messages"] return cast("ClientItem", out_dict) From cf0bd0dfb970c0541c4c4a67c51a22b9059c4d7e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 May 2026 03:02:09 +0000 Subject: [PATCH 71/95] Typo Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index c3ea62a3f..5d6f62e11 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -453,7 +453,7 @@ def _preprocess_batch( results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} conversations: list[dict] = examples.get("conversations", []) - # MM inputs must use Chat Completion API + # MM inputs must use Chat Completions API if isinstance(processor, ProcessorMixin): results["messages"] = [] From 2ac74abc255b0ac644b28bec07819212aa94770b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 May 2026 06:35:48 +0000 Subject: [PATCH 72/95] Avoid misleading warning for Qwen3.5 Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 7 ++++++- tests/integration/datagen/test_preprocessing.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 5d6f62e11..6301c7402 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -198,8 +198,13 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: Must return a non-zero mask for a conversation containing an assistant message. """ + # NOTE: Some models (e.g. Qwen3.5) require a user message in the conversation, + # even though this check only looks at the assistant turn test_conv = _adapt_conv_for_hf( - [{"role": "assistant", "content": "test"}], + [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "test"}, + ], processor, ) diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index c1e677698..e9afd401f 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -26,7 +26,7 @@ # chat template support TEXT_MODEL_REPO = "Qwen/Qwen2-0.5B-Instruct" # For testing multi-modal support -MM_MODEL_REPO = "Qwen/Qwen3-VL-2B-Instruct" +MM_MODEL_REPO = "Qwen/Qwen3.5-0.8B" # Tests for _normalize_conversation From ab936fce30316c1987d2e106d151f7d515145356 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 May 2026 14:12:24 +0000 Subject: [PATCH 73/95] Avoid CPU contention for MM processing Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 75 +++++++++++-------- .../data_generation/torch_utils.py | 42 +++++++++++ 2 files changed, 84 insertions(+), 33 deletions(-) create mode 100644 src/speculators/data_generation/torch_utils.py diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 6301c7402..594483d8f 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -2,6 +2,7 @@ import random import re from collections.abc import Callable +from contextlib import nullcontext from pathlib import Path from re import Pattern from typing import cast @@ -21,6 +22,7 @@ from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger +from speculators.data_generation.torch_utils import set_default_torch_num_threads from speculators.train.vocab_mapping import save_token_frequency_distribution __all__ = [ @@ -664,42 +666,49 @@ def load_and_preprocess_dataset( "Please use a model with a pre-configured chat template." ) - processed_datasets = [] - for train_data_path in train_data_paths: - log.subsection(f"Processing {train_data_path}") - raw_dataset, normalize_fn = load_raw_dataset(train_data_path) - raw_dataset = raw_dataset.shuffle(seed=seed) - - if max_samples is not None and len(raw_dataset) > 3 * max_samples: - # Reduce size to 3 * max_samples to reduce processing - # This will then be reduced further to max_samples - # after combining datasets and shuffling - raw_dataset = raw_dataset.select(range(3 * max_samples)) - - if normalize_fn is not None: - raw_dataset = raw_dataset.map( - normalize_fn, - num_proc=build_dataset_num_proc, - keep_in_memory=True, # skip caching - ) + # Avoid CPU contention for MM processing: + # https://github.com/vllm-project/vllm/pull/31879 + with ( + set_default_torch_num_threads() + if isinstance(processor, ProcessorMixin) + else nullcontext() + ): + processed_datasets = [] + for train_data_path in train_data_paths: + log.subsection(f"Processing {train_data_path}") + raw_dataset, normalize_fn = load_raw_dataset(train_data_path) + raw_dataset = raw_dataset.shuffle(seed=seed) + + if max_samples is not None and len(raw_dataset) > 3 * max_samples: + # Reduce size to 3 * max_samples to reduce processing + # This will then be reduced further to max_samples + # after combining datasets and shuffling + raw_dataset = raw_dataset.select(range(3 * max_samples)) + + if normalize_fn is not None: + raw_dataset = raw_dataset.map( + normalize_fn, + num_proc=build_dataset_num_proc, + keep_in_memory=True, # skip caching + ) - log.info(f"Loaded {len(raw_dataset)} samples") + log.info(f"Loaded {len(raw_dataset)} samples") - if turn_dropout: - log.info("Turn dropout enabled: randomly keeping N consecutive turns") + if turn_dropout: + log.info("Turn dropout enabled: randomly keeping N consecutive turns") - preprocessed_dataset = build_eagle3_dataset( - dataset=raw_dataset, - processor=processor, - max_length=seq_length, - num_proc=build_dataset_num_proc, - assistant_pattern=assistant_pattern, - turn_dropout=turn_dropout, - minimum_valid_tokens=minimum_valid_tokens, - ) - if minimum_valid_tokens is not None: - log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") - processed_datasets.append(preprocessed_dataset) + preprocessed_dataset = build_eagle3_dataset( + dataset=raw_dataset, + processor=processor, + max_length=seq_length, + num_proc=build_dataset_num_proc, + assistant_pattern=assistant_pattern, + turn_dropout=turn_dropout, + minimum_valid_tokens=minimum_valid_tokens, + ) + if minimum_valid_tokens is not None: + log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") + processed_datasets.append(preprocessed_dataset) combined_dataset = concatenate_datasets(processed_datasets) combined_dataset.shuffle(seed=seed) diff --git a/src/speculators/data_generation/torch_utils.py b/src/speculators/data_generation/torch_utils.py new file mode 100644 index 000000000..6c2db5af2 --- /dev/null +++ b/src/speculators/data_generation/torch_utils.py @@ -0,0 +1,42 @@ +import contextlib +import os + +import torch + +from speculators.data_generation.logging_utils import PipelineLogger + +log = PipelineLogger(__name__) + + +# Based on vLLM's util with the same name +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int | None = None): + """ + Sets the default number of threads for PyTorch to the given value. + + `None` means using the value of the environment variable `OMP_NUM_THREADS` + (or `1` if that is not available). + """ + if num_threads is None: + num_threads = 1 + + try: + num_threads = int(os.environ["OMP_NUM_THREADS"]) + except KeyError: + log.debug( + f"OMP_NUM_THREADS is not set; " + f"defaulting Torch threads to {num_threads}.", + ) + except ValueError: + log.warning( + f"OMP_NUM_THREADS is invalid; " + f"defaulting Torch threads to {num_threads}.", + ) + + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + + try: + yield + finally: + torch.set_num_threads(old_num_threads) From 9348a58825f68ec1c1f28b34f190b196fbb48149 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 May 2026 14:20:03 +0000 Subject: [PATCH 74/95] Format Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index 1867a5a0d..b7076ef1b 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -80,6 +80,8 @@ def _normalize_sharegpt4v_coco(example: dict) -> dict: ] return {"conversations": messages} + + def _normalize_gsm8k(example: dict) -> dict: return { "conversations": [ From 2ff9bbb1fad05246cd36edc449f6ccf7c809c59f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 May 2026 14:20:59 +0000 Subject: [PATCH 75/95] mv Signed-off-by: DarkLight1337 --- src/speculators/data_generation/configs.py | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index b7076ef1b..fcd74b0ae 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -22,16 +22,25 @@ class DatasetConfig: normalize_fn: Callable[[dict], dict] | None = None -def get_coco_dir(): - return os.getenv("COCO_DIR") or "coco/" - - def _normalize_ultrachat(example: dict) -> dict: if "messages" in example: return {"conversations": example["messages"]} return example +def _normalize_gsm8k(example: dict) -> dict: + return { + "conversations": [ + {"role": "user", "content": example["question"]}, + {"role": "assistant", "content": example["answer"]}, + ] + } + + +def get_coco_dir(): + return os.getenv("COCO_DIR") or "coco/" + + def _parse_sharegpt4v_part(part: str, image_path: str): if part == "": return {"type": "image", "path": image_path} @@ -82,15 +91,6 @@ def _normalize_sharegpt4v_coco(example: dict) -> dict: return {"conversations": messages} -def _normalize_gsm8k(example: dict) -> dict: - return { - "conversations": [ - {"role": "user", "content": example["question"]}, - {"role": "assistant", "content": example["answer"]}, - ] - } - - DATASET_CONFIGS: dict[str, DatasetConfig] = { "sharegpt": DatasetConfig( name="sharegpt", From f72d76b19f8a8d505d603b1e57ebc435f4350b96 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 May 2026 14:40:40 +0000 Subject: [PATCH 76/95] Reduce scope Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 110 +++++++++--------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 678008109..30847f021 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -550,21 +550,28 @@ def build_eagle3_dataset( original_cols = dataset.column_names - dataset = dataset.map( - lambda examples: _preprocess_batch( - examples, - processor, - max_length, - assistant_pattern, - turn_dropout, - minimum_valid_tokens, - ), - batched=True, - num_proc=num_proc, - batch_size=1000, - remove_columns=original_cols, - keep_in_memory=True, # skip caching - ) + # Avoid CPU contention for MM processing: + # https://github.com/vllm-project/vllm/pull/31879 + with ( + set_default_torch_num_threads() + if isinstance(processor, ProcessorMixin) + else nullcontext() + ): + dataset = dataset.map( + lambda examples: _preprocess_batch( + examples, + processor, + max_length, + assistant_pattern, + turn_dropout, + minimum_valid_tokens, + ), + batched=True, + num_proc=num_proc, + batch_size=1000, + remove_columns=original_cols, + keep_in_memory=True, # skip caching + ) dataset.set_format(type="torch") return dataset @@ -666,49 +673,42 @@ def load_and_preprocess_dataset( "Please use a model with a pre-configured chat template." ) - # Avoid CPU contention for MM processing: - # https://github.com/vllm-project/vllm/pull/31879 - with ( - set_default_torch_num_threads() - if isinstance(processor, ProcessorMixin) - else nullcontext() - ): - processed_datasets = [] - for train_data_path in train_data_paths: - log.subsection(f"Processing {train_data_path}") - raw_dataset, normalize_fn = load_raw_dataset(train_data_path) - raw_dataset = raw_dataset.shuffle(seed=seed) - - if max_samples is not None and len(raw_dataset) > 3 * max_samples: - # Reduce size to 3 * max_samples to reduce processing - # This will then be reduced further to max_samples - # after combining datasets and shuffling - raw_dataset = raw_dataset.select(range(3 * max_samples)) - - if normalize_fn is not None: - raw_dataset = raw_dataset.map( - normalize_fn, - num_proc=build_dataset_num_proc, - keep_in_memory=True, # skip caching - ) + processed_datasets = [] + for train_data_path in train_data_paths: + log.subsection(f"Processing {train_data_path}") + raw_dataset, normalize_fn = load_raw_dataset(train_data_path) + raw_dataset = raw_dataset.shuffle(seed=seed) + + if max_samples is not None and len(raw_dataset) > 3 * max_samples: + # Reduce size to 3 * max_samples to reduce processing + # This will then be reduced further to max_samples + # after combining datasets and shuffling + raw_dataset = raw_dataset.select(range(3 * max_samples)) + + if normalize_fn is not None: + raw_dataset = raw_dataset.map( + normalize_fn, + num_proc=build_dataset_num_proc, + keep_in_memory=True, # skip caching + ) - log.info(f"Loaded {len(raw_dataset)} samples") + log.info(f"Loaded {len(raw_dataset)} samples") - if turn_dropout: - log.info("Turn dropout enabled: randomly keeping N consecutive turns") + if turn_dropout: + log.info("Turn dropout enabled: randomly keeping N consecutive turns") - preprocessed_dataset = build_eagle3_dataset( - dataset=raw_dataset, - processor=processor, - max_length=seq_length, - num_proc=build_dataset_num_proc, - assistant_pattern=assistant_pattern, - turn_dropout=turn_dropout, - minimum_valid_tokens=minimum_valid_tokens, - ) - if minimum_valid_tokens is not None: - log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") - processed_datasets.append(preprocessed_dataset) + preprocessed_dataset = build_eagle3_dataset( + dataset=raw_dataset, + processor=processor, + max_length=seq_length, + num_proc=build_dataset_num_proc, + assistant_pattern=assistant_pattern, + turn_dropout=turn_dropout, + minimum_valid_tokens=minimum_valid_tokens, + ) + if minimum_valid_tokens is not None: + log.info(f"Kept {len(preprocessed_dataset)} samples after filtering") + processed_datasets.append(preprocessed_dataset) combined_dataset = concatenate_datasets(processed_datasets) combined_dataset.shuffle(seed=seed) From 8c8c24294345e3418b765b5ca7e25019ba6f6952 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 May 2026 04:15:43 +0000 Subject: [PATCH 77/95] Improve messaging Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 2 +- src/speculators/data_generation/preprocessing.py | 7 +++++-- src/speculators/train/data.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index b1962e433..c70c68171 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -401,7 +401,7 @@ async def generate_and_save_hidden_states(args, dataset): if args.model and args.model != model_id: raise ValueError( f"An explicit model name was passed ({args.model}) which doesn't match" - "found model_id {model_id}." + f"found model_id {model_id}." "Please make sure --endpoint is set to the correct vllm instance." ) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 30847f021..8c0de2bbb 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -327,6 +327,7 @@ def _create_loss_mask_from_offsets( text: str, offsets: list[tuple[int, int]], assistant_pattern: str | Pattern[str], + conv_idx: int, ) -> torch.Tensor: """Create loss mask by finding assistant response spans in formatted text.""" loss_mask = torch.zeros(len(offsets), dtype=torch.bool) @@ -353,7 +354,7 @@ def _create_loss_mask_from_offsets( loss_mask[idx] = 1 if matches_found == 0: - log.warning("No assistant response spans found in conversation") + log.warning(f"No assistant response spans found in conversation {conv_idx}") return loss_mask @@ -363,6 +364,7 @@ def _get_input_ids_loss_mask( processor: ProcessorLike, max_length: int, assistant_pattern: str | Pattern[str] | None, + conv_idx: int, ): hf_conv = _adapt_conv_for_hf(normalized_conv, processor) @@ -441,7 +443,7 @@ def _get_input_ids_loss_mask( offsets = encoded["offset_mapping"] loss_mask = _create_loss_mask_from_offsets( - formatted_text, offsets, assistant_pattern + formatted_text, offsets, assistant_pattern, conv_idx ) return input_ids, loss_mask @@ -483,6 +485,7 @@ def _preprocess_batch( processor, max_length=max_length, assistant_pattern=assistant_pattern, + conv_idx=idx, ) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: log.error( diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 2feb4b494..15fa2e379 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -244,7 +244,7 @@ def _setup_client(self): if self.model and self.model != model_id: raise ValueError( f"An explicit model name was passed ({self.model}) which doesn't match" - "found model_id {model_id}." + f"found model_id {model_id}." "Please make sure --endpoint is set to the correct vllm instance." ) self.model = model_id From 6b643847c57eb7e8bac0b173802719f0f84a4245 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 May 2026 04:22:45 +0000 Subject: [PATCH 78/95] Fix typo Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 8c0de2bbb..8506e14cc 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -213,7 +213,7 @@ def _supports_assistant_mask(processor: ProcessorLike) -> bool: try: res_any = processor.apply_chat_template( test_conv, - tokenizer=True, + tokenize=True, return_assistant_tokens_mask=True, return_dict=True, ) From 434ce83709e3d660c82d3c8170b008a42a3cb0e3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 May 2026 04:25:51 +0000 Subject: [PATCH 79/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 8506e14cc..8de66e48d 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -327,7 +327,8 @@ def _create_loss_mask_from_offsets( text: str, offsets: list[tuple[int, int]], assistant_pattern: str | Pattern[str], - conv_idx: int, + *, + conv_idx: int | None = None, # For logging ) -> torch.Tensor: """Create loss mask by finding assistant response spans in formatted text.""" loss_mask = torch.zeros(len(offsets), dtype=torch.bool) @@ -354,7 +355,10 @@ def _create_loss_mask_from_offsets( loss_mask[idx] = 1 if matches_found == 0: - log.warning(f"No assistant response spans found in conversation {conv_idx}") + if conv_idx is None: + log.warning(f"No assistant response spans found in conversation") + else: + log.warning(f"No assistant response spans found in conversation {conv_idx}") return loss_mask @@ -364,7 +368,8 @@ def _get_input_ids_loss_mask( processor: ProcessorLike, max_length: int, assistant_pattern: str | Pattern[str] | None, - conv_idx: int, + *, + conv_idx: int | None = None, # For logging ): hf_conv = _adapt_conv_for_hf(normalized_conv, processor) From 31160243eaadec7512fa0bd36dd695b10552afb0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 May 2026 04:26:01 +0000 Subject: [PATCH 80/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 8de66e48d..03f031b0d 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -356,7 +356,7 @@ def _create_loss_mask_from_offsets( if matches_found == 0: if conv_idx is None: - log.warning(f"No assistant response spans found in conversation") + log.warning("No assistant response spans found in conversation") else: log.warning(f"No assistant response spans found in conversation {conv_idx}") From a8ce7fb23cad251ce9cde8d7d3b829c6b470ddcb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 May 2026 04:32:11 +0000 Subject: [PATCH 81/95] Fix Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 03f031b0d..ba0e3ae78 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -448,7 +448,7 @@ def _get_input_ids_loss_mask( offsets = encoded["offset_mapping"] loss_mask = _create_loss_mask_from_offsets( - formatted_text, offsets, assistant_pattern, conv_idx + formatted_text, offsets, assistant_pattern, conv_idx=conv_idx ) return input_ids, loss_mask From dd4c62889c48bbddd422997edec0f94e0fd0c67f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 May 2026 04:48:10 +0000 Subject: [PATCH 82/95] Improve Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index ba0e3ae78..6ed48a055 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -328,7 +328,9 @@ def _create_loss_mask_from_offsets( offsets: list[tuple[int, int]], assistant_pattern: str | Pattern[str], *, - conv_idx: int | None = None, # For logging + # For logging + conv_idx: int | None = None, + max_length: int | None = None, ) -> torch.Tensor: """Create loss mask by finding assistant response spans in formatted text.""" loss_mask = torch.zeros(len(offsets), dtype=torch.bool) @@ -355,10 +357,18 @@ def _create_loss_mask_from_offsets( loss_mask[idx] = 1 if matches_found == 0: - if conv_idx is None: - log.warning("No assistant response spans found in conversation") - else: - log.warning(f"No assistant response spans found in conversation {conv_idx}") + warning_msg = "No assistant response spans found in conversation" + if conv_idx is not None: + warning_msg += f" {conv_idx}" + + suggestion_msg = "" + if max_length is not None and len(offsets) == max_length: + suggestion_msg += ( + "Consider increasing --seq-length to avoid truncating " + "the assistant response." + ) + + log.warning(f"{warning_msg}. {suggestion_msg}") return loss_mask @@ -369,7 +379,8 @@ def _get_input_ids_loss_mask( max_length: int, assistant_pattern: str | Pattern[str] | None, *, - conv_idx: int | None = None, # For logging + # For logging + conv_idx: int | None = None, ): hf_conv = _adapt_conv_for_hf(normalized_conv, processor) @@ -448,7 +459,11 @@ def _get_input_ids_loss_mask( offsets = encoded["offset_mapping"] loss_mask = _create_loss_mask_from_offsets( - formatted_text, offsets, assistant_pattern, conv_idx=conv_idx + formatted_text, + offsets, + assistant_pattern, + conv_idx=conv_idx, + max_length=max_length, ) return input_ids, loss_mask From b9017fdc797461dbc07cda8739e018dc9e4c1588 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 18 May 2026 05:49:24 +0000 Subject: [PATCH 83/95] Improve messaging Signed-off-by: DarkLight1337 --- src/speculators/data_generation/vllm_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 577085ee3..2b51574af 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -106,7 +106,7 @@ def extract_output( f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" ) - if not hasattr(response, "kv_transfer_params"): + if getattr(response, "kv_transfer_params", None) is None: raise InvalidResponseError("Response missing kv_transfer_params") return response.kv_transfer_params.get("hidden_states_path") From 94ab45864640ec36383524cf4221b9b3559c4a80 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 18 May 2026 05:54:42 +0000 Subject: [PATCH 84/95] mypy Signed-off-by: DarkLight1337 --- src/speculators/data_generation/vllm_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 2b51574af..44629c432 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -106,10 +106,11 @@ def extract_output( f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" ) - if getattr(response, "kv_transfer_params", None) is None: + kv_transfer_params = getattr(response, "kv_transfer_params", None) + if kv_transfer_params is None: raise InvalidResponseError("Response missing kv_transfer_params") - return response.kv_transfer_params.get("hidden_states_path") + return kv_transfer_params.get("hidden_states_path") class ClientItem(TypedDict): From 327282953e302fc16c06f4d68ec9291f36549005 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 21 May 2026 05:32:34 +0000 Subject: [PATCH 85/95] Fix Signed-off-by: DarkLight1337 --- tests/e2e/smoke/test_offline_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 20f984841..2af63a6df 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -42,6 +42,7 @@ [1, 13, 25], ), # DFlash with 3 layers + verifier last layer ( + TEXT_MODEL, "peagle", [ "--num-layers", From 07f7d7c3a3426352b1e7bc3877adaaca17067663 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 21 May 2026 05:33:07 +0000 Subject: [PATCH 86/95] Fix 2 Signed-off-by: DarkLight1337 --- tests/e2e/smoke/test_offline_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 2af63a6df..73d5afe6b 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -43,6 +43,7 @@ ), # DFlash with 3 layers + verifier last layer ( TEXT_MODEL, + "sharegpt", "peagle", [ "--num-layers", From 8a10f1b95623a5a4206b04a17f073c1a81d91af2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 03:52:07 +0000 Subject: [PATCH 87/95] Use load_processor everywhere Signed-off-by: DarkLight1337 --- .../data_generation/preprocessing.py | 12 +++-- src/speculators/train/utils.py | 5 +- .../integration/datagen/test_preprocessing.py | 46 +++++++++---------- .../datagen/test_regex_patterns.py | 9 ++-- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 6ed48a055..417547514 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -622,15 +622,17 @@ def load_raw_dataset( return raw_dataset, config.normalize_fn +def get_tokenizer(processor: ProcessorLike): + return processor.tokenizer if isinstance(processor, ProcessorMixin) else processor + + def _resolve_pad_token(processor: ProcessorLike): - tokenizer = ( - processor.tokenizer if isinstance(processor, ProcessorMixin) else processor - ) + tokenizer = get_tokenizer(processor) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token -def _load_processor(target_model_path: str, *, trust_remote_code: bool = False): +def load_processor(target_model_path: str, *, trust_remote_code: bool = False): processor = AutoProcessor.from_pretrained( target_model_path, trust_remote_code=trust_remote_code, @@ -688,7 +690,7 @@ def load_and_preprocess_dataset( ) log.subsection("Loading processor") - processor = _load_processor(target_model_path, trust_remote_code=trust_remote_code) + processor = load_processor(target_model_path, trust_remote_code=trust_remote_code) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: raise ValueError( diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py index 8897d3ae7..19947032e 100644 --- a/src/speculators/train/utils.py +++ b/src/speculators/train/utils.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard -from transformers import AutoTokenizer +from speculators.data_generation.preprocessing import get_tokenizer, load_processor local_rank = int(os.environ.get("LOCAL_RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -75,10 +75,11 @@ def resolve_mask_token_id( logger.info(f"Using explicit mask_token_id={mask_token_id}") return mask_token_id - tokenizer = AutoTokenizer.from_pretrained( + processor = load_processor( verifier_name_or_path, trust_remote_code=trust_remote_code, ) + tokenizer = get_tokenizer(processor) if tokenizer.mask_token_id is not None: logger.info(f"Using tokenizer mask_token_id={tokenizer.mask_token_id}") diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index e9afd401f..bee137ff6 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -14,7 +14,7 @@ _adapt_conv_for_vllm, _create_loss_mask_from_offsets, _detect_assistant_pattern, - _load_processor, + load_processor, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, @@ -89,7 +89,7 @@ def test_adapt_conv_for_hf_text_only_processor(): Test converting from normalized conversation to HF format with a text-only processor (i.e. tokenizer). """ - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) conv: list[dict] = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -107,7 +107,7 @@ def test_adapt_conv_for_hf_multimodal_processor(): Test converting from normalized conversation to HF format with a multi-modal processor. """ - processor = _load_processor(MM_MODEL_REPO, trust_remote_code=True) + processor = load_processor(MM_MODEL_REPO, trust_remote_code=True) conv: list[dict] = [ { @@ -223,7 +223,7 @@ def test_adapt_conv_for_vllm_invalid_content_formats(): @pytest.mark.sanity def test_detect_assistant_pattern_structure(): """Test that the detected pattern has the correct regex structure.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -247,7 +247,7 @@ def test_detect_assistant_pattern_structure(): @pytest.mark.sanity def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): """Test that pattern correctly distinguishes assistant from user content.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -283,7 +283,7 @@ def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): @pytest.mark.sanity def test_detect_assistant_pattern_extracts_correct_content(): """Test that the pattern's capture group extracts only assistant message content.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -403,7 +403,7 @@ def test_create_loss_mask_empty_offsets(): @pytest.mark.sanity def test_preprocess_batch_basic(): """Test preprocessing a basic batch of conversations.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -441,7 +441,7 @@ def test_preprocess_batch_basic(): @pytest.mark.sanity def test_preprocess_batch_multimodal(tmp_path): """Test preprocessing a batch of multimodal conversations.""" - processor = _load_processor(MM_MODEL_REPO, trust_remote_code=True) + processor = load_processor(MM_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -523,7 +523,7 @@ def test_preprocess_batch_multimodal(tmp_path): @pytest.mark.sanity def test_preprocess_batch_empty_conversations(): """Test preprocessing batch with no conversations.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -541,7 +541,7 @@ def test_preprocess_batch_empty_conversations(): @pytest.mark.sanity def test_preprocess_batch_invalid_conversation(): """Test preprocessing batch with invalid conversations.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -567,7 +567,7 @@ def test_preprocess_batch_invalid_conversation(): @pytest.mark.sanity def test_preprocess_batch_truncation(): """Test that long sequences are truncated to max_length.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -599,7 +599,7 @@ def test_preprocess_batch_truncation(): @pytest.mark.sanity def test_preprocess_batch_uses_hf_assistant_mask(): """Test that HF assistant token mask is used when supported.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -639,7 +639,7 @@ def test_preprocess_batch_falls_back_to_regex(): """Test that preprocessing falls back to regex-based detection when HF mask is unavailable. """ - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -684,7 +684,7 @@ def patched_apply_chat_template(*args, **kwargs): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): """Test that minimum_valid_tokens drops short samples on regex path.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -727,7 +727,7 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): """Test that a sample is kept when valid tokens equal the threshold.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -779,7 +779,7 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): @pytest.mark.sanity def test_build_eagle3_dataset_basic(): """Test building EAGLE3 dataset from a simple HuggingFace dataset.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -813,7 +813,7 @@ def test_build_eagle3_dataset_basic(): @pytest.mark.sanity def test_build_eagle3_dataset_preserves_format(): """Test that build_eagle3_dataset sets the correct format.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -837,7 +837,7 @@ def test_build_eagle3_dataset_preserves_format(): @pytest.mark.sanity def test_build_eagle3_dataset_removes_original_columns(): """Test that original columns are removed after processing.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -864,7 +864,7 @@ def test_build_eagle3_dataset_removes_original_columns(): @pytest.mark.sanity def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): """Test that build_eagle3_dataset removes samples below the token threshold.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -933,7 +933,7 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): @pytest.mark.sanity def test_preprocess_batch_with_turn_dropout(): """Test preprocessing batch with turn dropout enabled.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") @@ -976,7 +976,7 @@ def test_detect_assistant_pattern_thinking_model(): but the pattern must still match real conversations where the think block contains substantial content. """ - processor = _load_processor("Qwen/Qwen3-8B", trust_remote_code=True) + processor = load_processor("Qwen/Qwen3-8B", trust_remote_code=True) pattern = _detect_assistant_pattern(processor) # Format a multi-turn conversation with thinking content injected @@ -1036,7 +1036,7 @@ def test_create_loss_mask_thinking_model(thinking_content): Verifies correct masking both with and without thinking content in the block. """ - processor = _load_processor("Qwen/Qwen3-8B", trust_remote_code=True) + processor = load_processor("Qwen/Qwen3-8B", trust_remote_code=True) pattern = _detect_assistant_pattern(processor) # Build formatted text using the real chat template @@ -1086,7 +1086,7 @@ def test_create_loss_mask_thinking_model(thinking_content): @pytest.mark.sanity def test_build_eagle3_dataset_with_custom_pattern(): """Test building dataset with custom assistant pattern.""" - processor = _load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: pytest.skip("Processor does not support chat templates") diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index 4187e9dd5..759d95907 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -13,7 +13,8 @@ from speculators.data_generation.preprocessing import ( _detect_assistant_pattern, - _load_processor, + get_tokenizer, + load_processor, _preprocess_batch, ) @@ -43,7 +44,7 @@ def processor(request): model_id = request.param try: # Using trust_remote_code=True for variety of templates - return _load_processor(model_id, trust_remote_code=True) + return load_processor(model_id, trust_remote_code=True) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: pytest.skip(f"Failed to load processor for {model_id}: {e}") @@ -53,9 +54,7 @@ def test_regex_detection_across_models(tmp_path, processor): Verify that _detect_assistant_pattern and _preprocess_batch (regex path) work correctly for a variety of model families. """ - tokenizer = ( - processor.tokenizer if isinstance(processor, ProcessorMixin) else processor - ) + tokenizer = get_tokenizer(processor) model_name = tokenizer.name_or_path log.info(f"Testing family: {model_name}") From 7f0c3176b6cbd7dd3c2fdfd8edf3eefdb0833312 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 03:55:10 +0000 Subject: [PATCH 88/95] Use graph by default for server Signed-off-by: DarkLight1337 --- tests/e2e/smoke/test_offline_training.py | 1 + tests/e2e/smoke/test_online_training.py | 1 + tests/e2e/smoke/test_resume_optimizer.py | 7 ++++++- tests/e2e/utils.py | 2 +- 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 73d5afe6b..0e38f7140 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -86,6 +86,7 @@ def test_offline_smoke( dataset=dataset, prompts=prompts, vllm_kwargs={ + "enforce_eager": True, "gpu_memory_utilization": 0.9, "allowed_local_media_path": vllm_media_path, }, diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 998f73715..1c927f05a 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -57,6 +57,7 @@ def test_online_smoke( dataset=dataset, prompts=prompts, vllm_kwargs={ + "enforce_eager": True, "allowed_local_media_path": vllm_media_path, }, ) diff --git a/tests/e2e/smoke/test_resume_optimizer.py b/tests/e2e/smoke/test_resume_optimizer.py index 7929bc32e..b793e2cd8 100644 --- a/tests/e2e/smoke/test_resume_optimizer.py +++ b/tests/e2e/smoke/test_resume_optimizer.py @@ -85,7 +85,12 @@ def test_resume_after_checkpoint_best(tmp_path: Path): run_prepare_data(MODEL, "sharegpt", data_path) # Step 2: Generate hidden states offline - with launch_vllm_server_context(MODEL, VLLM_PORT, str(tmp_path / "hidden_states")): + with launch_vllm_server_context( + MODEL, + VLLM_PORT, + str(tmp_path / "hidden_states"), + enforce_eager=True, + ): run_data_generation_offline(data_path, hidden_states_path, port=VLLM_PORT) # Step 3: Train 1 epoch with --save-best diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 9c85c05b4..259b48738 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -72,7 +72,7 @@ def launch_vllm_server( max_model_len: int = 513, gpu_memory_utilization: float = 0.5, target_layer_ids: list[int] | None = None, - enforce_eager: bool = True, + enforce_eager: bool = False, allowed_local_media_path: str | None = None, ) -> subprocess.Popen: """Launch a vLLM server configured for hidden-state extraction. From be339a81176d0330034f66117925bb197360dcdf Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 03:55:52 +0000 Subject: [PATCH 89/95] Change default Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 259b48738..58673ff69 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -309,7 +309,7 @@ def run_vllm_engine( prompts: list[list[dict[str, str]]], max_model_len: int = 1024, gpu_memory_utilization: float = 0.8, - enforce_eager: bool = True, + enforce_eager: bool = False, allowed_local_media_path: str | None = None, disable_compile_cache: bool = False, max_tokens: int = 50, From b6336ed0a90be5e6502c896986a7764997d6758c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 03:57:11 +0000 Subject: [PATCH 90/95] Fix llm args Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 58673ff69..5b153b050 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -323,28 +323,29 @@ def run_vllm_engine( run_vllm_file = str(Path(__file__).with_name("run_vllm.py")) results_file = str(tmp_path / "results.json") + sampling_params_dict = { + "temperature": 0, + "top_p": 0.9, + "max_tokens": max_tokens, + "ignore_eos": ignore_eos, + } + + llm_args_dict = { + "model": model_path, + "max_model_len": max_model_len, + "gpu_memory_utilization": gpu_memory_utilization, + "enforce_eager": enforce_eager, + } + if allowed_local_media_path is not None: + llm_args_dict["allowed_local_media_path"] = allowed_local_media_path + command = [ VLLM_PYTHON, run_vllm_file, "--sampling-params-args", - json.dumps( - { - "temperature": 0, - "top_p": 0.9, - "max_tokens": max_tokens, - "ignore_eos": ignore_eos, - } - ), + json.dumps(sampling_params_dict), "--llm-args", - json.dumps( - { - "model": model_path, - "max_model_len": max_model_len, - "gpu_memory_utilization": gpu_memory_utilization, - "enforce_eager": enforce_eager, - "allowed_local_media_path": allowed_local_media_path, - } - ), + json.dumps(llm_args_dict), "--prompts", json.dumps(prompts), "--results-file", From 1d318189d7520ac68c78e7a7f2d0b1507f62aae3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 03:58:15 +0000 Subject: [PATCH 91/95] Be more robust Signed-off-by: DarkLight1337 --- tests/e2e/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 5b153b050..ee4e90ac9 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -91,7 +91,7 @@ def launch_vllm_server( cmd += ["--target-layer-ids"] + [str(lid) for lid in target_layer_ids] if enforce_eager: cmd += ["--enforce-eager"] - if allowed_local_media_path: + if allowed_local_media_path is not None: cmd += ["--allowed-local-media-path", allowed_local_media_path] cmd += [ "--", From 5362c7a0b5c78ec718348b5a57fc1f58cf339f18 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 03:59:43 +0000 Subject: [PATCH 92/95] Fix whitespace Signed-off-by: DarkLight1337 --- scripts/data_generation_offline.py | 2 +- src/speculators/train/data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index c70c68171..ca1be6bc0 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -401,7 +401,7 @@ async def generate_and_save_hidden_states(args, dataset): if args.model and args.model != model_id: raise ValueError( f"An explicit model name was passed ({args.model}) which doesn't match" - f"found model_id {model_id}." + f" found model_id {model_id}." "Please make sure --endpoint is set to the correct vllm instance." ) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 15fa2e379..fc4a4af9a 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -244,7 +244,7 @@ def _setup_client(self): if self.model and self.model != model_id: raise ValueError( f"An explicit model name was passed ({self.model}) which doesn't match" - f"found model_id {model_id}." + f" found model_id {model_id}." "Please make sure --endpoint is set to the correct vllm instance." ) self.model = model_id From e509c6d9aeea15318ca5585a2b5a95af8f8b8524 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 04:01:18 +0000 Subject: [PATCH 93/95] Doc Signed-off-by: DarkLight1337 --- src/speculators/data_generation/vllm_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 44629c432..eccf915e3 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -115,7 +115,11 @@ def extract_output( class ClientItem(TypedDict): input_ids: list[int] + """The input token IDs.""" + messages: NotRequired[list[ChatCompletionMessageParam]] + """If provided, pass `messages` to Chat Completions API + instead of passing `token_ids` to Completions API.""" @with_retries @@ -133,9 +137,7 @@ async def generate_hidden_states_async( Args: client: The async OpenAI client. model: The model ID. - token_ids: The input token IDs. - messages: If provided, pass `messages` to Chat Completions API - instead of passing `token_ids` to Completions API. + client_item: Inputs to send via the client. timeout: Timeout in seconds for each request attempt. None for no timeout. """ token_ids = client_item["input_ids"] From 2d39e73cfa255c340e1267d1a128d28e4737dd82 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 04:02:37 +0000 Subject: [PATCH 94/95] Format Signed-off-by: DarkLight1337 --- src/speculators/train/utils.py | 1 + tests/integration/datagen/test_preprocessing.py | 2 +- tests/integration/datagen/test_regex_patterns.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py index 19947032e..5c2a87f66 100644 --- a/src/speculators/train/utils.py +++ b/src/speculators/train/utils.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard + from speculators.data_generation.preprocessing import get_tokenizer, load_processor local_rank = int(os.environ.get("LOCAL_RANK", "0")) diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index bee137ff6..44f385628 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -14,11 +14,11 @@ _adapt_conv_for_vllm, _create_loss_mask_from_offsets, _detect_assistant_pattern, - load_processor, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, build_eagle3_dataset, + load_processor, ) # Test model from HuggingFace with chat template diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index 759d95907..3c68026ff 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -13,9 +13,9 @@ from speculators.data_generation.preprocessing import ( _detect_assistant_pattern, + _preprocess_batch, get_tokenizer, load_processor, - _preprocess_batch, ) # Test models covering major template families From ee95bd60da6d3c27a45c25d066f6946bdaf6741c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 22 May 2026 04:05:05 +0000 Subject: [PATCH 95/95] mypy Signed-off-by: DarkLight1337 --- src/speculators/data_generation/preprocessing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 417547514..e7791e8cc 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -623,7 +623,10 @@ def load_raw_dataset( def get_tokenizer(processor: ProcessorLike): - return processor.tokenizer if isinstance(processor, ProcessorMixin) else processor + if isinstance(processor, ProcessorMixin): + return processor.tokenizer # type: ignore[attr-defined] + + return processor def _resolve_pad_token(processor: ProcessorLike):