diff --git a/docs/cli/prepare_data.md b/docs/cli/prepare_data.md index f4fc40066..fab7cb788 100644 --- a/docs/cli/prepare_data.md +++ b/docs/cli/prepare_data.md @@ -26,6 +26,8 @@ python scripts/prepare_data.py \ Example: `meta-llama/Llama-3.1-8B-Instruct` +- **`--trust-remote-code`** (flag) Allow executing code from HF Hub when loading the target model's processor. + ### Data Arguments - **`--data`** (str, required, repeatable) Path to training data. Can be a HuggingFace dataset name or local path. Use multiple times to specify multiple datasets. diff --git a/docs/cli/train.md b/docs/cli/train.md index 570fc4aaf..dee322121 100644 --- a/docs/cli/train.md +++ b/docs/cli/train.md @@ -32,6 +32,8 @@ torchrun --standalone --nproc_per_node=4 scripts/train.py \ - **`--verifier-name-or-path`** (str, required) HuggingFace model ID or local path for the verifier/target model. +- **`--trust-remote-code`** (flag) Allow executing code from HF Hub when loading the verifier's tokenizer. + - **`--speculator-type`** (str, default: `"eagle3"`) Type of speculator model to train. Options: `eagle3`, `dflash` - **`--from-pretrained`** (str, default: `""`) Path to a pretrained draft model to finetune. diff --git a/pyproject.toml b/pyproject.toml index 3428159cc..dd5d3a65f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ dependencies = [ "safetensors", "setuptools", "torch>=2.9.0,<=2.11.0", + "torchaudio", + "torchvision", "tqdm>=4.66.3,<=4.67.3", "transformers>=4.56.1,<5.9.0", "typer>=0.12.0", @@ -249,6 +251,7 @@ select = [ "PTH", # os.path is acceptable in scripts "T201", # print statements are acceptable in scripts "SLF001", # allow private member access for model configuration + "PLR0915", # allow long parse_args functions ] "examples/**/*.py" = [ diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 04882c09b..ca1be6bc0 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -31,6 +31,7 @@ DEFAULT_REQUEST_TIMEOUT, generate_hidden_states_async, ) +from speculators.train.data import build_client_item from speculators.train.logger import setup_root_logger logger = logging.getLogger(__name__) @@ -66,8 +67,8 @@ def parse_args(): type=str, default=None, help=( - "HuggingFace model ID or local path for target model (default auto select)." - "For verification purposes only." + "HuggingFace model ID or local path for target model " + "(default auto select). For verification purposes only." ), ) parser.add_argument( @@ -113,7 +114,7 @@ def parse_args(): type=int, default=32, help=( - "Number of active vLLM requests at a time." + "Number of active vLLM requests at a time. " "Note: number of async workers set to 2*concurrency" ), ) @@ -121,8 +122,8 @@ def parse_args(): "--validate-outputs", action="store_true", help=( - "Load generated safetensor files and check output token ids match prompt" - " tokens and hidden states seq_len matches num tokens" + "Load generated safetensor files and check output token ids match " + "prompt tokens and hidden states seq_len matches num tokens" ), ) parser.add_argument( @@ -276,8 +277,6 @@ async def worker( queue.task_done() continue - input_ids = item["input_ids"].tolist() - target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" try: @@ -285,7 +284,7 @@ async def worker( hidden_states_path = await generate_hidden_states_async( client, model, - input_ids, + item, timeout=request_timeout, max_retries=max_retries, ) @@ -295,7 +294,9 @@ async def worker( ) if validate_outputs: await asyncio.to_thread( - check_safetensors_file, target_hidden_states_path, input_ids + check_safetensors_file, + target_hidden_states_path, + item["input_ids"], ) except Exception as e: if fail_on_error: @@ -325,12 +326,15 @@ async def _feed_queue(to_process, dataset, queue, cancel_event): for i in to_process: if cancel_event.is_set(): break - item = dataset[i] + + dataset_item = dataset[i] + client_item = build_client_item(dataset_item) | {"idx": i} + # Check cancel_event while waiting for queue space to avoid # deadlocking when all workers have died. while not cancel_event.is_set(): try: - queue.put_nowait({"idx": i, "input_ids": item["input_ids"]}) + queue.put_nowait(client_item) break except asyncio.QueueFull: await asyncio.sleep(0.1) @@ -397,7 +401,7 @@ async def generate_and_save_hidden_states(args, dataset): if args.model and args.model != model_id: raise ValueError( f"An explicit model name was passed ({args.model}) which doesn't match" - "found model_id {model_id}." + f" found model_id {model_id}." "Please make sure --endpoint is set to the correct vllm instance." ) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index dff5597df..27bde9073 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -49,6 +49,14 @@ def parse_args(): required=True, help="HuggingFace model ID or local path for target model", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help=( + "Allow executing code from HF Hub when loading the target model's " + "processor." + ), + ) # Data arguments parser.add_argument( @@ -75,7 +83,7 @@ def parse_args(): type=str, default=None, help=( - "Path to save token frequency distribution" + "Path to save token frequency distribution " "(default: args.output / 'token_freq.pt')" ), ) @@ -177,6 +185,7 @@ def main(): assistant_pattern=args.assistant_pattern, turn_dropout=args.turn_dropout, minimum_valid_tokens=args.minimum_valid_tokens, + trust_remote_code=args.trust_remote_code, ) log.info("Done preparing data") diff --git a/scripts/train.py b/scripts/train.py index 0bb4cf403..cbc6e5574 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -259,6 +259,7 @@ def main(args: argparse.Namespace): args.verifier_name_or_path, transformer_layer_config.vocab_size, args.mask_token_id, + trust_remote_code=args.trust_remote_code, ) registry = SpeculatorModel.registry @@ -398,6 +399,11 @@ def _checkpoint_freq(value: str) -> float: def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--verifier-name-or-path", type=str, required=True) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow executing code from HF Hub when loading the verifier's tokenizer.", + ) parser.add_argument( "--speculator-type", type=str, diff --git a/src/speculators/data_generation/configs.py b/src/speculators/data_generation/configs.py index a27015ccd..fcd74b0ae 100644 --- a/src/speculators/data_generation/configs.py +++ b/src/speculators/data_generation/configs.py @@ -1,5 +1,6 @@ """Configuration registries for data generation pipeline.""" +import os from collections.abc import Callable from dataclasses import dataclass @@ -9,14 +10,15 @@ ] -@dataclass +@dataclass(kw_only=True) class DatasetConfig: """Configuration for loading a dataset""" name: str hf_path: str - split: str subset: str | None = None + split: str + filter_fn: Callable[[dict], bool] | None = None normalize_fn: Callable[[dict], dict] | None = None @@ -35,6 +37,60 @@ def _normalize_gsm8k(example: dict) -> dict: } +def get_coco_dir(): + return os.getenv("COCO_DIR") or "coco/" + + +def _parse_sharegpt4v_part(part: str, image_path: str): + if part == "": + return {"type": "image", "path": image_path} + + return {"type": "text", "text": part} + + +def _parse_sharegpt4v_user_content(content: str, image_path: str): + return [_parse_sharegpt4v_part(part, image_path) for part in content.split("\n")] + + +def _parse_sharegpt4v_assistant_content(content: str): + return [{"type": "text", "text": content}] + + +def _filter_sharegpt4v_coco(example: dict) -> bool: + return example["image"].startswith("coco/") + + +def _normalize_sharegpt4v_coco(example: dict) -> dict: + coco_dir = get_coco_dir() + image_path = os.path.join(coco_dir, example["image"].removeprefix("coco/")) + + if not os.path.exists(image_path): + state_str = "set to" if os.getenv("COCO_DIR") else "default" + + raise ValueError( + f"No image found at <{image_path}>. " + f"Please download COCO 2017 Train Images from " + f" and place the " + f"extracted folder under `COCO_DIR` ({state_str}: `{coco_dir}`)." + ) + + messages = [ + ( + turn + | { + "value": ( + _parse_sharegpt4v_user_content(turn["value"], image_path) + if turn["from"] in ("human", "user") + else _parse_sharegpt4v_assistant_content(turn["value"]) + ) + } + ) + for turn in example["conversations"] + ] + + return {"conversations": messages} + + DATASET_CONFIGS: dict[str, DatasetConfig] = { "sharegpt": DatasetConfig( name="sharegpt", @@ -50,8 +106,17 @@ def _normalize_gsm8k(example: dict) -> dict: "gsm8k": DatasetConfig( name="gsm8k", hf_path="openai/gsm8k", - split="train", subset="main", + split="train", normalize_fn=_normalize_gsm8k, ), + # NOTE: You need to serve vLLM with `--allowed-local-media-path /path/to/coco` + "sharegpt4v_coco": DatasetConfig( + name="sharegpt4v_coco", + hf_path="Lin-Chen/ShareGPT4V", + subset="ShareGPT4V", + split="train", + filter_fn=_filter_sharegpt4v_coco, + normalize_fn=_normalize_sharegpt4v_coco, + ), } diff --git a/src/speculators/data_generation/preprocessing.py b/src/speculators/data_generation/preprocessing.py index 9d9abcffd..e7791e8cc 100644 --- a/src/speculators/data_generation/preprocessing.py +++ b/src/speculators/data_generation/preprocessing.py @@ -1,17 +1,28 @@ import bisect import random import re +from collections.abc import Callable +from contextlib import nullcontext from pathlib import Path from re import Pattern -from typing import Any, cast +from typing import cast import torch from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from packaging.version import Version +from transformers import ( + AutoProcessor, + BatchEncoding, + BatchFeature, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from transformers import __version__ as TRANSFORMERS_VERSION # noqa: N812 from speculators.data_generation.configs import DATASET_CONFIGS from speculators.data_generation.logging_utils import PipelineLogger +from speculators.data_generation.torch_utils import set_default_torch_num_threads from speculators.train.vocab_mapping import save_token_frequency_distribution __all__ = [ @@ -23,12 +34,15 @@ log = PipelineLogger(__name__) -def _visualize_sample(preprocessed, tokenizer, idx: int = 0): +ProcessorLike = PreTrainedTokenizerBase | ProcessorMixin + + +def _visualize_sample(preprocessed: HFDataset, processor: ProcessorLike, idx: int = 0): """Visualize a single sample with color-coded trainable regions.""" # Get preprocessed sample prep_sample = preprocessed[idx] - input_ids = prep_sample["input_ids"] - loss_mask = prep_sample["loss_mask"] + input_ids = prep_sample["input_ids"].tolist() + loss_mask = prep_sample["loss_mask"].tolist() log.info(f"SAMPLE #{idx}") log.info("HIGHLIGHTED TEXT (BLUE = trainable, GREY = masked)") @@ -42,8 +56,9 @@ def _visualize_sample(preprocessed, tokenizer, idx: int = 0): prev_state = None for i in range(len(input_ids)): - is_train = loss_mask[i].item() == 1 - token = tokenizer.decode([input_ids[i].item()]) + is_train = loss_mask[i] == 1 + token = processor.decode([input_ids[i]]) + assert isinstance(token, str) # Switch colors when state changes if is_train != prev_state: @@ -107,19 +122,103 @@ def _normalize_conversation( return normalized -def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: - """Check if tokenizer truly supports HF assistant token mask. +def _adapt_part_for_hf(part: str | dict, processor: ProcessorLike): + if isinstance(part, str) and isinstance(processor, ProcessorMixin): + return {"type": "text", "text": part} + + return part + + +def _adapt_turn_for_hf(turn: dict, processor: ProcessorLike): + if isinstance(turn["content"], str): + if isinstance(processor, ProcessorMixin): + return turn | {"content": [_adapt_part_for_hf(turn["content"], processor)]} + + return turn + + return turn | { + "content": [_adapt_part_for_hf(part, processor) for part in turn["content"]] + } + + +def _adapt_conv_for_hf(normalized_conv: list[dict], processor: ProcessorLike): + return [_adapt_turn_for_hf(turn, processor) for turn in normalized_conv] + + +def _adapt_part_for_vllm(part: str | dict): + if isinstance(part, str): + return {"type": "text", "text": part} + + part_type = part["type"] + + if part_type == "text": + return {"type": "text", "text": part["text"]} + + for modality in ("image", "video", "audio"): + if part_type == modality: + if local_path := part.get("path"): + file_url = f"file://{Path(local_path).absolute()}" + return {"type": f"{modality}_url", f"{modality}_url": {"url": file_url}} + if url := part.get("url"): + return {"type": f"{modality}_url", f"{modality}_url": {"url": url}} + + if part.get("base64"): + expr = {"type": modality, "base64": "..."} + raise ValueError( + f"Content part {expr} is not supported. To avoid copying " + f"the {modality} when saving the preprocessed dataset, " + f"please express {modality} inputs using file paths or URLs." + ) + if part.get(modality): + expr = {"type": modality, modality: "..."} + raise ValueError( + f"Content part {expr} is not supported. To avoid copying " + f"the {modality} when saving the preprocessed dataset, " + f"please express {modality} inputs using file paths or URLs." + ) + + expr = {"type": modality} | {k: "..." for k in part if k != "type"} + raise NotImplementedError(f"Unknown content part: {expr}") + + expr = dict.fromkeys(part.keys(), "...") + raise NotImplementedError(f"Unknown content part: {expr}") + + +def _adapt_turn_for_vllm(turn: dict): + if isinstance(turn["content"], str): + return turn + + return turn | {"content": [_adapt_part_for_vllm(part) for part in turn["content"]]} + + +def _adapt_conv_for_vllm(normalized_conv: list[dict]): + return [_adapt_turn_for_vllm(turn) for turn in normalized_conv] + + +def _supports_assistant_mask(processor: ProcessorLike) -> bool: + """Check if processor truly supports HF assistant token mask. Must return a non-zero mask for a conversation containing an assistant message. """ + # NOTE: Some models (e.g. Qwen3.5) require a user message in the conversation, + # even though this check only looks at the assistant turn + test_conv = _adapt_conv_for_hf( + [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "test"}, + ], + processor, + ) + try: - res_any = tokenizer.apply_chat_template( - [{"role": "assistant", "content": "test"}], + res_any = processor.apply_chat_template( + test_conv, tokenize=True, return_assistant_tokens_mask=True, return_dict=True, ) - res = cast("dict[str, Any]", res_any) + res = cast("BatchEncoding | BatchFeature", res_any) + # Check both singular and plural key names mask = res.get("assistant_masks", res.get("assistant_mask")) if mask is None: @@ -127,24 +226,28 @@ def _supports_assistant_mask(tokenizer: PreTrainedTokenizerBase) -> bool: # Verify the mask is not all zeros return any(m == 1 for m in mask) - except (TypeError, ValueError, KeyError, AttributeError): + except (TypeError, ValueError, KeyError, AttributeError) as e: + log.warning(f"An error occurred when trying to return assistant mask: {e}") return False -def _detect_assistant_pattern(tokenizer: PreTrainedTokenizerBase) -> str: - """Auto-detect the assistant message pattern from the tokenizer's chat template. +def _detect_assistant_pattern(processor: ProcessorLike) -> str: + """Auto-detect the assistant message pattern from the processor's chat template. Uses multi-turn conversation but extracts pattern from the LAST assistant message only. """ - test_conv = [ - {"role": "user", "content": "USER_MSG_1"}, - {"role": "assistant", "content": "ASSISTANT_MSG_1"}, - {"role": "user", "content": "USER_MSG_2"}, - {"role": "assistant", "content": "ASSISTANT_MSG_2"}, - ] - - formatted = tokenizer.apply_chat_template( + test_conv = _adapt_conv_for_hf( + [ + {"role": "user", "content": "USER_MSG_1"}, + {"role": "assistant", "content": "ASSISTANT_MSG_1"}, + {"role": "user", "content": "USER_MSG_2"}, + {"role": "assistant", "content": "ASSISTANT_MSG_2"}, + ], + processor, + ) + + formatted = processor.apply_chat_template( test_conv, tokenize=False, add_generation_prompt=False ) assert isinstance(formatted, str), "Expected string from apply_chat_template" @@ -224,6 +327,10 @@ def _create_loss_mask_from_offsets( text: str, offsets: list[tuple[int, int]], assistant_pattern: str | Pattern[str], + *, + # For logging + conv_idx: int | None = None, + max_length: int | None = None, ) -> torch.Tensor: """Create loss mask by finding assistant response spans in formatted text.""" loss_mask = torch.zeros(len(offsets), dtype=torch.bool) @@ -250,14 +357,121 @@ def _create_loss_mask_from_offsets( loss_mask[idx] = 1 if matches_found == 0: - log.warning("No assistant response spans found in conversation") + warning_msg = "No assistant response spans found in conversation" + if conv_idx is not None: + warning_msg += f" {conv_idx}" + + suggestion_msg = "" + if max_length is not None and len(offsets) == max_length: + suggestion_msg += ( + "Consider increasing --seq-length to avoid truncating " + "the assistant response." + ) + + log.warning(f"{warning_msg}. {suggestion_msg}") return loss_mask +def _get_input_ids_loss_mask( + normalized_conv: list[dict], + processor: ProcessorLike, + max_length: int, + assistant_pattern: str | Pattern[str] | None, + *, + # For logging + conv_idx: int | None = None, +): + hf_conv = _adapt_conv_for_hf(normalized_conv, processor) + + if assistant_pattern is None: + # HF assistant token mask + encoded_any = processor.apply_chat_template( + hf_conv, + tokenize=True, + add_generation_prompt=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) + encoded = cast("BatchEncoding | BatchFeature", encoded_any) + + # input IDs and loss mask + input_ids = encoded["input_ids"] + # HF uses 'assistant_masks' in recent versions + mask_key = ( + "assistant_masks" if "assistant_masks" in encoded else "assistant_mask" + ) + loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) + + return input_ids, loss_mask + + # Fallback: regex-based detection + assert assistant_pattern is not None, "Assistant pattern required for fallback" + + processor_kwargs: dict = { + "return_offsets_mapping": True, + "max_length": max_length, + "truncation": True, + "add_special_tokens": False, + } + + if isinstance(processor, ProcessorMixin): + if Version(TRANSFORMERS_VERSION) >= Version("5.4.0"): + encoded_any = processor.apply_chat_template( + hf_conv, + tokenize=True, + add_generation_prompt=False, + return_dict=True, + processor_kwargs=processor_kwargs, + ) + else: + encoded_any = processor.apply_chat_template( + hf_conv, + tokenize=True, + add_generation_prompt=False, + return_dict=True, + **processor_kwargs, + ) + + encoded = cast("BatchFeature", encoded_any) + + # Remove batch dimension + (input_ids,) = encoded["input_ids"] + (offsets,) = encoded["offset_mapping"] + + # MM placeholder tokens are inserted separate from chat template + formatted_text = processor.decode(input_ids) + assert isinstance(formatted_text, str) + else: + # More optimized flow for text-only processors (i.e. tokenizers) + formatted_text = processor.apply_chat_template( + hf_conv, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(formatted_text, str) + + # Tokenize and get offsets + encoded_any = processor(formatted_text, **processor_kwargs) + encoded = cast("BatchEncoding", encoded_any) + + input_ids = encoded["input_ids"] + offsets = encoded["offset_mapping"] + + loss_mask = _create_loss_mask_from_offsets( + formatted_text, + offsets, + assistant_pattern, + conv_idx=conv_idx, + max_length=max_length, + ) + + return input_ids, loss_mask + + def _preprocess_batch( examples: dict, - tokenizer: PreTrainedTokenizerBase, + processor: ProcessorLike, max_length: int, assistant_pattern: str | Pattern[str] | None, turn_dropout: bool = False, @@ -266,7 +480,11 @@ def _preprocess_batch( """Process a batch of conversations into tokenized format with loss masks.""" results: dict[str, list] = {"input_ids": [], "loss_mask": [], "seq_len": []} - conversations = examples.get("conversations", []) + conversations: list[dict] = examples.get("conversations", []) + + # MM inputs must use Chat Completions API + if isinstance(processor, ProcessorMixin): + results["messages"] = [] if not conversations: log.warning(f"No conversations key found. Keys: {list(examples.keys())}") @@ -282,73 +500,13 @@ def _preprocess_batch( continue try: - if assistant_pattern is None: - # HF assistant token mask - encoded_any = tokenizer.apply_chat_template( - normalized_conv, - tokenize=True, - add_generation_prompt=False, - return_assistant_tokens_mask=True, - return_dict=True, - ) - encoded = cast("dict[str, Any]", encoded_any) - - # input IDs and loss mask - input_ids = encoded["input_ids"] - # HF uses 'assistant_masks' in recent versions - mask_key = ( - "assistant_masks" - if "assistant_masks" in encoded - else "assistant_mask" - ) - loss_mask = torch.tensor(encoded[mask_key], dtype=torch.long) - - else: - # Fallback: regex-based detection - assert assistant_pattern is not None, ( - "Assistant pattern required for fallback" - ) - formatted_raw = tokenizer.apply_chat_template( - normalized_conv, - tokenize=False, - add_generation_prompt=False, - ) - assert isinstance(formatted_raw, str) - - # Tokenize and get offsets - encoding = tokenizer( - formatted_raw, - return_offsets_mapping=True, - max_length=max_length, - truncation=True, - add_special_tokens=False, - ) - - # input IDs and loss mask - input_ids = encoding["input_ids"] - offsets = encoding["offset_mapping"] - - loss_mask = _create_loss_mask_from_offsets( - formatted_raw, offsets, assistant_pattern - ) - - # Assert shapes match - assert len(input_ids) == len(loss_mask), ( - f"Shape mismatch: input_ids={len(input_ids)}, " - f"loss_mask={len(loss_mask)}" + input_ids, loss_mask = _get_input_ids_loss_mask( + normalized_conv, + processor, + max_length=max_length, + assistant_pattern=assistant_pattern, + conv_idx=idx, ) - - # Filtering samples out with too few valid tokens - if minimum_valid_tokens is not None: - num_valid_tokens = int(loss_mask.sum().item()) - if num_valid_tokens < minimum_valid_tokens: - continue - - # Append to results - results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) - results["loss_mask"].append(loss_mask) - results["seq_len"].append(len(input_ids)) - except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: log.error( f"Failed to process conversation {idx} " @@ -356,12 +514,31 @@ def _preprocess_batch( ) continue + # Assert shapes match + assert len(input_ids) == len(loss_mask), ( + f"Shape mismatch: input_ids={len(input_ids)}, loss_mask={len(loss_mask)}" + ) + + # Filtering samples out with too few valid tokens + if minimum_valid_tokens is not None: + num_valid_tokens = int(loss_mask.sum().item()) + if num_valid_tokens < minimum_valid_tokens: + continue + + # Append to results + results["input_ids"].append(torch.tensor(input_ids, dtype=torch.long)) + results["loss_mask"].append(loss_mask) + results["seq_len"].append(len(input_ids)) + + if "messages" in results: + results["messages"].append(_adapt_conv_for_vllm(normalized_conv)) + return results def build_eagle3_dataset( dataset: HFDataset, - tokenizer: PreTrainedTokenizerBase, + processor: ProcessorLike, max_length: int = 2048, num_proc: int = 8, assistant_pattern: str | Pattern[str] | None = None, @@ -370,11 +547,11 @@ def build_eagle3_dataset( ) -> HFDataset: """Build EAGLE3 dataset by tokenizing conversations and creating loss masks. - Uses the tokenizer's built-in chat template via apply_chat_template. + Uses the processor's built-in chat template via apply_chat_template. Args: dataset: Raw dataset with conversations - tokenizer: Tokenizer with chat template support + processor: Processor with chat template support max_length: Maximum sequence length num_proc: Number of processes for parallel processing assistant_pattern: Optional custom regex pattern for matching assistant @@ -387,39 +564,48 @@ def build_eagle3_dataset( # Detect and use provided assistant message pattern if assistant_pattern is not None: log.info(f"Using custom assistant pattern: {str(assistant_pattern)[:80]}...") - elif _supports_assistant_mask(tokenizer): + elif _supports_assistant_mask(processor): assistant_pattern = None # Signal to use HF mask in _preprocess_batch log.info("Using HF assistant token mask for loss masking") else: - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...") original_cols = dataset.column_names - dataset = dataset.map( - lambda examples: _preprocess_batch( - examples, - tokenizer, - max_length, - assistant_pattern, - turn_dropout, - minimum_valid_tokens, - ), - batched=True, - num_proc=num_proc, - batch_size=1000, - remove_columns=original_cols, - keep_in_memory=True, # skip caching - ) + # Avoid CPU contention for MM processing: + # https://github.com/vllm-project/vllm/pull/31879 + with ( + set_default_torch_num_threads() + if isinstance(processor, ProcessorMixin) + else nullcontext() + ): + dataset = dataset.map( + lambda examples: _preprocess_batch( + examples, + processor, + max_length, + assistant_pattern, + turn_dropout, + minimum_valid_tokens, + ), + batched=True, + num_proc=num_proc, + batch_size=1000, + remove_columns=original_cols, + keep_in_memory=True, # skip caching + ) dataset.set_format(type="torch") return dataset -def load_raw_dataset(train_data_path: str, num_proc: int = 8) -> HFDataset: +def load_raw_dataset( + train_data_path: str, +) -> tuple[HFDataset, Callable[[dict], dict] | None]: """Load raw dataset from local file or HuggingFace.""" if train_data_path.endswith((".jsonl", ".json")): - return load_dataset("json", data_files=train_data_path, split="train") + return load_dataset("json", data_files=train_data_path, split="train"), None if train_data_path not in DATASET_CONFIGS: raise ValueError( @@ -430,15 +616,39 @@ def load_raw_dataset(train_data_path: str, num_proc: int = 8) -> HFDataset: config = DATASET_CONFIGS[train_data_path] raw_dataset = load_dataset(config.hf_path, name=config.subset, split=config.split) - if config.normalize_fn is not None: - raw_dataset = raw_dataset.map(config.normalize_fn, num_proc=num_proc) + if config.filter_fn is not None: + raw_dataset = raw_dataset.filter(config.filter_fn) + + return raw_dataset, config.normalize_fn - return raw_dataset + +def get_tokenizer(processor: ProcessorLike): + if isinstance(processor, ProcessorMixin): + return processor.tokenizer # type: ignore[attr-defined] + + return processor + + +def _resolve_pad_token(processor: ProcessorLike): + tokenizer = get_tokenizer(processor) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + +def load_processor(target_model_path: str, *, trust_remote_code: bool = False): + processor = AutoProcessor.from_pretrained( + target_model_path, + trust_remote_code=trust_remote_code, + ) + _resolve_pad_token(processor) + + return processor def load_and_preprocess_dataset( target_model_path: str, train_data_paths: list[str], + *, seq_length: int, build_dataset_num_proc: int = 8, seed: int = 0, @@ -447,10 +657,11 @@ def load_and_preprocess_dataset( assistant_pattern: str | None = None, turn_dropout: bool = False, minimum_valid_tokens: int | None = None, -) -> tuple[HFDataset, PreTrainedTokenizerBase]: + trust_remote_code: bool = False, +) -> tuple[HFDataset, ProcessorLike]: """Load, tokenize, and preprocess a dataset for EAGLE3 training. - Uses the tokenizer's built-in chat template via apply_chat_template. + Uses the processor's built-in chat template via apply_chat_template. Caching is handled automatically by HuggingFace datasets. Args: @@ -468,9 +679,10 @@ def load_and_preprocess_dataset( turn_dropout: If True, randomly keeps first N consecutive turns per conversation minimum_valid_tokens: Number of tokens to consider for a valid sample + trust_remote_code: If True, allows executing code from HF Hub. Returns: - Tuple of (preprocessed_dataset, tokenizer) + Tuple of (preprocessed_dataset, processor) """ if minimum_valid_tokens is not None and minimum_valid_tokens < 0: raise ValueError("minimum_valid_tokens must be >= 0") @@ -480,21 +692,19 @@ def load_and_preprocess_dataset( f"Filtering samples with fewer than {minimum_valid_tokens} valid tokens" ) - log.subsection("Loading tokenizer") - tokenizer = AutoTokenizer.from_pretrained(target_model_path, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + log.subsection("Loading processor") + processor = load_processor(target_model_path, trust_remote_code=trust_remote_code) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: raise ValueError( - f"Tokenizer for {target_model_path} does not support chat templates. " + f"Processor for {target_model_path} does not support chat templates. " "Please use a model with a pre-configured chat template." ) processed_datasets = [] for train_data_path in train_data_paths: log.subsection(f"Processing {train_data_path}") - raw_dataset = load_raw_dataset(train_data_path, num_proc=build_dataset_num_proc) + raw_dataset, normalize_fn = load_raw_dataset(train_data_path) raw_dataset = raw_dataset.shuffle(seed=seed) if max_samples is not None and len(raw_dataset) > 3 * max_samples: @@ -503,6 +713,13 @@ def load_and_preprocess_dataset( # after combining datasets and shuffling raw_dataset = raw_dataset.select(range(3 * max_samples)) + if normalize_fn is not None: + raw_dataset = raw_dataset.map( + normalize_fn, + num_proc=build_dataset_num_proc, + keep_in_memory=True, # skip caching + ) + log.info(f"Loaded {len(raw_dataset)} samples") if turn_dropout: @@ -510,7 +727,7 @@ def load_and_preprocess_dataset( preprocessed_dataset = build_eagle3_dataset( dataset=raw_dataset, - tokenizer=tokenizer, + processor=processor, max_length=seq_length, num_proc=build_dataset_num_proc, assistant_pattern=assistant_pattern, @@ -533,8 +750,8 @@ def load_and_preprocess_dataset( ) log.subsection("Visualizing sample") - _visualize_sample(combined_dataset, tokenizer, idx=0) + _visualize_sample(combined_dataset, processor, idx=0) log.section("Dataset preprocessing complete") - return combined_dataset, tokenizer + return combined_dataset, processor diff --git a/src/speculators/data_generation/torch_utils.py b/src/speculators/data_generation/torch_utils.py new file mode 100644 index 000000000..6c2db5af2 --- /dev/null +++ b/src/speculators/data_generation/torch_utils.py @@ -0,0 +1,42 @@ +import contextlib +import os + +import torch + +from speculators.data_generation.logging_utils import PipelineLogger + +log = PipelineLogger(__name__) + + +# Based on vLLM's util with the same name +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int | None = None): + """ + Sets the default number of threads for PyTorch to the given value. + + `None` means using the value of the environment variable `OMP_NUM_THREADS` + (or `1` if that is not available). + """ + if num_threads is None: + num_threads = 1 + + try: + num_threads = int(os.environ["OMP_NUM_THREADS"]) + except KeyError: + log.debug( + f"OMP_NUM_THREADS is not set; " + f"defaulting Torch threads to {num_threads}.", + ) + except ValueError: + log.warning( + f"OMP_NUM_THREADS is invalid; " + f"defaulting Torch threads to {num_threads}.", + ) + + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + + try: + yield + finally: + torch.set_num_threads(old_num_threads) diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index aa2cfc385..2b1a1f97b 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -2,8 +2,15 @@ import functools import logging import time +from typing import TYPE_CHECKING, Any, TypedDict import openai +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam +from openai.types.completion import Completion +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from collections.abc import Coroutine logger = logging.getLogger(__name__) @@ -82,8 +89,14 @@ def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs): return sync_wrapper -def extract_output(completion, token_ids) -> str: - prompt_token_ids = getattr(completion.choices[0], "prompt_token_ids", None) +def extract_output( + response: Completion | ChatCompletion, + token_ids: list[int], +) -> str: + if isinstance(response, Completion): + prompt_token_ids = getattr(response.choices[0], "prompt_token_ids", None) + else: + prompt_token_ids = getattr(response, "prompt_token_ids", None) if prompt_token_ids is None: raise InvalidResponseError("Response missing prompt_token_ids") @@ -93,17 +106,28 @@ def extract_output(completion, token_ids) -> str: f"Prompt token IDs mismatch: expected {token_ids}, got {prompt_token_ids}" ) - if not hasattr(completion, "kv_transfer_params"): + kv_transfer_params = getattr(response, "kv_transfer_params", None) + if kv_transfer_params is None: raise InvalidResponseError("Response missing kv_transfer_params") - return completion.kv_transfer_params.get("hidden_states_path") + return kv_transfer_params.get("hidden_states_path") + + +class ClientItem(TypedDict): + input_ids: list[int] + """The input token IDs.""" + + messages: NotRequired[list[ChatCompletionMessageParam]] + """If provided, pass `messages` to Chat Completions API + instead of passing `token_ids` to Completions API.""" @with_retries async def generate_hidden_states_async( client: openai.AsyncClient, model: str, - token_ids: list[int], + client_item: ClientItem, + *, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ @@ -113,40 +137,70 @@ async def generate_hidden_states_async( Args: client: The async OpenAI client. model: The model ID. - token_ids: The input token IDs. + client_item: Inputs to send via the client. timeout: Timeout in seconds for each request attempt. None for no timeout. """ - coro = client.completions.create( - model=model, - prompt=token_ids, - max_tokens=1, - extra_body={"return_token_ids": True}, - timeout=timeout, - ) + token_ids = client_item["input_ids"] + messages = client_item.get("messages") + + coro: Coroutine[Any, Any, Completion | ChatCompletion] + if messages is None: + coro = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + else: + coro = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=1, + extra_body={"add_generation_prompt": False, "return_token_ids": True}, + timeout=timeout, + ) + + res: Completion | ChatCompletion if timeout is not None: - completion = await asyncio.wait_for(coro, timeout=timeout) + res = await asyncio.wait_for(coro, timeout=timeout) else: - completion = await coro + res = await coro - return extract_output(completion, token_ids) + return extract_output(res, token_ids) @with_retries def generate_hidden_states( client: openai.Client, model: str, - token_ids: list[int], + client_item: ClientItem, + *, timeout: float | None = DEFAULT_REQUEST_TIMEOUT, ) -> str: """ Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file. """ - completion = client.completions.create( - model=model, - prompt=token_ids, - max_tokens=1, - extra_body={"return_token_ids": True}, - timeout=timeout, - ) - return extract_output(completion, token_ids) + token_ids = client_item["input_ids"] + messages = client_item.get("messages") + + res: Completion | ChatCompletion + if messages is None: + res = client.completions.create( + model=model, + prompt=token_ids, + max_tokens=1, + extra_body={"return_token_ids": True}, + timeout=timeout, + ) + else: + res = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=1, + extra_body={"add_generation_prompt": False, "return_token_ids": True}, + timeout=timeout, + ) + + return extract_output(res, token_ids) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index b1e05c91e..fc4a4af9a 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -7,7 +7,7 @@ from collections.abc import Callable from os import PathLike from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import openai import torch @@ -19,6 +19,7 @@ from speculators.data_generation.vllm_client import ( DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT, + ClientItem, generate_hidden_states, ) from speculators.train.noise_transforms import TransformTensors @@ -105,6 +106,16 @@ def standardize_data_v1(data: dict[str, Any]) -> dict[str, Any]: } +def build_client_item(dataset_item: dict) -> ClientItem: + out_dict = {} + out_dict["input_ids"] = dataset_item["input_ids"].tolist() + + if "messages" in dataset_item: + out_dict["messages"] = dataset_item["messages"] + + return cast("ClientItem", out_dict) + + class BaseDataset(Dataset): def __init__( self, @@ -233,7 +244,7 @@ def _setup_client(self): if self.model and self.model != model_id: raise ValueError( f"An explicit model name was passed ({self.model}) which doesn't match" - "found model_id {model_id}." + f" found model_id {model_id}." "Please make sure --endpoint is set to the correct vllm instance." ) self.model = model_id @@ -257,12 +268,14 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: if not self.client: self._setup_client() - input_ids = self.data[index]["input_ids"].tolist() + dataset_item = self.data[index] + client_item = build_client_item(dataset_item) + try: hs_filepath = generate_hidden_states( self.client, # type:ignore[arg-type] self.model, # type:ignore[arg-type] - input_ids, + client_item, timeout=self.request_timeout, max_retries=self.max_retries, ) diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py index 82872796f..5c2a87f66 100644 --- a/src/speculators/train/utils.py +++ b/src/speculators/train/utils.py @@ -5,7 +5,8 @@ import torch import torch.distributed as dist from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard -from transformers import AutoTokenizer + +from speculators.data_generation.preprocessing import get_tokenizer, load_processor local_rank = int(os.environ.get("LOCAL_RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -60,6 +61,8 @@ def resolve_mask_token_id( verifier_name_or_path: str, vocab_size: int, mask_token_id: int | None = None, + *, + trust_remote_code: bool = False, ) -> int: """Resolve mask_token_id from explicit value, tokenizer, or fallback. @@ -73,7 +76,11 @@ def resolve_mask_token_id( logger.info(f"Using explicit mask_token_id={mask_token_id}") return mask_token_id - tokenizer = AutoTokenizer.from_pretrained(verifier_name_or_path) + processor = load_processor( + verifier_name_or_path, + trust_remote_code=trust_remote_code, + ) + tokenizer = get_tokenizer(processor) if tokenizer.mask_token_id is not None: logger.info(f"Using tokenizer mask_token_id={tokenizer.mask_token_id}") diff --git a/tests/e2e/regression/test_eagle3_conversion_acceptance.py b/tests/e2e/regression/test_eagle3_conversion_acceptance.py index 59707fdbf..6f895e9a1 100644 --- a/tests/e2e/regression/test_eagle3_conversion_acceptance.py +++ b/tests/e2e/regression/test_eagle3_conversion_acceptance.py @@ -69,6 +69,7 @@ def test_convert_run_vllm_engine_eagle3( run_vllm_engine( model_path=str(converted_path), tmp_path=tmp_path, + enforce_eager=False, disable_compile_cache=disable_compile_cache, prompts=prompts, acceptance_thresholds=acceptance_thresholds, diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index 598e35007..f12865d4f 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -11,20 +11,64 @@ from pathlib import Path +import pytest + +from speculators.data_generation.configs import get_coco_dir +from speculators.data_generation.preprocessing import ( + _adapt_conv_for_vllm, + _normalize_conversation, + load_raw_dataset, +) from tests.e2e.smoke.test_offline_training import run_offline_e2e from tests.utils import requires_cadence @requires_cadence("nightly") -def test_offline_regression(tmp_path: Path, prompts): +@pytest.mark.parametrize( + ("model", "dataset", "acceptance_thresholds"), + [ + ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.1, 0.01]), + ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.2, 0.04]), + ], +) +def test_offline_regression( + tmp_path: Path, + model: str, + dataset: str, + acceptance_thresholds: list[float], + prompts: list[list[dict[str, str]]], +): + if dataset == "sharegpt4v_coco": + coco_dir = get_coco_dir() + + if not Path(coco_dir).exists(): + pytest.skip(f"Cannot find COCO dataset at {coco_dir}") + + vllm_media_path = coco_dir + + raw_dataset, normalize_fn = load_raw_dataset(dataset) + raw_dataset = raw_dataset.skip(len(raw_dataset) - len(prompts)) + if normalize_fn is not None: + raw_dataset = raw_dataset.map(normalize_fn, keep_in_memory=True) + + raw_convs = raw_dataset["conversations"] + normalized_convs = [_normalize_conversation(conv) for conv in raw_convs] + prompts = [_adapt_conv_for_vllm(conv) for conv in normalized_convs] + else: + vllm_media_path = None + run_offline_e2e( tmp_path, - "Qwen/Qwen3-8B", + model, + dataset, max_samples=5000, seq_length=8192, - vllm_gpu_util=0.9, + vllm_kwargs={ + "gpu_memory_utilization": 0.9, + "allowed_local_media_path": vllm_media_path, + }, epochs=3, prompts=prompts, - acceptance_thresholds=[0.4, 0.07, 0.007], + acceptance_thresholds=acceptance_thresholds, log_freq=50, ) diff --git a/tests/e2e/regression/test_eagle3_online_acceptance.py b/tests/e2e/regression/test_eagle3_online_acceptance.py index 8f2faecff..044e2d87d 100644 --- a/tests/e2e/regression/test_eagle3_online_acceptance.py +++ b/tests/e2e/regression/test_eagle3_online_acceptance.py @@ -10,21 +10,65 @@ from pathlib import Path +import pytest + +from speculators.data_generation.configs import get_coco_dir +from speculators.data_generation.preprocessing import ( + _adapt_conv_for_vllm, + _normalize_conversation, + load_raw_dataset, +) from tests.e2e.smoke.test_online_training import run_online_e2e from tests.utils import requires_cadence @requires_cadence("nightly") -def test_online_regression(tmp_path: Path, prompts): +@pytest.mark.parametrize( + ("model", "dataset", "acceptance_thresholds"), + [ + ("Qwen/Qwen3-8B", "sharegpt", [0.4, 0.1, 0.01]), + ("Qwen/Qwen3-VL-2B-Instruct", "sharegpt4v_coco", [0.4, 0.2, 0.04]), + ], +) +def test_online_regression( + tmp_path: Path, + model: str, + dataset: str, + acceptance_thresholds: list[float], + prompts: list[list[dict[str, str]]], +): + if dataset == "sharegpt4v_coco": + coco_dir = get_coco_dir() + + if not Path(coco_dir).exists(): + pytest.skip(f"Cannot find COCO dataset at {coco_dir}") + + vllm_media_path = coco_dir + + raw_dataset, normalize_fn = load_raw_dataset(dataset) + raw_dataset = raw_dataset.skip(len(raw_dataset) - len(prompts)) + if normalize_fn is not None: + raw_dataset = raw_dataset.map(normalize_fn, keep_in_memory=True) + + raw_convs = raw_dataset["conversations"] + normalized_convs = [_normalize_conversation(conv) for conv in raw_convs] + prompts = [_adapt_conv_for_vllm(conv) for conv in normalized_convs] + else: + vllm_media_path = None + run_online_e2e( tmp_path, - "Qwen/Qwen3-8B", + model, + dataset, max_samples=5000, seq_length=8192, - vllm_gpu_util=0.75, + vllm_kwargs={ + "gpu_memory_utilization": 0.75, + "allowed_local_media_path": vllm_media_path, + }, epochs=3, prompts=prompts, - acceptance_thresholds=[0.4, 0.1, 0.01], + acceptance_thresholds=acceptance_thresholds, log_freq=50, train_timeout=45 * 60, # 45 mins ) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index aa99b4aa9..0e38f7140 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -10,6 +10,7 @@ """ from pathlib import Path +from typing import Any import pytest @@ -19,23 +20,30 @@ run_prepare_data, run_training, run_vllm_engine, + setup_dummy_sharegpt4v_coco, ) -MODEL = "Qwen/Qwen3-0.6B" +TEXT_MODEL = "Qwen/Qwen3-0.6B" +MM_MODEL = "Qwen/Qwen3-VL-2B-Instruct" @pytest.mark.e2e @pytest.mark.slow @pytest.mark.parametrize( - ("speculator_type", "extra_train_args", "target_layer_ids"), + ("model", "dataset", "speculator_type", "extra_train_args", "target_layer_ids"), [ - ("eagle3", [], None), # Use default EAGLE layers + (TEXT_MODEL, "sharegpt", "eagle3", [], None), # Use default EAGLE layers + (MM_MODEL, "sharegpt4v_coco", "eagle3", [], None), # Multimodal ( + TEXT_MODEL, + "sharegpt", "dflash", ["--block-size", "8", "--max-anchors", "256", "--num-layers", "3"], [1, 13, 25], ), # DFlash with 3 layers + verifier last layer ( + TEXT_MODEL, + "sharegpt", "peagle", [ "--num-layers", @@ -53,17 +61,35 @@ ], ) def test_offline_smoke( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, + model: str, + dataset: str, prompts: list[list[dict[str, str]]], speculator_type: str, extra_train_args: list[str], target_layer_ids: list[int] | None, ): + if dataset == "sharegpt4v_coco": + coco_dir = tmp_path / "coco" + + monkeypatch.setenv("COCO_DIR", str(coco_dir)) + setup_dummy_sharegpt4v_coco(coco_dir) + + vllm_media_path = str(coco_dir) + else: + vllm_media_path = None + run_offline_e2e( tmp_path, - MODEL, + model, + dataset=dataset, prompts=prompts, - vllm_gpu_util=0.9, + vllm_kwargs={ + "enforce_eager": True, + "gpu_memory_utilization": 0.9, + "allowed_local_media_path": vllm_media_path, + }, speculator_type=speculator_type, extra_train_args=extra_train_args, target_layer_ids=target_layer_ids, @@ -73,9 +99,10 @@ def test_offline_smoke( def run_offline_e2e( tmp_path: Path, model: str, + dataset: str, max_samples: int = 50, seq_length: int = 512, - vllm_gpu_util: float = 0.5, + vllm_kwargs: dict[str, Any] | None = None, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -97,15 +124,15 @@ def run_offline_e2e( save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(model, data_path, max_samples, seq_length) + run_prepare_data(model, dataset, data_path, max_samples, seq_length) with launch_vllm_server_context( model, port, str(tmp_path / "vllm_hidden_states"), max_model_len=seq_length + 1, - gpu_memory_utilization=vllm_gpu_util, target_layer_ids=target_layer_ids, + **(vllm_kwargs or {}), ): # Step 2: Generate hidden states offline run_data_generation_offline( @@ -146,4 +173,5 @@ def run_offline_e2e( max_tokens=max_tokens, ignore_eos=ignore_eos, acceptance_thresholds=acceptance_thresholds, + **(vllm_kwargs or {}), ) diff --git a/tests/e2e/smoke/test_online_training.py b/tests/e2e/smoke/test_online_training.py index 7224784d8..1c927f05a 100644 --- a/tests/e2e/smoke/test_online_training.py +++ b/tests/e2e/smoke/test_online_training.py @@ -9,6 +9,7 @@ """ from pathlib import Path +from typing import Any import pytest @@ -17,23 +18,58 @@ run_prepare_data, run_training, run_vllm_engine, + setup_dummy_sharegpt4v_coco, ) -MODEL = "Qwen/Qwen3-0.6B" +TEXT_MODEL = "Qwen/Qwen3-0.6B" +MM_MODEL = "Qwen/Qwen3-VL-2B-Instruct" @pytest.mark.e2e @pytest.mark.slow -def test_online_smoke(tmp_path: Path, prompts: list[list[dict[str, str]]]): - run_online_e2e(tmp_path, MODEL, prompts=prompts) +@pytest.mark.parametrize( + ("model", "dataset"), + [ + (TEXT_MODEL, "sharegpt"), + (MM_MODEL, "sharegpt4v_coco"), + ], +) +def test_online_smoke( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + model: str, + dataset: str, + prompts: list[list[dict[str, str]]], +): + if dataset == "sharegpt4v_coco": + coco_dir = tmp_path / "coco" + + monkeypatch.setenv("COCO_DIR", str(coco_dir)) + setup_dummy_sharegpt4v_coco(coco_dir) + + vllm_media_path = str(coco_dir) + else: + vllm_media_path = None + + run_online_e2e( + tmp_path, + model, + dataset=dataset, + prompts=prompts, + vllm_kwargs={ + "enforce_eager": True, + "allowed_local_media_path": vllm_media_path, + }, + ) def run_online_e2e( tmp_path: Path, model: str, + dataset: str, max_samples: int = 50, seq_length: int = 512, - vllm_gpu_util: float = 0.5, + vllm_kwargs: dict[str, Any] | None = None, port: int = 8321, draft_vocab_size: int = 8192, epochs: int = 1, @@ -55,7 +91,7 @@ def run_online_e2e( save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(model, data_path, max_samples, seq_length) + run_prepare_data(model, dataset, data_path, max_samples, seq_length) hidden_states_path = str(tmp_path / "hidden_states") with launch_vllm_server_context( @@ -63,7 +99,7 @@ def run_online_e2e( port, hidden_states_path, max_model_len=seq_length + 1, - gpu_memory_utilization=vllm_gpu_util, + **(vllm_kwargs or {}), ): # Step 2: Train against live vLLM server run_training( @@ -90,4 +126,5 @@ def run_online_e2e( max_tokens=max_tokens, ignore_eos=ignore_eos, acceptance_thresholds=acceptance_thresholds, + **(vllm_kwargs or {}), ) diff --git a/tests/e2e/smoke/test_resume_optimizer.py b/tests/e2e/smoke/test_resume_optimizer.py index c887d89e9..b793e2cd8 100644 --- a/tests/e2e/smoke/test_resume_optimizer.py +++ b/tests/e2e/smoke/test_resume_optimizer.py @@ -82,10 +82,15 @@ def test_resume_after_checkpoint_best(tmp_path: Path): save_path = tmp_path / "checkpoints" # Step 1: Prepare data - run_prepare_data(MODEL, data_path) + run_prepare_data(MODEL, "sharegpt", data_path) # Step 2: Generate hidden states offline - with launch_vllm_server_context(MODEL, VLLM_PORT, str(tmp_path / "hidden_states")): + with launch_vllm_server_context( + MODEL, + VLLM_PORT, + str(tmp_path / "hidden_states"), + enforce_eager=True, + ): run_data_generation_offline(data_path, hidden_states_path, port=VLLM_PORT) # Step 3: Train 1 epoch with --save-best diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 1e9053924..ee4e90ac9 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -11,6 +11,9 @@ from textwrap import indent from loguru import logger +from PIL import Image + +from speculators.data_generation.preprocessing import load_raw_dataset __all__ = [ "SCRIPTS_DIR", @@ -65,9 +68,12 @@ def launch_vllm_server( model: str, port: int, hidden_states_path: str, + *, max_model_len: int = 513, gpu_memory_utilization: float = 0.5, target_layer_ids: list[int] | None = None, + enforce_eager: bool = False, + allowed_local_media_path: str | None = None, ) -> subprocess.Popen: """Launch a vLLM server configured for hidden-state extraction. @@ -83,6 +89,10 @@ def launch_vllm_server( ] if target_layer_ids is not None: cmd += ["--target-layer-ids"] + [str(lid) for lid in target_layer_ids] + if enforce_eager: + cmd += ["--enforce-eager"] + if allowed_local_media_path is not None: + cmd += ["--allowed-local-media-path", allowed_local_media_path] cmd += [ "--", "--port", @@ -135,21 +145,41 @@ def launch_vllm_server_context(*args, **kwargs): stop_vllm_server(process) +def setup_dummy_sharegpt4v_coco(coco_dir: Path): + """Enable ShareGPT4V to be used without downloading the actual COCO dataset.""" + coco_dir.mkdir(parents=True, exist_ok=True) + + dummy_image = Image.new("RGB", (256, 256)) + dummy_image_path = coco_dir / "dummy.png" + dummy_image.save(dummy_image_path) + + raw_dataset, normalize_fn = load_raw_dataset("sharegpt4v_coco") + + # Use symlinks to avoid copying the image + for raw_path in raw_dataset["image"]: + image_path = coco_dir / raw_path.removeprefix("coco/") + + if not image_path.exists(): + image_path.parent.mkdir(parents=True, exist_ok=True) + image_path.symlink_to(dummy_image_path) + + def run_prepare_data( model: str, + data: str, data_path: Path, max_samples: int = 50, seq_length: int = 512, timeout: float | None = None, ): - """Tokenize ShareGPT data using prepare_data.py.""" + """Tokenize data using prepare_data.py.""" cmd = [ sys.executable, str(SCRIPTS_DIR / "prepare_data.py"), "--model", model, "--data", - "sharegpt", + data, "--output", str(data_path), "--max-samples", @@ -277,6 +307,10 @@ def run_vllm_engine( model_path: str, tmp_path: Path, prompts: list[list[dict[str, str]]], + max_model_len: int = 1024, + gpu_memory_utilization: float = 0.8, + enforce_eager: bool = False, + allowed_local_media_path: str | None = None, disable_compile_cache: bool = False, max_tokens: int = 50, ignore_eos: bool = True, @@ -289,27 +323,29 @@ def run_vllm_engine( run_vllm_file = str(Path(__file__).with_name("run_vllm.py")) results_file = str(tmp_path / "results.json") + sampling_params_dict = { + "temperature": 0, + "top_p": 0.9, + "max_tokens": max_tokens, + "ignore_eos": ignore_eos, + } + + llm_args_dict = { + "model": model_path, + "max_model_len": max_model_len, + "gpu_memory_utilization": gpu_memory_utilization, + "enforce_eager": enforce_eager, + } + if allowed_local_media_path is not None: + llm_args_dict["allowed_local_media_path"] = allowed_local_media_path + command = [ VLLM_PYTHON, run_vllm_file, "--sampling-params-args", - json.dumps( - { - "temperature": 0, - "top_p": 0.9, - "max_tokens": max_tokens, - "ignore_eos": ignore_eos, - } - ), + json.dumps(sampling_params_dict), "--llm-args", - json.dumps( - { - "model": model_path, - "max_model_len": 1024, - "gpu_memory_utilization": 0.8, - "enforce_eager": True, - } - ), + json.dumps(llm_args_dict), "--prompts", json.dumps(prompts), "--results-file", diff --git a/tests/integration/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py index a9184c18f..44f385628 100644 --- a/tests/integration/datagen/test_preprocessing.py +++ b/tests/integration/datagen/test_preprocessing.py @@ -7,21 +7,26 @@ import pytest import torch from datasets import Dataset as HFDataset -from transformers import AutoTokenizer +from PIL import Image from speculators.data_generation.preprocessing import ( + _adapt_conv_for_hf, + _adapt_conv_for_vllm, _create_loss_mask_from_offsets, _detect_assistant_pattern, _normalize_conversation, _preprocess_batch, _supports_assistant_mask, build_eagle3_dataset, + load_processor, ) # Test model from HuggingFace with chat template # Using Qwen2-0.5B-Instruct: small (0.5B params), fast model with proper # chat template support -TEST_MODEL_REPO = "Qwen/Qwen2-0.5B-Instruct" +TEXT_MODEL_REPO = "Qwen/Qwen2-0.5B-Instruct" +# For testing multi-modal support +MM_MODEL_REPO = "Qwen/Qwen3.5-0.8B" # Tests for _normalize_conversation @@ -77,16 +82,153 @@ def test_normalize_conversation_unknown_role(): assert result[1]["role"] == "assistant" +# Tests for _adapt_conv_for_hf +@pytest.mark.sanity +def test_adapt_conv_for_hf_text_only_processor(): + """ + Test converting from normalized conversation to HF format + with a text-only processor (i.e. tokenizer). + """ + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) + + conv: list[dict] = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": ["Hello"]}, + {"role": "assistant", "content": "Hi!"}, + ] + result = _adapt_conv_for_hf(conv, processor) + + assert result == conv + + +@pytest.mark.sanity +def test_adapt_conv_for_hf_multimodal_processor(): + """ + Test converting from normalized conversation to HF format + with a multi-modal processor. + """ + processor = load_processor(MM_MODEL_REPO, trust_remote_code=True) + + conv: list[dict] = [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": ["Hello", {"type": "image", "path": "/path/to/img"}], + }, + { + "role": "assistant", + "content": "Hi!", + }, + ] + result = _adapt_conv_for_hf(conv, processor) + + assert result == [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "image", "path": "/path/to/img"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "Hi!"}], + }, + ] + + +# Tests for _adapt_conv_for_vllm +@pytest.mark.sanity +def test_adapt_conv_for_vllm_all_content_formats(): + """ + Test converting from normalized conversation to vLLM format + with each supported content format. + """ + conv: list[dict] = [ + { + "role": "system", + "content": "You are a helpful assistant.", # Content as string + }, + { + "role": "assistant", + "content": [ # Content as list + "Hello,", # Text as string + {"type": "text", "text": "I am"}, # Text as dictionary + {"type": "image", "path": "/path/to/img"}, # Image file path + {"type": "image", "url": "http://path/to/img"}, # Image URL + ], + }, + ] + result = _adapt_conv_for_vllm(conv) + + assert len(result) == 2 + + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are a helpful assistant." + + assert result[1]["role"] == "assistant" + assert result[1]["content"] == [ + {"type": "text", "text": "Hello,"}, + {"type": "text", "text": "I am"}, + { + "type": "image_url", + "image_url": {"url": "file:///path/to/img"}, + }, + { + "type": "image_url", + "image_url": {"url": "http://path/to/img"}, + }, + ] + + +@pytest.mark.sanity +def test_adapt_conv_for_vllm_invalid_content_formats(): + """ + Test converting from normalized conversation to vLLM format + with unsupported content formats. + """ + with pytest.raises(ValueError, match=r"'image':.* is not supported"): + _adapt_conv_for_vllm( + [ + { + "role": "assistant", + "content": [ + {"type": "image", "image": Image.new("RGB", (256, 256))}, + ], + }, + ] + ) + + with pytest.raises(ValueError, match=r"'base64':.* is not supported"): + _adapt_conv_for_vllm( + [ + { + "role": "assistant", + "content": [ + {"type": "image", "base64": "abcdef"}, + ], + }, + ] + ) + + # Tests for _detect_assistant_pattern @pytest.mark.sanity def test_detect_assistant_pattern_structure(): """Test that the detected pattern has the correct regex structure.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) # Pattern should be a valid regex string assert isinstance(pattern, str) @@ -105,20 +247,20 @@ def test_detect_assistant_pattern_structure(): @pytest.mark.sanity def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): """Test that pattern correctly distinguishes assistant from user content.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") # Get the pattern - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) # Format a conversation manually to test the pattern test_conv = [ {"role": "user", "content": "USER_MSG"}, {"role": "assistant", "content": "ASSISTANT_MSG"}, ] - formatted: str = tokenizer.apply_chat_template( # type: ignore[assignment] + formatted: str = processor.apply_chat_template( # type: ignore[assignment] test_conv, tokenize=False, add_generation_prompt=False ) @@ -141,12 +283,12 @@ def test_detect_assistant_pattern_correctly_identifies_assistant_vs_user(): @pytest.mark.sanity def test_detect_assistant_pattern_extracts_correct_content(): """Test that the pattern's capture group extracts only assistant message content.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) # Test with a multi-turn conversation test_conv = [ @@ -156,7 +298,7 @@ def test_detect_assistant_pattern_extracts_correct_content(): {"role": "assistant", "content": "Second answer"}, ] - formatted: str = tokenizer.apply_chat_template( # type: ignore [assignment] + formatted: str = processor.apply_chat_template( # type: ignore [assignment] test_conv, tokenize=False, add_generation_prompt=False ) @@ -261,13 +403,10 @@ def test_create_loss_mask_empty_offsets(): @pytest.mark.sanity def test_preprocess_batch_basic(): """Test preprocessing a basic batch of conversations.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples = { "conversations": [ @@ -282,9 +421,91 @@ def test_preprocess_batch_basic(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=assistant_pattern + examples, processor, max_length=512, assistant_pattern=assistant_pattern + ) + + assert "input_ids" in results + assert "loss_mask" in results + assert len(results["input_ids"]) == 2 + assert len(results["loss_mask"]) == 2 + + # Check that input_ids and loss_mask have same length for each example + for i in range(2): + assert len(results["input_ids"][i]) == len(results["loss_mask"][i]) + assert isinstance(results["input_ids"][i], torch.Tensor) + assert isinstance(results["loss_mask"][i], torch.Tensor) + + +@pytest.mark.sanity +def test_preprocess_batch_multimodal(tmp_path): + """Test preprocessing a batch of multimodal conversations.""" + processor = load_processor(MM_MODEL_REPO, trust_remote_code=True) + + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") + + img_path = str(tmp_path / "blank.png") + Image.new("RGB", (256, 256)).save(img_path) + + examples = { + "conversations": [ + [ + { + "role": "user", + "content": "Hello, how are you?", + }, + { + "role": "assistant", + "content": "I am a helpful assistant.", + }, + { + "role": "user", + "content": "What is the capital of France?", + }, + { + "role": "assistant", + "content": "The capital of France is Paris.", + }, + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is the difference between these two images?", + }, + {"type": "image", "path": img_path}, + {"type": "image", "path": img_path}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "They are the exact same image."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Why?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "They are both blank."}, + ], + }, + ], + ] + } + + assistant_pattern = _detect_assistant_pattern(processor) + results = _preprocess_batch( + examples, processor, max_length=2048, assistant_pattern=assistant_pattern ) assert "input_ids" in results @@ -302,15 +523,15 @@ def test_preprocess_batch_basic(): @pytest.mark.sanity def test_preprocess_batch_empty_conversations(): """Test preprocessing batch with no conversations.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples: dict[str, list] = {"conversations": []} - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=assistant_pattern + examples, processor, max_length=512, assistant_pattern=assistant_pattern ) assert results["input_ids"] == [] @@ -320,13 +541,10 @@ def test_preprocess_batch_empty_conversations(): @pytest.mark.sanity def test_preprocess_batch_invalid_conversation(): """Test preprocessing batch with invalid conversations.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples = { "conversations": [ @@ -336,9 +554,9 @@ def test_preprocess_batch_invalid_conversation(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=assistant_pattern + examples, processor, max_length=512, assistant_pattern=assistant_pattern ) # Should only process the valid conversation @@ -349,13 +567,10 @@ def test_preprocess_batch_invalid_conversation(): @pytest.mark.sanity def test_preprocess_batch_truncation(): """Test that long sequences are truncated to max_length.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") # Create a very long message long_content = "word " * 1000 @@ -370,9 +585,9 @@ def test_preprocess_batch_truncation(): } max_length = 100 - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( - examples, tokenizer, max_length=max_length, assistant_pattern=assistant_pattern + examples, processor, max_length=max_length, assistant_pattern=assistant_pattern ) if len(results["input_ids"]) > 0: @@ -384,17 +599,14 @@ def test_preprocess_batch_truncation(): @pytest.mark.sanity def test_preprocess_batch_uses_hf_assistant_mask(): """Test that HF assistant token mask is used when supported.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # Skip test if assistant mask is not supported/functional for this tokenizer - if not _supports_assistant_mask(tokenizer): - pytest.skip("Tokenizer does not support assistant token mask") + # Skip test if assistant mask is not supported/functional for this processor + if not _supports_assistant_mask(processor): + pytest.skip("Processor does not support assistant token mask") examples = { "conversations": [ @@ -408,7 +620,7 @@ def test_preprocess_batch_uses_hf_assistant_mask(): # Pass None to trigger masking path results = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=None, ) @@ -427,23 +639,20 @@ def test_preprocess_batch_falls_back_to_regex(): """Test that preprocessing falls back to regex-based detection when HF mask is unavailable. """ - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") # Monkeypatch apply_chat_template to force HF mask failure - original_apply_chat_template = tokenizer.apply_chat_template + original_apply_chat_template = processor.apply_chat_template def patched_apply_chat_template(*args, **kwargs): if kwargs.get("return_assistant_tokens_mask", False): raise ValueError("Forcing fallback to regex path") return original_apply_chat_template(*args, **kwargs) - tokenizer.apply_chat_template = patched_apply_chat_template # type: ignore [method-assign] + processor.apply_chat_template = patched_apply_chat_template # type: ignore [method-assign] examples = { "conversations": [ @@ -454,11 +663,11 @@ def patched_apply_chat_template(*args, **kwargs): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=assistant_pattern, ) @@ -475,13 +684,10 @@ def patched_apply_chat_template(*args, **kwargs): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): """Test that minimum_valid_tokens drops short samples on regex path.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples = { "conversations": [ @@ -492,11 +698,11 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) baseline = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=assistant_pattern, ) @@ -507,7 +713,7 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): filtered = _preprocess_batch( examples, - tokenizer, + processor, max_length=128, assistant_pattern=assistant_pattern, minimum_valid_tokens=valid_count + 1, @@ -521,13 +727,10 @@ def test_preprocess_batch_minimum_valid_tokens_filters_regex_path(): @pytest.mark.sanity def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): """Test that a sample is kept when valid tokens equal the threshold.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples = { "conversations": [ @@ -544,11 +747,11 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) baseline = _preprocess_batch( examples, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, ) @@ -559,7 +762,7 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): kept = _preprocess_batch( examples, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, minimum_valid_tokens=valid_count, @@ -576,13 +779,10 @@ def test_preprocess_batch_minimum_valid_tokens_keeps_boundary_case(): @pytest.mark.sanity def test_build_eagle3_dataset_basic(): """Test building EAGLE3 dataset from a simple HuggingFace dataset.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") # Create a simple dataset data = { @@ -599,7 +799,7 @@ def test_build_eagle3_dataset_basic(): } dataset = HFDataset.from_dict(data) - result = build_eagle3_dataset(dataset, tokenizer, max_length=512, num_proc=1) + result = build_eagle3_dataset(dataset, processor, max_length=512, num_proc=1) assert isinstance(result, HFDataset) assert len(result) <= len(dataset) @@ -613,13 +813,10 @@ def test_build_eagle3_dataset_basic(): @pytest.mark.sanity def test_build_eagle3_dataset_preserves_format(): """Test that build_eagle3_dataset sets the correct format.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") data = { "conversations": [ @@ -631,7 +828,7 @@ def test_build_eagle3_dataset_preserves_format(): } dataset = HFDataset.from_dict(data) - result = build_eagle3_dataset(dataset, tokenizer, max_length=512, num_proc=1) + result = build_eagle3_dataset(dataset, processor, max_length=512, num_proc=1) # Dataset should be in torch format assert result.format["type"] == "torch" @@ -640,13 +837,10 @@ def test_build_eagle3_dataset_preserves_format(): @pytest.mark.sanity def test_build_eagle3_dataset_removes_original_columns(): """Test that original columns are removed after processing.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") data = { "conversations": [ @@ -659,7 +853,7 @@ def test_build_eagle3_dataset_removes_original_columns(): } dataset = HFDataset.from_dict(data) - result = build_eagle3_dataset(dataset, tokenizer, max_length=512, num_proc=1) + result = build_eagle3_dataset(dataset, processor, max_length=512, num_proc=1) # Original columns should be removed if len(result) > 0: @@ -670,13 +864,10 @@ def test_build_eagle3_dataset_removes_original_columns(): @pytest.mark.sanity def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): """Test that build_eagle3_dataset removes samples below the token threshold.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") short_conv = [ {"role": "user", "content": "Hi"}, @@ -693,17 +884,17 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): }, ] - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) short_baseline = _preprocess_batch( {"conversations": [short_conv]}, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, ) long_baseline = _preprocess_batch( {"conversations": [long_conv]}, - tokenizer, + processor, max_length=256, assistant_pattern=assistant_pattern, ) @@ -722,7 +913,7 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): dataset = HFDataset.from_dict({"conversations": [short_conv, long_conv]}) result = build_eagle3_dataset( dataset, - tokenizer, + processor, max_length=256, num_proc=1, assistant_pattern=assistant_pattern, @@ -742,13 +933,10 @@ def test_build_eagle3_dataset_minimum_valid_tokens_filters_short_samples(): @pytest.mark.sanity def test_preprocess_batch_with_turn_dropout(): """Test preprocessing batch with turn dropout enabled.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") examples = { "conversations": [ @@ -761,10 +949,10 @@ def test_preprocess_batch_with_turn_dropout(): ] } - assistant_pattern = _detect_assistant_pattern(tokenizer) + assistant_pattern = _detect_assistant_pattern(processor) results = _preprocess_batch( examples, - tokenizer, + processor, max_length=512, assistant_pattern=assistant_pattern, turn_dropout=True, @@ -788,8 +976,8 @@ def test_detect_assistant_pattern_thinking_model(): but the pattern must still match real conversations where the think block contains substantial content. """ - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) - pattern = _detect_assistant_pattern(tokenizer) + processor = load_processor("Qwen/Qwen3-8B", trust_remote_code=True) + pattern = _detect_assistant_pattern(processor) # Format a multi-turn conversation with thinking content injected # directly into the formatted string (as it would appear in real data) @@ -807,7 +995,7 @@ def test_detect_assistant_pattern_thinking_model(): "reasoning_content": "We are adding 3 and 3.", }, ] - formatted: str = tokenizer.apply_chat_template( # type: ignore[assignment] + formatted: str = processor.apply_chat_template( # type: ignore[assignment] test_conv, tokenize=False, add_generation_prompt=False, enable_thinking=True ) @@ -848,8 +1036,8 @@ def test_create_loss_mask_thinking_model(thinking_content): Verifies correct masking both with and without thinking content in the block. """ - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) - pattern = _detect_assistant_pattern(tokenizer) + processor = load_processor("Qwen/Qwen3-8B", trust_remote_code=True) + pattern = _detect_assistant_pattern(processor) # Build formatted text using the real chat template conv = [ @@ -858,7 +1046,7 @@ def test_create_loss_mask_thinking_model(thinking_content): ] if thinking_content: conv[-1]["reasoning_content"] = thinking_content - formatted: str = tokenizer.apply_chat_template( # type: ignore[assignment] + formatted: str = processor.apply_chat_template( # type: ignore[assignment] conv, tokenize=False, add_generation_prompt=False, @@ -866,7 +1054,7 @@ def test_create_loss_mask_thinking_model(thinking_content): ) # Tokenize with offsets - encoding = tokenizer( + encoding = processor( formatted, return_offsets_mapping=True, add_special_tokens=False, @@ -880,8 +1068,8 @@ def test_create_loss_mask_thinking_model(thinking_content): # Decode masked vs unmasked regions input_ids = torch.tensor(encoding["input_ids"]) - trainable_text = tokenizer.decode(input_ids[mask == 1]) - masked_text = tokenizer.decode(input_ids[mask == 0]) + trainable_text = processor.decode(input_ids[mask == 1]) + masked_text = processor.decode(input_ids[mask == 0]) # Assistant response must be in the trainable region assert "Paris is the capital" in trainable_text @@ -898,13 +1086,10 @@ def test_create_loss_mask_thinking_model(thinking_content): @pytest.mark.sanity def test_build_eagle3_dataset_with_custom_pattern(): """Test building dataset with custom assistant pattern.""" - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_REPO, trust_remote_code=True) - - if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None: - pytest.skip("Tokenizer does not support chat templates") + processor = load_processor(TEXT_MODEL_REPO, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if not hasattr(processor, "apply_chat_template") or processor.chat_template is None: + pytest.skip("Processor does not support chat templates") data = { "conversations": [ @@ -920,7 +1105,7 @@ def test_build_eagle3_dataset_with_custom_pattern(): dataset = HFDataset.from_dict(data) result = build_eagle3_dataset( - dataset, tokenizer, max_length=512, num_proc=1, assistant_pattern=custom_pattern + dataset, processor, max_length=512, num_proc=1, assistant_pattern=custom_pattern ) # Should successfully build dataset with custom pattern diff --git a/tests/integration/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py index 4a671f12a..3c68026ff 100644 --- a/tests/integration/datagen/test_regex_patterns.py +++ b/tests/integration/datagen/test_regex_patterns.py @@ -6,11 +6,16 @@ import pytest from loguru import logger as log -from transformers import AutoTokenizer +from packaging.version import Version +from PIL import Image +from transformers import ProcessorMixin +from transformers import __version__ as TRANSFORMERS_VERSION # noqa: N812 from speculators.data_generation.preprocessing import ( _detect_assistant_pattern, _preprocess_batch, + get_tokenizer, + load_processor, ) # Test models covering major template families @@ -29,28 +34,33 @@ "openai/gpt-oss-20b", ] +if Version(TRANSFORMERS_VERSION) >= Version("5.5.0"): + # Multimodal + MODELS.append("google/gemma-4-E2B-it") + @pytest.fixture(scope="module", params=MODELS) -def tokenizer(request): +def processor(request): model_id = request.param try: # Using trust_remote_code=True for variety of templates - return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + return load_processor(model_id, trust_remote_code=True) except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: - pytest.skip(f"Failed to load tokenizer for {model_id}: {e}") + pytest.skip(f"Failed to load processor for {model_id}: {e}") -def test_regex_detection_across_models(tokenizer): +def test_regex_detection_across_models(tmp_path, processor): """ Verify that _detect_assistant_pattern and _preprocess_batch (regex path) work correctly for a variety of model families. """ + tokenizer = get_tokenizer(processor) model_name = tokenizer.name_or_path log.info(f"Testing family: {model_name}") # 1. Detect pattern try: - pattern = _detect_assistant_pattern(tokenizer) + pattern = _detect_assistant_pattern(processor) except (ValueError, RuntimeError) as e: pytest.fail(f"Failed to detect assistant pattern for {model_name}: {e}") @@ -58,20 +68,55 @@ def test_regex_detection_across_models(tokenizer): assert isinstance(pattern, (str, Pattern)), "Pattern must be str or regex object" # 2. Preprocess a simple multi-turn conversation using REGEX path - examples = { - "conversations": [ - [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - {"role": "assistant", "content": "The capital of France is Paris."}, - ] + if isinstance(processor, ProcessorMixin): + img_path = str(tmp_path / "blank.png") + Image.new("RGB", (256, 256)).save(img_path) + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello, how are you?"}, + {"type": "image", "path": img_path}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I am a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the capital"}, + {"type": "image", "path": img_path}, + {"type": "text", "text": "of France?"}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The capital of France is Paris.", + }, + ], + }, + ] + else: + conversation = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, ] - } + + examples = {"conversations": [conversation]} # Regex path by passing the explicit pattern results = _preprocess_batch( - examples, tokenizer, max_length=512, assistant_pattern=pattern + examples, processor, max_length=2048, assistant_pattern=pattern ) assert len(results["input_ids"]) == 1 @@ -86,7 +131,7 @@ def test_regex_detection_across_models(tokenizer): # 3. Qualitative check: Assistant content should be masked as 1 trainable_tokens = input_ids[loss_mask == 1] - decoded_assistant = tokenizer.decode(trainable_tokens) + decoded_assistant = processor.decode(trainable_tokens) log.info(f"Decoded trainable regions: {decoded_assistant}") @@ -97,7 +142,3 @@ def test_regex_detection_across_models(tokenizer): # It should NOT contain user message content assert "Hello" not in decoded_assistant assert "France?" not in decoded_assistant - - -if __name__ == "__main__": - pass