From 99f1665b918f19aa2d8433b40d4d84bf0be8c638 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 17 Apr 2026 16:11:17 +0000 Subject: [PATCH 1/7] fully deprecate old data generation system Signed-off-by: shanjiaz --- .coderabbit.yaml | 2 +- README.md | 8 - docs/examples/index.md | 8 - docs/index.md | 2 +- docs/scripts/gen_files.py | 6 - .../data_generation_and_training/README.md | 45 -- .../gpt_oss_20b_ultrachat_5k.py | 75 --- .../llama3_8b_sharegpt_5k.py | 68 -- .../qwen3_8b_sharegpt_ultrachat.py | 77 --- pyproject.toml | 3 - scripts/README.md | 328 --------- scripts/data_generation_offline.py | 631 ++++++++++-------- scripts/data_generation_offline2.py | 475 ------------- scripts/gen_and_train.py | 354 ---------- .../data_generation/config_generator.py | 281 -------- .../data_generation/custom_worker.py | 195 ------ .../vllm_hidden_states_generator.py | 374 ----------- tests/datagen/test_config_generator.py | 127 ---- tests/datagen/test_vllm_hidden_states.py | 349 ---------- .../test_eagle3_offline_acceptance.py | 2 +- tests/e2e/smoke/test_offline_training.py | 6 +- tests/e2e/smoke/test_resume_optimizer.py | 4 +- tests/e2e/utils.py | 8 +- 23 files changed, 380 insertions(+), 3048 deletions(-) delete mode 100644 examples/data_generation_and_training/README.md delete mode 100644 examples/data_generation_and_training/gpt_oss_20b_ultrachat_5k.py delete mode 100755 examples/data_generation_and_training/llama3_8b_sharegpt_5k.py delete mode 100755 examples/data_generation_and_training/qwen3_8b_sharegpt_ultrachat.py delete mode 100644 scripts/README.md delete mode 100644 scripts/data_generation_offline2.py delete mode 100644 scripts/gen_and_train.py delete mode 100644 src/speculators/data_generation/config_generator.py delete mode 100644 src/speculators/data_generation/custom_worker.py delete mode 100644 src/speculators/data_generation/vllm_hidden_states_generator.py delete mode 100644 tests/datagen/test_config_generator.py delete mode 100644 tests/datagen/test_vllm_hidden_states.py diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 708883dd5..e72681bd2 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -34,7 +34,7 @@ reviews: - path: "examples/**/*.py" instructions: "Review for clarity, correctness, and educational value. Ensure examples are end-to-end runnable, configurations match current API, and comments explain speculative decoding-specific concepts for users new to the algorithm." - path: "scripts/**/*.py" - instructions: "Check that scripts handle argument parsing robustly, log progress clearly, and are safe to run in multi-GPU environments. Verify that gen_and_train.py pipeline orchestration correctly sequences data generation before training." + instructions: "Check that scripts handle argument parsing robustly, log progress clearly, and are safe to run in multi-GPU environments." - path: "docs/**/*.md" instructions: "Check for clarity, accuracy, and completeness. Ensure code examples match the current API and that speculative decoding concepts are explained correctly." - path: "**/README.md" diff --git a/README.md b/README.md index d187010e8..1dfc705aa 100644 --- a/README.md +++ b/README.md @@ -167,14 +167,6 @@ The following table summarizes the models that have been trained end-to-end by o ✅ = Supported, ⏳ = In Progress, ❌ = Not Yet Supported -## Examples - -End-To-End Training Examples: - -- [Train Llama3 Draft Model](https://github.com/vllm-project/speculators/blob/main/examples/data_generation_and_training/llama3_8b_sharegpt_5k.py) -- [Train Qwen3 (Non-MoE) Draft Model](https://github.com/vllm-project/speculators/blob/main/examples/data_generation_and_training/qwen3_8b_sharegpt_ultrachat.py) -- [Train GPT-OSS Draft Model](https://github.com/vllm-project/speculators/blob/main/examples/data_generation_and_training/gpt_oss_20b_ultrachat_5k.py) - ## vLLM Inference Models trained through Speculators can run seamlessly in vLLM using a simple `vllm serve ` command. This will run the model in vLLM using default arguments, defined in the `speculator_config` of the model's config.json. diff --git a/docs/examples/index.md b/docs/examples/index.md index e91641e62..fa965d8fa 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -8,14 +8,6 @@ Welcome to the Examples section of Speculators! This area provides end-to-end ex
-- :octicons-ai-model-16:{ .lg .middle } Train - - ______________________________________________________________________ - - End-to-end example of using Speculators to train a speculative decoding model. - - [:octicons-arrow-right-24: Train](data_generation_and_training.md) - - :octicons-arrow-switch-16:{ .lg .middle } Convert ______________________________________________________________________ diff --git a/docs/index.md b/docs/index.md index 9e35b5f65..c01f6a2ff 100644 --- a/docs/index.md +++ b/docs/index.md @@ -39,5 +39,5 @@ Behind the scenes, this is reading the model from Hugging Face, parsing the `spe To create a speculative decoding model for a different verifier model there are two approaches you can choose: -1. Train a new speculative decoding model ([instructions](train.md))([examples](examples/data_generation_and_training.md)). +1. Train a new speculative decoding model ([instructions](train.md)). 2. Convert an existing model from a third-party library to the Speculators format for easy deployment with vLLM ([instructions](convert.md)) ([examples](examples/convert.md)). diff --git a/docs/scripts/gen_files.py b/docs/scripts/gen_files.py index ea45d3e81..499081586 100644 --- a/docs/scripts/gen_files.py +++ b/docs/scripts/gen_files.py @@ -139,12 +139,6 @@ def migrate_developer_docs(): weight=-12, ), # Examples - ProcessFile( - root_path=Path("examples/data_generation_and_training/README.md"), - docs_path=Path("examples/data_generation_and_training.md"), - title="Train", - weight=1, - ), ProcessFile( root_path=Path("examples/convert/README.md"), docs_path=Path("examples/convert.md"), diff --git a/examples/data_generation_and_training/README.md b/examples/data_generation_and_training/README.md deleted file mode 100644 index 4b2c59698..000000000 --- a/examples/data_generation_and_training/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Data Generation and Training - -Speculators currently supports training of Eagle3 speculative decoders. For full details on all the steps described below, see [README.md](/scripts/README.md) - -This process is currently broken down into three key steps: - -1. Data Generation -2. Vocab Mapping -3. Training - -## Data Generation - -Generate hidden states for training using vLLM. Dataset values are passed through the target or verifier model and generated hidden states are saved to disk for further use. `scripts/data_generation_offline.py` provides the main entry point for generating training data for Eagle3 models. - -Once completed, the following files will be generated on disk: - -1. `token_freq.pt` (the token frequency distribution file) -2. `data_config.json` (data metadata) -3. data pt files containing the hidden state values - -Note: this process uses vLLM and requires the `datagen` optional install. - -## Vocab Mapping - -Build `d2t` and `t2d` files from the token frequency distribution file. `scripts/build_vocab_mapping.py` is the main entrypoint for this step. - -Once completed, the following files will be generated from this step on disk: - -1. `d2t.npy` -2. `t2d.npy` - -## Training - -Train an Eagle3 draft model or `speculator`. Currently, training is supported for: - -1. Single-Layer and Multi-Layer Draft Models for Non-MoE models -2. Single-Layer and Multi-Layer Draft Models of certain Non-Vision MoEs - -For a full list of models with support, see: https://github.com/vllm-project/speculators/blob/main/README.md - -`scripts/train.py` provides the main entry point for training Eagle3 models with support for single and multi GPU training using FSDP. - -# Examples - -The files in this folder provide end-to-end examples which run the three steps listed above for GPT-OSS, Llama3 and Qwen3 draft models. If at any point a step fails, you can rerun the script and continue from the last step. Seprate steps may also run using the individual scripts listed above. diff --git a/examples/data_generation_and_training/gpt_oss_20b_ultrachat_5k.py b/examples/data_generation_and_training/gpt_oss_20b_ultrachat_5k.py deleted file mode 100644 index 261b1f437..000000000 --- a/examples/data_generation_and_training/gpt_oss_20b_ultrachat_5k.py +++ /dev/null @@ -1,75 +0,0 @@ -import sys -from pathlib import Path - -# Add scripts directory to path so we can import the run_e2e function. -scripts_path = Path(__file__).absolute().parent.parent.parent / "scripts" -sys.path.append(str(scripts_path)) - -from gen_and_train import ( # noqa: E402 - DataGenArgs, - TrainArgs, - VocabMappingArgs, - run_e2e, -) - -### Example E2E run for GPT-OSS 20B on 5k samples from UltraChat ### - -# Note: With just 5k samples, the model performance will not be very good, however there -# are enough samples to verify that the pipeline is working correctly and that the model -# is learning something. This is a good sanity check when creating a drafter for a new -# target model. - -# Because this is a thinking model, we use "turn dropout" which randomly truncates -# training conversations. This is because thinking models only use the last response -# when training (via loss masking). By randomly truncating the conversations, the model -# learns to generalize to both short and long conversations. - -# Timing (on 4x NVIDIA H100 80GB GPUs) -# Data Generation: 484.58 seconds -# Vocab Mapping: 4.41 seconds -# Training: 542.34 seconds seconds -# Total: 1031.33 seconds (17 mins) - -# Results on MT-Bench: -# first token accuracy: 0.14 -# second token accuracy: 0.01 -# third token accuracy: 0.00 -# average acceptance length: 1.34 - - -if __name__ == "__main__": - VERIFIER_NAME_OR_PATH = "openai/gpt-oss-20b" - OUTPUT_PATH = "./output/gpt_oss_20b_ultrachat_5k" - TOTAL_SEQ_LEN = 8192 - - # Data Generation - data_gen_args_ultrachat = DataGenArgs( - train_data_path="ultrachat", - seq_length=TOTAL_SEQ_LEN, - max_samples=5000, # Only use 5000 samples from UltraChat - turn_dropout=True, # Turn dropout enabled here - ) - - # Vocab Mapping - vocab_mapping_args = VocabMappingArgs( - draft_vocab_size=32000, # Use a 32k draft vocabulary - target_vocab_size=201088, # From https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json - ) - - # Training (norm_before_fc=True for gpt-oss to stabilize draft path) - train_args = TrainArgs( - logger="tensorboard", - lr=3e-5, - total_seq_len=TOTAL_SEQ_LEN, - run_name="gpt_oss_20b_ultrachat_5k", - epochs=10, - norm_before_fc=True, - ) - - run_e2e( - verifier_name_or_path=VERIFIER_NAME_OR_PATH, - output_path=OUTPUT_PATH, - data_gen_args=data_gen_args_ultrachat, - vocab_mapping_args=vocab_mapping_args, - train_args=train_args, - ) diff --git a/examples/data_generation_and_training/llama3_8b_sharegpt_5k.py b/examples/data_generation_and_training/llama3_8b_sharegpt_5k.py deleted file mode 100755 index b9e8b84de..000000000 --- a/examples/data_generation_and_training/llama3_8b_sharegpt_5k.py +++ /dev/null @@ -1,68 +0,0 @@ -import sys -from pathlib import Path - -# Add scripts directory to path so we can import the run_e2e function. -scripts_path = Path(__file__).absolute().parent.parent.parent / "scripts" -sys.path.append(str(scripts_path)) - -from gen_and_train import ( # noqa: E402 - DataGenArgs, - TrainArgs, - VocabMappingArgs, - run_e2e, -) - -### Example E2E run for Llama 3.1 8B on 5k samples from ShareGPT ### - -# Note: With just 5k samples, the model performance will not be very good, however there -# are enough samples to verify that the pipeline is working correctly and that the model -# is learning something. This is a good sanity check when creating a drafter for a new -# target model. - -# Timing (on 2x NVIDIA H100 80GB GPUs) -# Data Generation: 839 seconds -# Vocab Mapping: 6 seconds -# Training: 1254 seconds -# Total: 2099 seconds (35 mins) - -# Results on MT-Bench: -# first token accuracy: 0.40 -# second token accuracy: 0.13 -# third token accuracy: 0.04 -# average acceptance length: 1.57 - - -if __name__ == "__main__": - VERIFIER_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct" - OUTPUT_PATH = "./output/llama3_8b_sharegpt_5k" - TOTAL_SEQ_LEN = 8192 - - # Data Generation - data_gen_args_sharegpt = DataGenArgs( - train_data_path="sharegpt", - seq_length=TOTAL_SEQ_LEN, - max_samples=5000, # Only use 5000 samples from ShareGPT - ) - - # Vocab Mapping - vocab_mapping_args = VocabMappingArgs( - draft_vocab_size=8192, # Use a very small draft vocabulary for this example - target_vocab_size=128256, # From https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/config.json#L37 - ) - - # Training - train_args = TrainArgs( - logger="trackio", - lr=3e-5, - total_seq_len=TOTAL_SEQ_LEN, - run_name="llama3_8b_sharegpt_5k", - epochs=10, - ) - - run_e2e( - verifier_name_or_path=VERIFIER_NAME_OR_PATH, - output_path=OUTPUT_PATH, - data_gen_args=data_gen_args_sharegpt, - vocab_mapping_args=vocab_mapping_args, - train_args=train_args, - ) diff --git a/examples/data_generation_and_training/qwen3_8b_sharegpt_ultrachat.py b/examples/data_generation_and_training/qwen3_8b_sharegpt_ultrachat.py deleted file mode 100755 index 32ad64c69..000000000 --- a/examples/data_generation_and_training/qwen3_8b_sharegpt_ultrachat.py +++ /dev/null @@ -1,77 +0,0 @@ -import sys -from pathlib import Path - -# Add scripts directory to path so we can import the run_e2e function. -scripts_path = Path(__file__).absolute().parent.parent.parent / "scripts" -sys.path.append(str(scripts_path)) - -from gen_and_train import ( # noqa: E402 - DataGenArgs, - TrainArgs, - VocabMappingArgs, - run_e2e, -) - -### Example E2E full run for Qwen 3 8B on ShareGPT and UltraChat ### - -# Note: This is a full training run using all the data from ShareGPT (~130k samples) and -# UltraChat (~200k samples). - -# Because this is a thinking model, we use "turn dropout" which randomly truncates -# training conversations. This is because thinking models only use the last response -# when training (via loss masking). By randomly truncating the conversations, the model -# learns to generalize to both short and long conversations. - -# Timing (on 4x NVIDIA H100 80GB GPUs) -# Data Generation: ~15 hours -# Vocab Mapping: 6 seconds -# Training: ~8 hours -# Total: ~23 hours - -# Results on MT-Bench: -# first token accuracy: 0.58 -# second token accuracy: 0.28 -# third token accuracy: 0.13 -# average acceptance length: 1.98 - - -if __name__ == "__main__": - VERIFIER_NAME_OR_PATH = "Qwen/Qwen3-8B" - OUTPUT_PATH = "./output/qwen3_8b_sharegpt_ultrachat" - TOTAL_SEQ_LEN = 8192 - - # Data Generation - data_gen_args_sharegpt = DataGenArgs( - train_data_path="sharegpt", - seq_length=TOTAL_SEQ_LEN, - turn_dropout=True, # Turn dropout enabled here - ) - - data_gen_args_ultrachat = DataGenArgs( - train_data_path="ultrachat", - seq_length=TOTAL_SEQ_LEN, - turn_dropout=True, # Turn dropout enabled here - ) - - # Vocab Mapping - vocab_mapping_args = VocabMappingArgs( - draft_vocab_size=32000, # Use an 32k draft vocabulary - target_vocab_size=151936, # From https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json#L29 - ) - - # Training - train_args = TrainArgs( - logger="trackio", - lr=3e-5, - total_seq_len=TOTAL_SEQ_LEN, - run_name="qwen3_8b_sharegpt_ultrachat", - epochs=10, - ) - - run_e2e( - verifier_name_or_path=VERIFIER_NAME_OR_PATH, - output_path=OUTPUT_PATH, - data_gen_args=[data_gen_args_sharegpt, data_gen_args_ultrachat], - vocab_mapping_args=vocab_mapping_args, - train_args=train_args, - ) diff --git a/pyproject.toml b/pyproject.toml index b4e03fc6e..3426391a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -258,9 +258,6 @@ select = [ "INP001", # allow implicit namespace packages in examples ] -"scripts/gen_and_train.py" = [ - "T201", # allow print in scripts -] [tool.ruff.lint.isort] known-first-party = ["speculators", "tests"] diff --git a/scripts/README.md b/scripts/README.md deleted file mode 100644 index 7a3fb9acf..000000000 --- a/scripts/README.md +++ /dev/null @@ -1,328 +0,0 @@ -# Scripts - -## Eagle3 Model Production - -Speculators currently supports training of Eagle3 models. This functionality is available via the scripts in this directory. - -1. [data_generation_offline.py](/scripts/data_generation_offline.py): Generate training data (verifier hidden states) using vLLM. Note: this script will also preprocess the data if it hasn't been already. -2. [build_vocab_mapping.py](/scripts/build_vocab_mapping.py): Uses the token frequency distribution file to build `d2t` (draft to target) and `t2d` (target to draft) vocabulary mappings. -3. [train.py](/scripts/train.py): Trains an Eagle3 model using the training data and vocabulary mappings. -4. (Optional) [gen_and_train.py](/scripts/gen_and_train.py): A convenience wrapper around the above scripts that runs the full pipeline in one command. - -## Table of Contents - -- **[Data Generation](#data-generation)**
- - **[Quick Start](#quick-start)**
- - **[Response Regeneration](#response-regeneration)**
- - **[Advanced Usage](#advanced-usage)**
- - **[Troubleshooting](#troubleshooting)**
-- **[Vocab Mapping](#vocab-mapping)**
- - **[Quick Start](#quick-start-1)**
-- **[Training](#training)**
- - - **[Quick Start](#quick-start-2)**
- - **[Arguments](#arguments)**
- - **[Example Command](#example-command)**
-- **[E2E Pipeline](#e2e-pipeline)**
- - **[Overview](#overview)**
- - **[Prerequisites](#prerequisites)**
- - **[Usage](#usage)**
- -## Data Generation - -`scripts/data_generation_offline.py` provides the main entry point for generating training data for Eagle3 models. Data generation uses vLLM and requires the optional `datagen` install. - -### Quick Start - -Generate training data from ShareGPT using Llama 3.1 8B: - -```bash -python scripts/data_generation_offline.py \ - --target-model-path meta-llama/Llama-3.1-8B-Instruct \ - --train-data-path sharegpt \ - --output-dir ./training_data \ - --max-samples 5000 -``` - -The script automatically uses the tokenizer's built-in chat template via `apply_chat_template`. It will use vllm to generate target model hidden states for the training data, and save them to disk alongside the input_ids and loss_mask tensors, as .pt files. - -For sample generated data, see: https://huggingface.co/datasets/nm-testing/sharegpt_llama3_8b_hidden_states - -### Response Regeneration - -The [response_regeneration/](/scripts/response_regeneration/) directory contains scripts for regenerating assistant responses in existing datasets using a vLLM-served model. Given a dataset containing user prompts (e.g., Magpie, UltraChat), the pipeline extracts the prompts, sends them to a vLLM server, and produces a new dataset with freshly generated responses from the target model. Regenerating responses with the target model can improve draft model performance, since the training data distribution better matches the target model's own outputs. - -See the [response_regeneration/README.md](/scripts/response_regeneration/README.md) for full usage details. - -### Advanced Usage - -With custom settings and multi-GPU: - -```bash -python scripts/data_generation_offline.py \ - --target-model-path meta-llama/Llama-3.1-70B-Instruct \ - --train-data-path ./my_data.jsonl \ - --seq-length 4096 \ - --cache-dir ./cache \ - --output-dir ./training_data \ - --layer-ids 2 28 54 \ - --tensor-parallel-size 4 \ - --batch-size 16 \ - --max-samples 10000 -``` - -### Data Config File - -The script will produce a `data_config.json` file in the output directory, which contains the configuration used to generate the data, as well as other metadata about the data generation process. - -Example file: - -```json -{ - "version": "2.0", - "generated_at": "2025-12-03T16:03:02.471808+00:00", - "speculators_version": "0.3.0", - "reproducibility": { - "command": "data_generation_offline.py --target-model-path meta-llama/Llama-3.1-8B-Instruct --train-data-path sharegpt --output-dir ./training_data --max-samples 5000", - "package_versions": { - "torch": "2.8.0+cu128", - "vllm": "0.11.0", - "transformers": "4.57.3", - "speculators": "0.3.0" - }, - "gpu": "NVIDIA H100 80GB HBM3" - }, - "model": { - "target_model_path": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "max_model_len": 2048, - "gpu_memory_utilization": 0.8, - "hidden_size": 4096 - }, - "data": { - "train_data_path": "sharegpt", - "seq_length": 2048, - "max_samples": 5000, - "num_samples": 5000, - "seed": 0, - "chat_template_note": "Uses tokenizer's built-in chat template" - }, - "hidden_states": { - "layer_ids": [ - 2, - 16, - 29, - 31 - ], - "description": "Layers selected for EAGLE3 fusion and target logits" - }, - "generation": { - "cache_dir": "/home/***/.cache/huggingface/datasets" - }, - "format": { - "file_pattern": "data_{idx}.pt", - "schema": { - "input_ids": { - "dtype": "torch.long", - "shape": "[seq_len]", - "description": "Tokenized input sequence" - }, - "hidden_states": { - "dtype": "list[torch.bfloat16]", - "shape": "list of [seq_len, 4096]", - "num_tensors": 4, - "description": "Hidden states from 4 layers" - }, - "loss_mask": { - "dtype": "torch.long", - "shape": "[seq_len]", - "description": "1 for assistant tokens to train on, 0 elsewhere" - } - } - } -} -``` - -### Token Frequency File - -Along with the `data_config.json`, the data generation step will also generate a `token_freq.pt` file containing the token frequencies. If not specified, the default location for the token frequency file is `./token_freq.pt` i.e in the same directory where the script runs. This frequencies will be used to `d2t` i.e `draft-to-target` and `t2d` i.e `target-to-draft` vocabulary mappings. - -#### Datasets - -Built-in datasets (can be used directly by name in the `--train-data-path` argument): - -- `sharegpt` - ShareGPT Vicuna unfiltered -- `ultrachat` - HuggingFace UltraChat 200k - -Alternatively, you can use a different dataset by passing the HuggingFace dataset path or local JSON/JSONL file path in the `--train-data-path` argument. - -#### Caching - -Preprocessing is automatically cached by HuggingFace datasets using fingerprint-based cache invalidation. The cache automatically updates when: - -- Tokenizer changes -- Preprocessing parameters change (seq_length, etc.) -- Dataset changes - -**Cache Location:** - -Default: `~/.cache/huggingface/datasets` (Optional) Use a custom cache directory by setting the `HF_HUB_CACHE` environment variable - -```bash -# Example: Use custom cache directory -export HF_HUB_CACHE=/path/to/your/cache -python scripts/data_generation_offline.py ... -``` - -### Troubleshooting - -1. **Out of memory during hidden state extraction** - - - Reduce `--batch-size` - - Reduce `--seq-length` - - Increase `--tensor-parallel-size` - -2. **Layer index out of bounds** - - - Check model's actual number of layers - - Auto-selection uses: `[2, num_layers // 2, num_layers - 3]` - -3. **No assistant response spans found** - - - Ensure tokenizer has a chat template (supports `apply_chat_template`) - - Check that conversations have assistant responses in correct format (role/content keys) - -4. **Cache invalidation** - - - Delete cache directory if changing preprocessing parameters - - Ensure `--seed` matches between runs for reproducibility - -## Vocab Mapping - -`scripts/build_vocab_mapping.py` Uses the token frequency distribution file to build `d2t` (draft to target) and `t2d` (target to draft) vocabulary mappings. - -### Quick Start - -Generate vocab mapping using Llama 3.1 8B: - -by specifying `target-vocab-size` manually: - -```bash - python scripts/build_vocab_mapping.py \ - --token-freq-path ./token_freq.pt \ - --draft-vocab-size 32000 \ - --target-vocab-size 128256 \ - --output-path ./vocab_mapping -``` - -or by using `target-model-path` to automatically infer the target vocab size: - -```bash - python scripts/build_vocab_mapping.py \ - --token-freq-path ./token_freq.pt \ - --draft-vocab-size 32000 \ - --target-model-path meta-llama/Llama-3.1-8B-Instruct \ - --output-path ./vocab_mapping -``` - -If not specified, the default location for token frequency file is `./token_freq.pt`. Make sure `target-vocab-size` match the verifier model vocab size exactly. Once complete, this step will generate and save `t2d.npy` and `d2t.npy` files to disk. - -## Training - -`scripts/train.py` provides the main entry point for training Eagle3 models. - -### Quick Start - -To run in a single-node multi-GPU distributed training setup with FSDP, the scripts should be launched with `torchrun`: - -```bash -torchrun --standalone --nproc_per_node= scripts/train.py -``` - -For single GPU training (useful for debugging), the script can be run directly: - -```bash -python scripts/train.py -``` - -> [!NOTE] -> Use `CUDA_VISIBLE_DEVICES=` to control which GPUS are visible to the script. - -### Arguments - -The scripts has one required argument: `--verifier-name-or-path`, which is the name or path of the verifier model to use. - -The scripts has the following optional arguments: - -- `--data-path`: The path to the data directory. Defaults to `./data`. The script will collect all `.pt` files in this directory or its subdirectories and use them as training data. -- `--save-path`: The path to save the checkpoints. Defaults to `./checkpoints`. The script will create subdirectories for each epoch to save the model weights and optimizer states. e.g. `./checkpoints/0/` -- `--epochs`: The number of epochs to train for. Defaults to 20. -- `--lr`: The learning rate to use. Defaults to 1e-4. -- `--no-resume-from-checkpoint`: If set, the script will not resume from the last checkpoint if it exists, and will instead start from scratch and overwrite existing checkpoints. -- `--logger`: The logger to use. Defaults to empty string, which means no logging. Supported loggers are `trackio`, `wandb`, and `tensorboard`. -- `--total-seq-len`: The total sequence length to use. Defaults to 8192. -- `--log-dir`: The path to save the logs. Defaults to `./logs`. -- `--run-name`: The name of the run. Defaults to None. -- `--num-layers`: The number of layers to use. Defaults to 1. -- `--d2t-path`: The path to the d2t tensor. Defaults to `d2t.npy`. -- `--t2d-path`: The path to the t2d tensor. Defaults to `t2d.npy`. -- `--ttt-steps`: The number of TTT steps to use. Defaults to 3. -- `--ttt-step-loss-decay`: The loss decay factor to use for the TTT steps. Defaults to 1.0. - -### Example Command - -```bash -torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py \ - --verifier-name-or-path "meta-llama/Llama-3.1-8B-Instruct" \ - --data-path "./data/llama-3.1-8b_sharegpt/gen/" \ - --save-path "./checkpoints/llama-3.1-8b.eagle3" \ - --epochs 10 \ - --lr 1e-4 \ - --no-resume-from-checkpoint \ - --logger "tensorboard" \ - --total-seq-len 8192 \ - --log-dir "./logs/llama-3.1-8b.eagle3" \ - --run-name "llama-3.1-8b.eagle3" \ - --num-layers 1 \ - --d2t-path "./data/llama-3.1-8b_sharegpt/d2t.npy" \ - --t2d-path "./data/llama-3.1-8b_sharegpt/t2d.npy" \ - --ttt-steps 3 \ - --ttt-step-loss-decay 1.0 -``` - -## E2E Pipeline - -### Overview - -`scripts/gen_and_train.py` can be used to run the full pipeline in one command. It also ensures each script is run with the correct arguments and dependencies. - -Internally it calls the following scripts in order: - -1. scripts/data_generation_offline.py -2. scripts/build_vocab_mapping.py -3. scripts/train.py - -Using `uv` to produce ephemeral environments for each script. - -### Prerequisites: - -- python 3.10+ -- uv (`pip install uv`) - -### Usage: - -> [!IMPORTANT] -> Update the script arguments section in the script file itself before running. - -Then run: - -```bash -python scripts/gen_and_train.py -``` - -> [!NOTE] -> You can call the script with environment variables (like `CUDA_VISIBLE_DEVICES` and `HF_HOME`) to control the behavior of the scripts. By default the script will use all available GPUs. - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gen_and_train.py -``` diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 67c54d966..e52ccdafb 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -1,58 +1,60 @@ #!/usr/bin/env python3 """ -Offline EAGLE Training Data Generation Pipeline +Offline Hidden States Generation Pipeline -This script generates training data for EAGLE models by: -1. Automatically preprocessing data if needed (or loading from cache) -2. Using vLLM to extract hidden states from target model -3. Saving each data point as a separate .pt file - -Preprocessing is cached automatically by HuggingFace datasets. -Token frequencies are saved in the current directory by default. +This script generates hidden states and saves them to disk for offline training. Usage: python data_generation_offline.py \ - --target-model-path meta-llama/Llama-3.1-8B-Instruct \ - --train-data-path sharegpt \ - --output-dir ./training_data \ - --hf-cache-dir /path/to/cache \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --preprocessed-data sharegpt \ + --output ./training_data \ --max-samples 5000 """ import argparse -import json +import asyncio import logging -from concurrent.futures import ThreadPoolExecutor, as_completed +import os +import shutil +import sys from pathlib import Path +from typing import Any -import torch -from datasets import config as datasets_config -from tqdm import tqdm # type: ignore[import-untyped] +import openai +from datasets import load_from_disk +from safetensors import safe_open +from tqdm import tqdm -# Set vLLM to use 'spawn' instead of 'fork' -# to prevent "Cannot re-initialize CUDA in forked subprocess" errors -from vllm import envs +from speculators.data_generation.vllm_client import ( + DEFAULT_MAX_RETRIES, + DEFAULT_REQUEST_TIMEOUT, + generate_hidden_states_async, +) +from speculators.train.logger import setup_root_logger -envs.VLLM_WORKER_MULTIPROC_METHOD = "spawn" +logger = logging.getLogger(__name__) -from speculators.data_generation.config_generator import ( # noqa: E402 - DataGenerationConfig, -) -from speculators.data_generation.logging_utils import PipelineLogger # noqa: E402 -from speculators.data_generation.preprocessing import ( # noqa: E402 - load_and_preprocess_dataset, -) -from speculators.data_generation.vllm_hidden_states_generator import ( # noqa: E402 - VllmHiddenStatesGenerator, -) -# Constants -MAX_IO_WORKERS = 4 # Number of parallel file save operations +class _FailureTracker: + """Tracks consecutive sample failures across async workers. -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -log = PipelineLogger(__name__) + When the number of consecutive failures (with no successes in between) + reaches ``threshold``, the tracker signals that the run should abort. + Because asyncio is single-threaded, no locking is needed. + """ + + def __init__(self, threshold: int): + self.threshold = threshold + self._consecutive = 0 + + def record_success(self) -> None: + self._consecutive = 0 + + def record_failure(self) -> bool: + """Record a failure. Returns True when the threshold is reached.""" + self._consecutive += 1 + return self._consecutive >= self.threshold def parse_args(): @@ -60,37 +62,31 @@ def parse_args(): # Model arguments parser.add_argument( - "--target-model-path", + "--model", type=str, - required=True, - help="HuggingFace model ID or local path for target model", - ) - parser.add_argument( - "--tensor-parallel-size", - type=int, - default=torch.accelerator.device_count(), - help="Tensor parallel size for target model (default: 1)", + default=None, + help=( + "HuggingFace model ID or local path for target model (default auto select)." + "For verification purposes only." + ), ) parser.add_argument( - "--gpu-memory-utilization", - type=float, - default=0.8, - help="Target GPU memory utilization (default: 0.8)", + "--endpoint", + type=str, + default="http://localhost:8000/v1", + help=( + "The address of the vLLM instance to use for hidden states generation " + "(default: 'http://localhost:8000/v1'). " + "Note: the vLLM instance must be configured for hidden states extraction." + ), ) # Data arguments parser.add_argument( - "--train-data-path", + "--preprocessed-data", type=str, - action="append", required=True, - help="Path to training data (same as used in preprocessing)", - ) - parser.add_argument( - "--seq-length", - type=int, - default=2048, - help="Maximum sequence length for preprocessing and model (default: 2048)", + help="Path to preprocessed dataset (dataset produced by prepare_data.py)", ) parser.add_argument( "--max-samples", @@ -98,272 +94,381 @@ def parse_args(): default=None, help="Maximum number of samples to process (default: None, process all)", ) + + # Output arguments parser.add_argument( - "--token-freq-path", + "--output", type=str, - default="./token_freq.pt", - help="Path to save token frequency distribution (default: ./token_freq.pt)", + default=None, + help=( + "Directory to generated hidden states files " + "(default args.preprocessed_data / 'hidden_states')" + ), ) + + # Hidden states generation arguments parser.add_argument( - "--hf-cache-dir", - type=str, + "--layer-ids", + type=int, + nargs="+", default=None, help=( - "Directory for HuggingFace datasets cache. " - "If not specified, uses HF_DATASETS_CACHE env var or default location. " - "(default: None)" + "List of layer IDs from which to capture hidden states " + "(default: auto-select)" ), ) parser.add_argument( - "--assistant-pattern", - type=str, - default=None, + "--concurrency", + type=int, + default=32, help=( - "Custom regex pattern for matching assistant responses. " - "If not provided, auto-detected from chat template." + "Number of active vLLM requests at a time." + "Note: number of async workers set to 2*concurrency" ), ) parser.add_argument( - "--turn-dropout", + "--validate-outputs", action="store_true", help=( - "Enable turn dropout: randomly keeps first N consecutive turns " - "per conversation for data augmentation." + "Load generated safetensor files and check output token ids match prompt" + " tokens and hidden states seq_len matches num tokens" ), ) - - # Output arguments parser.add_argument( - "--output-dir", type=str, required=True, help="Directory to save .pt files" + "--request-timeout", + type=float, + default=DEFAULT_REQUEST_TIMEOUT, + help=( + "Timeout in seconds for each individual vLLM request " + f"(default: {DEFAULT_REQUEST_TIMEOUT})" + ), ) - - # Hidden states generation arguments parser.add_argument( - "--layer-ids", + "--max-retries", type=int, - nargs="+", - default=None, + default=DEFAULT_MAX_RETRIES, help=( - "List of layer IDs from which to capture hidden states " - "(default: auto-select)" + "Maximum number of retry attempts per request on failure " + f"(default: {DEFAULT_MAX_RETRIES})" ), ) parser.add_argument( - "--batch-size", - type=int, - default=8, - help="Batch size for hidden states generation (default: 8)", + "--fail-on-error", + action="store_true", + help=( + "Abort when a request fails after all retries. " + "By default, failed samples are skipped." + ), ) - - # Processing arguments parser.add_argument( - "--seed", + "--max-consecutive-errors", type=int, - default=0, - help="Random seed (must match preprocessing seed, default: 0)", + default=None, + help=( + "Abort after this many consecutive sample failures (each sample " + "already retried --max-retries times). Prevents silently churning " + "through the entire dataset when the server is down. " + "Ignored when --fail-on-error is set. " + "(default: value of --concurrency)" + ), ) + + # Processing arguments parser.add_argument( "--start-idx", type=int, default=0, help="Starting index for output files (default: 0)", ) - parser.add_argument( - "--num-preprocessing-workers", - type=int, - default=8, - help="Number of CPU processes for dataset preprocessing (default: 8)", - ) - parser.add_argument( - "--minimum-valid-tokens", - type=int, - default=None, - help=( - "Drop samples whose loss mask contains fewer than this many " - "trainable tokens." - ), - ) return parser.parse_args() -def find_last_checkpoint(output_dir: str) -> int: - """Find the last successfully saved file index by scanning existing files.""" - output_path = Path(output_dir) +def get_existing_hidden_state_indices(output_path: Path) -> list[int]: + """Find existing `hs_i.safetensors` files (where i is the file index)""" + + existing_file_indices = [] + if not output_path.exists(): - return 0 + return existing_file_indices - max_index = -1 for file_path in output_path.iterdir(): - if file_path.name.startswith("data_") and file_path.name.endswith(".pt"): - index_str = file_path.stem[5:] # Remove "data_" prefix + if file_path.name.startswith("hs_") and file_path.name.endswith(".safetensors"): + index_str = file_path.stem[3:] # Remove "hs_" prefix try: - index = int(index_str) - max_index = max(max_index, index) + file_index = int(index_str) + existing_file_indices.append(file_index) except ValueError: continue - return max_index + 1 - - -def save_sample_to_disk(data_dict, output_path): - """Save a single sample to disk for async execution.""" - torch.save(data_dict, output_path) - return output_path - - -def save_config(args, generator, num_samples, output_dir): - """Save metadata config file for reproducibility.""" - log.subsection("Saving configuration metadata") - - cache_dir = ( - args.hf_cache_dir if args.hf_cache_dir else datasets_config.HF_DATASETS_CACHE - ) - - config = DataGenerationConfig.from_generator( - generator=generator, - train_data_path=args.train_data_path, - seq_length=args.seq_length, - cache_dir=str(cache_dir), - num_samples=num_samples, - max_samples=args.max_samples, - seed=args.seed, - ) - - config_path = Path(output_dir) / "data_config.json" - config_path.write_text(json.dumps(config.to_dict(), indent=2)) - log.info(f"Saved config v{config.version} to {config_path}") - - -def generate_and_save_hidden_states(args, dataset): - """Generate hidden states and save each sample as a .pt file""" - Path(args.output_dir).mkdir(parents=True, exist_ok=True) + return sorted(existing_file_indices) + + +def get_indices_to_process( + num_samples: int, max_samples: int | None, existing: list[int] +) -> list[int]: + """Determines which indices should be processed. If max_samples is None + returns all dataset indices not in existing. Otherwise gets the first + `max_samples - len(existing)` samples not already in existing. + + Args: + num_samples: Total size of preprocessed dataset + max_samples: (Optional) limit for number of samples to process + existing: list of ids that have already been processed + + Returns: + list of dataset indices to process + """ + + if len(existing) >= num_samples: + logger.info("All samples already processed!") + return [] + if max_samples and len(existing) >= max_samples: + logger.info("At least max_samples already processed!") + return [] + + if len(existing) > 0: + logger.info(f"Found {len(existing)} existing samples.") + + existing_s = set(existing) + if max_samples is None: + return [i for i in range(num_samples) if i not in existing_s] + + num_remaining = min(max_samples, num_samples) - len(existing) + to_process = [] + cur = 0 + while num_remaining > 0 and cur < num_samples: + if cur not in existing_s: + to_process.append(cur) + num_remaining -= 1 + + cur += 1 + + return to_process + + +def check_safetensors_file(path: Path, tokens: list[int]): + with safe_open(path, "pt") as f: + t_ids = f.get_tensor("token_ids").tolist() + if t_ids != tokens: + raise ValueError( + f"Token ids in {path} don't match expected token ids {tokens}" + ) + + hs_slice = f.get_slice("hidden_states") + hs_shape = list(hs_slice.get_shape()) + if len(tokens) != hs_shape[0]: + raise ValueError( + f"Sequence length of hidden states {hs_shape[0]} in {path}" + f" doesn't match num tokens {len(tokens)}" + ) + + +async def worker( + client, + model: str, + queue: "asyncio.Queue[dict[str, Any]]", + pbar: tqdm, + vllm_semaphore: asyncio.Semaphore, + write_semaphore: asyncio.Semaphore, + hidden_states_output_dir: Path, + validate_outputs: bool, + request_timeout: float | None, + max_retries: int, + fail_on_error: bool, + skipped_indices: list[int], + cancel_event: asyncio.Event, + failure_tracker: _FailureTracker | None, +): + """Worker that pulls items from queue and sends them to the vLLM endpoint.""" + while True: + item = await queue.get() + if item is None: + queue.task_done() + return + + idx = item["idx"] + + # Drain remaining items quickly after cancellation + if cancel_event.is_set(): + queue.task_done() + continue + + input_ids = item["input_ids"].tolist() + + target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" + + try: + async with vllm_semaphore: # Limit number of active generate calls + hidden_states_path = await generate_hidden_states_async( + client, + model, + input_ids, + timeout=request_timeout, + max_retries=max_retries, + ) + async with write_semaphore: # Limit number of active disk writes + await asyncio.to_thread( + shutil.move, hidden_states_path, target_hidden_states_path + ) + if validate_outputs: + await asyncio.to_thread( + check_safetensors_file, target_hidden_states_path, input_ids + ) + except Exception as e: + if fail_on_error: + logger.exception( + "Fatal: sample %d failed with --fail-on-error: %s", idx, e + ) + logging.shutdown() + os._exit(1) + logger.warning("Skipping sample %d due to error: %s", idx, e) + skipped_indices.append(idx) + if failure_tracker is not None and failure_tracker.record_failure(): + cancel_event.set() + raise RuntimeError( + f"Aborting: {failure_tracker.threshold} consecutive samples " + "failed. The vLLM server may be unreachable." + ) from e + else: + if failure_tracker is not None: + failure_tracker.record_success() + finally: + pbar.update(1) + queue.task_done() + + +async def _feed_queue(to_process, dataset, queue, cancel_event): + """Feed dataset items into the worker queue, respecting cancellation.""" + for i in to_process: + if cancel_event.is_set(): + break + item = dataset[i] + # Check cancel_event while waiting for queue space to avoid + # deadlocking when all workers have died. + while not cancel_event.is_set(): + try: + queue.put_nowait({"idx": i, "input_ids": item["input_ids"]}) + break + except asyncio.QueueFull: + await asyncio.sleep(0.1) + + +async def _shutdown_workers(workers, queue, cancel_event): + """Shut down workers and propagate the first real exception.""" + logger.info("Waiting for remaining file saves to complete...") + if cancel_event.is_set(): + # Workers may be dead or draining — cancel any that are + # still alive so we don't deadlock on sentinel puts. + for w in workers: + if not w.done(): + w.cancel() + else: + # Normal shutdown: send sentinel values so workers exit + for _ in range(len(workers)): + await queue.put(None) + results = await asyncio.gather(*workers, return_exceptions=True) + + # Propagate the first real worker exception (skip CancelledError) + for result in results: + if isinstance(result, Exception) and not isinstance( + result, asyncio.CancelledError + ): + raise result - start_file_idx = find_last_checkpoint(args.output_dir) - # Load existing sample lengths to preserve them on resume - sample_lengths_output_path = Path(args.output_dir) / "sample_lengths.json" - if start_file_idx > 0 and sample_lengths_output_path.exists(): - with open(sample_lengths_output_path) as f: - sample_lengths = json.load(f) - log.subsection( - f"Resuming: {start_file_idx} files already exist, " - f"loaded {len(sample_lengths)} existing sample lengths" - ) +async def generate_and_save_hidden_states(args, dataset): + if args.output is None: + hidden_states_dir = Path(args.preprocessed_data) / "hidden_states" else: - sample_lengths = {} - if start_file_idx > 0: - log.subsection(f"Resuming: {start_file_idx} files already exist") + hidden_states_dir = Path(args.output) + hidden_states_dir.mkdir(parents=True, exist_ok=True) + existing_file_indices = get_existing_hidden_state_indices(hidden_states_dir) num_samples = len(dataset) - start_sample_idx = start_file_idx - args.start_idx - - if start_sample_idx >= num_samples: - log.info("All samples already processed!") - return 0 - - log.subsection("Initializing vLLM hidden states generator") - generator = VllmHiddenStatesGenerator( - model_path=args.target_model_path, - layer_ids=args.layer_ids, - max_model_len=args.seq_length, - gpu_memory_utilization=args.gpu_memory_utilization, - tensor_parallel_size=args.tensor_parallel_size, - ) - log.info(f"Processing {num_samples - start_sample_idx}/{num_samples} samples") - file_idx = start_file_idx - - num_batches = ( - num_samples - start_sample_idx + args.batch_size - 1 - ) // args.batch_size - - # Use ThreadPoolExecutor for async file I/O - max_io_workers = MAX_IO_WORKERS - - pbar = tqdm( - range(start_sample_idx, num_samples, args.batch_size), - desc="Generating hidden states", - total=num_batches, + to_process = get_indices_to_process( + num_samples, args.max_samples, existing_file_indices ) - - with ThreadPoolExecutor(max_workers=max_io_workers) as thread_executor: - futures = [] - - for i in pbar: - batch_end = min(i + args.batch_size, num_samples) - batch = dataset[i:batch_end] - batch_input_ids = batch["input_ids"] - batch_loss_mask = batch["loss_mask"] - - results = generator.generate(batch_input_ids) - - # Submit save operations to thread pool (async I/O) - for j, result in enumerate(results): - # Truncate loss_mask to match input_ids length (generator may truncate) - input_len = len(result["input_ids"]) - sample_lengths[str(file_idx)] = input_len - loss_mask = batch_loss_mask[j][:input_len] - - result_cleaned = { - "input_ids": result["input_ids"], - "hidden_states": [h.contiguous() for h in result["hidden_states"]], - "loss_mask": loss_mask, - } - output_path = Path(args.output_dir) / f"data_{file_idx}.pt" - future = thread_executor.submit( - save_sample_to_disk, result_cleaned, output_path + if not to_process: + return + + logger.info(f"Processing {len(to_process)} samples") + + queue: asyncio.Queue = asyncio.Queue(maxsize=args.concurrency * 4) + vllm_semaphore = asyncio.Semaphore(args.concurrency) + write_semaphore = asyncio.Semaphore(args.concurrency) + + skipped_indices: list[int] = [] + cancel_event = asyncio.Event() + + max_consec = args.max_consecutive_errors + if max_consec is None: + max_consec = args.concurrency + failure_tracker = _FailureTracker(max_consec) if not args.fail_on_error else None + + async with openai.AsyncOpenAI( + base_url=args.endpoint, api_key="EMPTY", max_retries=0 + ) as client: + list_models = await client.models.list() + model_id = list_models.data[0].id + if args.model and args.model != model_id: + raise ValueError( + f"An explicit model name was passed ({args.model}) which doesn't match" + "found model_id {model_id}." + "Please make sure --endpoint is set to the correct vllm instance." + ) + + with tqdm(total=len(to_process)) as pbar: + workers = [ + asyncio.create_task( + worker( + client, + model_id, + queue, + pbar, + vllm_semaphore, + write_semaphore, + hidden_states_dir, + args.validate_outputs, + args.request_timeout, + args.max_retries, + args.fail_on_error, + skipped_indices, + cancel_event, + failure_tracker, + ) ) - futures.append(future) - file_idx += 1 - - log.info("Waiting for remaining file saves to complete...") - for future in tqdm( - as_completed(futures), total=len(futures), desc="Saving files" - ): - future.result() - - samples_saved = file_idx - start_file_idx + for _ in range(args.concurrency * 2) + ] - with open(sample_lengths_output_path, "w") as f: - json.dump(sample_lengths, f, indent=2) + await _feed_queue(to_process, dataset, queue, cancel_event) + await _shutdown_workers(workers, queue, cancel_event) - log.info(f"Saved {samples_saved} new data points to {args.output_dir}") - - save_config(args, generator, num_samples, args.output_dir) - - return samples_saved + num_saved = len(to_process) - len(skipped_indices) + logger.info(f"Saved {num_saved} new data points to {args.output}") + if skipped_indices: + logger.warning( + f"Skipped {len(skipped_indices)} samples due to errors: {skipped_indices}" + ) def main(): args = parse_args() + setup_root_logger() - log.section("EAGLE Offline Data Generation") - log.config( - { - "Target Model": args.target_model_path, - "Dataset": args.train_data_path, - "Output Dir": args.output_dir, - "Tensor Parallel": args.tensor_parallel_size, - "Batch Size": args.batch_size, - } - ) + logger.info("EAGLE Offline Data Generation") - dataset, _ = load_and_preprocess_dataset( - target_model_path=args.target_model_path, - train_data_paths=args.train_data_path, - seq_length=args.seq_length, - build_dataset_num_proc=args.num_preprocessing_workers, - seed=args.seed, - max_samples=args.max_samples, - token_freq_path=args.token_freq_path, - assistant_pattern=args.assistant_pattern, - turn_dropout=args.turn_dropout, - minimum_valid_tokens=args.minimum_valid_tokens, - ) - num_saved = generate_and_save_hidden_states(args, dataset) + dataset = load_from_disk(args.preprocessed_data) + + try: + asyncio.run(generate_and_save_hidden_states(args, dataset)) + except KeyboardInterrupt: + sys.exit(130) + except Exception: + logger.exception("Data generation failed") + sys.exit(1) - log.section("Data generation complete!") - log.info(f"Saved {num_saved} files to {args.output_dir}") + logger.info("Data generation complete!") if __name__ == "__main__": diff --git a/scripts/data_generation_offline2.py b/scripts/data_generation_offline2.py deleted file mode 100644 index e52ccdafb..000000000 --- a/scripts/data_generation_offline2.py +++ /dev/null @@ -1,475 +0,0 @@ -#!/usr/bin/env python3 -""" -Offline Hidden States Generation Pipeline - -This script generates hidden states and saves them to disk for offline training. - -Usage: - python data_generation_offline.py \ - --model meta-llama/Llama-3.1-8B-Instruct \ - --preprocessed-data sharegpt \ - --output ./training_data \ - --max-samples 5000 -""" - -import argparse -import asyncio -import logging -import os -import shutil -import sys -from pathlib import Path -from typing import Any - -import openai -from datasets import load_from_disk -from safetensors import safe_open -from tqdm import tqdm - -from speculators.data_generation.vllm_client import ( - DEFAULT_MAX_RETRIES, - DEFAULT_REQUEST_TIMEOUT, - generate_hidden_states_async, -) -from speculators.train.logger import setup_root_logger - -logger = logging.getLogger(__name__) - - -class _FailureTracker: - """Tracks consecutive sample failures across async workers. - - When the number of consecutive failures (with no successes in between) - reaches ``threshold``, the tracker signals that the run should abort. - Because asyncio is single-threaded, no locking is needed. - """ - - def __init__(self, threshold: int): - self.threshold = threshold - self._consecutive = 0 - - def record_success(self) -> None: - self._consecutive = 0 - - def record_failure(self) -> bool: - """Record a failure. Returns True when the threshold is reached.""" - self._consecutive += 1 - return self._consecutive >= self.threshold - - -def parse_args(): - parser = argparse.ArgumentParser(description="Generate EAGLE training data offline") - - # Model arguments - parser.add_argument( - "--model", - type=str, - default=None, - help=( - "HuggingFace model ID or local path for target model (default auto select)." - "For verification purposes only." - ), - ) - parser.add_argument( - "--endpoint", - type=str, - default="http://localhost:8000/v1", - help=( - "The address of the vLLM instance to use for hidden states generation " - "(default: 'http://localhost:8000/v1'). " - "Note: the vLLM instance must be configured for hidden states extraction." - ), - ) - - # Data arguments - parser.add_argument( - "--preprocessed-data", - type=str, - required=True, - help="Path to preprocessed dataset (dataset produced by prepare_data.py)", - ) - parser.add_argument( - "--max-samples", - type=int, - default=None, - help="Maximum number of samples to process (default: None, process all)", - ) - - # Output arguments - parser.add_argument( - "--output", - type=str, - default=None, - help=( - "Directory to generated hidden states files " - "(default args.preprocessed_data / 'hidden_states')" - ), - ) - - # Hidden states generation arguments - parser.add_argument( - "--layer-ids", - type=int, - nargs="+", - default=None, - help=( - "List of layer IDs from which to capture hidden states " - "(default: auto-select)" - ), - ) - parser.add_argument( - "--concurrency", - type=int, - default=32, - help=( - "Number of active vLLM requests at a time." - "Note: number of async workers set to 2*concurrency" - ), - ) - parser.add_argument( - "--validate-outputs", - action="store_true", - help=( - "Load generated safetensor files and check output token ids match prompt" - " tokens and hidden states seq_len matches num tokens" - ), - ) - parser.add_argument( - "--request-timeout", - type=float, - default=DEFAULT_REQUEST_TIMEOUT, - help=( - "Timeout in seconds for each individual vLLM request " - f"(default: {DEFAULT_REQUEST_TIMEOUT})" - ), - ) - parser.add_argument( - "--max-retries", - type=int, - default=DEFAULT_MAX_RETRIES, - help=( - "Maximum number of retry attempts per request on failure " - f"(default: {DEFAULT_MAX_RETRIES})" - ), - ) - parser.add_argument( - "--fail-on-error", - action="store_true", - help=( - "Abort when a request fails after all retries. " - "By default, failed samples are skipped." - ), - ) - parser.add_argument( - "--max-consecutive-errors", - type=int, - default=None, - help=( - "Abort after this many consecutive sample failures (each sample " - "already retried --max-retries times). Prevents silently churning " - "through the entire dataset when the server is down. " - "Ignored when --fail-on-error is set. " - "(default: value of --concurrency)" - ), - ) - - # Processing arguments - parser.add_argument( - "--start-idx", - type=int, - default=0, - help="Starting index for output files (default: 0)", - ) - return parser.parse_args() - - -def get_existing_hidden_state_indices(output_path: Path) -> list[int]: - """Find existing `hs_i.safetensors` files (where i is the file index)""" - - existing_file_indices = [] - - if not output_path.exists(): - return existing_file_indices - - for file_path in output_path.iterdir(): - if file_path.name.startswith("hs_") and file_path.name.endswith(".safetensors"): - index_str = file_path.stem[3:] # Remove "hs_" prefix - try: - file_index = int(index_str) - existing_file_indices.append(file_index) - except ValueError: - continue - - return sorted(existing_file_indices) - - -def get_indices_to_process( - num_samples: int, max_samples: int | None, existing: list[int] -) -> list[int]: - """Determines which indices should be processed. If max_samples is None - returns all dataset indices not in existing. Otherwise gets the first - `max_samples - len(existing)` samples not already in existing. - - Args: - num_samples: Total size of preprocessed dataset - max_samples: (Optional) limit for number of samples to process - existing: list of ids that have already been processed - - Returns: - list of dataset indices to process - """ - - if len(existing) >= num_samples: - logger.info("All samples already processed!") - return [] - if max_samples and len(existing) >= max_samples: - logger.info("At least max_samples already processed!") - return [] - - if len(existing) > 0: - logger.info(f"Found {len(existing)} existing samples.") - - existing_s = set(existing) - if max_samples is None: - return [i for i in range(num_samples) if i not in existing_s] - - num_remaining = min(max_samples, num_samples) - len(existing) - to_process = [] - cur = 0 - while num_remaining > 0 and cur < num_samples: - if cur not in existing_s: - to_process.append(cur) - num_remaining -= 1 - - cur += 1 - - return to_process - - -def check_safetensors_file(path: Path, tokens: list[int]): - with safe_open(path, "pt") as f: - t_ids = f.get_tensor("token_ids").tolist() - if t_ids != tokens: - raise ValueError( - f"Token ids in {path} don't match expected token ids {tokens}" - ) - - hs_slice = f.get_slice("hidden_states") - hs_shape = list(hs_slice.get_shape()) - if len(tokens) != hs_shape[0]: - raise ValueError( - f"Sequence length of hidden states {hs_shape[0]} in {path}" - f" doesn't match num tokens {len(tokens)}" - ) - - -async def worker( - client, - model: str, - queue: "asyncio.Queue[dict[str, Any]]", - pbar: tqdm, - vllm_semaphore: asyncio.Semaphore, - write_semaphore: asyncio.Semaphore, - hidden_states_output_dir: Path, - validate_outputs: bool, - request_timeout: float | None, - max_retries: int, - fail_on_error: bool, - skipped_indices: list[int], - cancel_event: asyncio.Event, - failure_tracker: _FailureTracker | None, -): - """Worker that pulls items from queue and sends them to the vLLM endpoint.""" - while True: - item = await queue.get() - if item is None: - queue.task_done() - return - - idx = item["idx"] - - # Drain remaining items quickly after cancellation - if cancel_event.is_set(): - queue.task_done() - continue - - input_ids = item["input_ids"].tolist() - - target_hidden_states_path = hidden_states_output_dir / f"hs_{idx}.safetensors" - - try: - async with vllm_semaphore: # Limit number of active generate calls - hidden_states_path = await generate_hidden_states_async( - client, - model, - input_ids, - timeout=request_timeout, - max_retries=max_retries, - ) - async with write_semaphore: # Limit number of active disk writes - await asyncio.to_thread( - shutil.move, hidden_states_path, target_hidden_states_path - ) - if validate_outputs: - await asyncio.to_thread( - check_safetensors_file, target_hidden_states_path, input_ids - ) - except Exception as e: - if fail_on_error: - logger.exception( - "Fatal: sample %d failed with --fail-on-error: %s", idx, e - ) - logging.shutdown() - os._exit(1) - logger.warning("Skipping sample %d due to error: %s", idx, e) - skipped_indices.append(idx) - if failure_tracker is not None and failure_tracker.record_failure(): - cancel_event.set() - raise RuntimeError( - f"Aborting: {failure_tracker.threshold} consecutive samples " - "failed. The vLLM server may be unreachable." - ) from e - else: - if failure_tracker is not None: - failure_tracker.record_success() - finally: - pbar.update(1) - queue.task_done() - - -async def _feed_queue(to_process, dataset, queue, cancel_event): - """Feed dataset items into the worker queue, respecting cancellation.""" - for i in to_process: - if cancel_event.is_set(): - break - item = dataset[i] - # Check cancel_event while waiting for queue space to avoid - # deadlocking when all workers have died. - while not cancel_event.is_set(): - try: - queue.put_nowait({"idx": i, "input_ids": item["input_ids"]}) - break - except asyncio.QueueFull: - await asyncio.sleep(0.1) - - -async def _shutdown_workers(workers, queue, cancel_event): - """Shut down workers and propagate the first real exception.""" - logger.info("Waiting for remaining file saves to complete...") - if cancel_event.is_set(): - # Workers may be dead or draining — cancel any that are - # still alive so we don't deadlock on sentinel puts. - for w in workers: - if not w.done(): - w.cancel() - else: - # Normal shutdown: send sentinel values so workers exit - for _ in range(len(workers)): - await queue.put(None) - results = await asyncio.gather(*workers, return_exceptions=True) - - # Propagate the first real worker exception (skip CancelledError) - for result in results: - if isinstance(result, Exception) and not isinstance( - result, asyncio.CancelledError - ): - raise result - - -async def generate_and_save_hidden_states(args, dataset): - if args.output is None: - hidden_states_dir = Path(args.preprocessed_data) / "hidden_states" - else: - hidden_states_dir = Path(args.output) - hidden_states_dir.mkdir(parents=True, exist_ok=True) - - existing_file_indices = get_existing_hidden_state_indices(hidden_states_dir) - num_samples = len(dataset) - - to_process = get_indices_to_process( - num_samples, args.max_samples, existing_file_indices - ) - if not to_process: - return - - logger.info(f"Processing {len(to_process)} samples") - - queue: asyncio.Queue = asyncio.Queue(maxsize=args.concurrency * 4) - vllm_semaphore = asyncio.Semaphore(args.concurrency) - write_semaphore = asyncio.Semaphore(args.concurrency) - - skipped_indices: list[int] = [] - cancel_event = asyncio.Event() - - max_consec = args.max_consecutive_errors - if max_consec is None: - max_consec = args.concurrency - failure_tracker = _FailureTracker(max_consec) if not args.fail_on_error else None - - async with openai.AsyncOpenAI( - base_url=args.endpoint, api_key="EMPTY", max_retries=0 - ) as client: - list_models = await client.models.list() - model_id = list_models.data[0].id - if args.model and args.model != model_id: - raise ValueError( - f"An explicit model name was passed ({args.model}) which doesn't match" - "found model_id {model_id}." - "Please make sure --endpoint is set to the correct vllm instance." - ) - - with tqdm(total=len(to_process)) as pbar: - workers = [ - asyncio.create_task( - worker( - client, - model_id, - queue, - pbar, - vllm_semaphore, - write_semaphore, - hidden_states_dir, - args.validate_outputs, - args.request_timeout, - args.max_retries, - args.fail_on_error, - skipped_indices, - cancel_event, - failure_tracker, - ) - ) - for _ in range(args.concurrency * 2) - ] - - await _feed_queue(to_process, dataset, queue, cancel_event) - await _shutdown_workers(workers, queue, cancel_event) - - num_saved = len(to_process) - len(skipped_indices) - logger.info(f"Saved {num_saved} new data points to {args.output}") - if skipped_indices: - logger.warning( - f"Skipped {len(skipped_indices)} samples due to errors: {skipped_indices}" - ) - - -def main(): - args = parse_args() - setup_root_logger() - - logger.info("EAGLE Offline Data Generation") - - dataset = load_from_disk(args.preprocessed_data) - - try: - asyncio.run(generate_and_save_hidden_states(args, dataset)) - except KeyboardInterrupt: - sys.exit(130) - except Exception: - logger.exception("Data generation failed") - sys.exit(1) - - logger.info("Data generation complete!") - - -if __name__ == "__main__": - main() diff --git a/scripts/gen_and_train.py b/scripts/gen_and_train.py deleted file mode 100644 index ce60ca231..000000000 --- a/scripts/gen_and_train.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -Combined EAGLE3 Data Generation and Training Pipeline - -This script is a convenience wrapper around the following scripts: - 1. scripts/data_generation_offline.py - 2. scripts/build_vocab_mapping.py - 3. scripts/train.py - -It can be used to run the full pipeline in one command. It also ensures each script is -run with the correct arguments and dependencies. - -Prerequisites: - - python 3.10+ - - uv (`pip install uv`) - -Usage: - Update arguments below. Then run: - python scripts/gen_and_train.py - - Note: You can call the script with environment variables (like - `CUDA_VISIBLE_DEVICES` and `HF_HOME`) to control the behavior of the scripts. -""" - -import enum -import os -import shutil -import subprocess -import sys -import time -from pathlib import Path -from typing import Any, NamedTuple - -import psutil -import torch - -from speculators.train.vocab_mapping import ( - combine_token_frequency_distributions, -) -from speculators.utils.util import is_npu_available - - -class _NS(enum.Enum): - """Class containing a sentinel value used to indicate unset arguments.""" - - # https://github.com/python/typing/issues/236#issuecomment-227180301 - value = 0 - - -_NOTSET = _NS.value # sentinel value - - -# Output structure: -# output_path/ -# gen/ -# / -# data_config.json -# data_0.pt -# data_1.pt -# ... -# / -# data_config.json -# data_0.pt -# data_1.pt -# ... -# ... -# vocab_mapping/ -# token_freq_.pt -# token_freq_.pt -# ... -# token_freq_combined.pt -# d2t.npy -# t2d.npy -# checkpoints/ -# 0/ -# config.json -# eagle3.py -# generation_config.json -# model.safetensors -# optimizer_state_dict.pt -# scheduler_state_dict.pt -# 1/ -# config.json -# eagle3.py -# generation_config.json -# model.safetensors -# optimizer_state_dict.pt -# scheduler_state_dict.pt -# ... -# logs/ - - -class DataGenArgs(NamedTuple): - """Arguments for data generation.""" - - train_data_path: str - """The path to the training data. Can be one of ["sharegpt", "ultrachat"] or a - huggingface dataset path or a local JSON/JSONL file.""" - dataset_name: str | None = None - """The name of the dataset to generate data for. Used exclusively for logging and - output path generation. If None and train_data_path is sharegpt or ultrachat, the - dataset name will be inferred from the train_data_path.""" - turn_dropout: bool = False - seq_length: int | _NS = _NOTSET - max_samples: int | _NS = _NOTSET - tensor_parallel_size: int | _NS = _NOTSET - gpu_memory_utilization: float | _NS = _NOTSET - hf_cache_dir: str | _NS = _NOTSET - layer_ids: list[int] | _NS = _NOTSET - batch_size: int | _NS = _NOTSET - seed: int | _NS = _NOTSET - start_idx: int | _NS = _NOTSET - num_preprocessing_workers: int | _NS = _NOTSET - - -class VocabMappingArgs(NamedTuple): - draft_vocab_size: int - target_vocab_size: int - - -class TrainArgs(NamedTuple): - run_name: str - logger: str | _NS = _NOTSET - lr: float | _NS = _NOTSET - total_seq_len: int | _NS = _NOTSET - ttt_steps: int | _NS = _NOTSET - epochs: int | _NS = _NOTSET - no_resume_from_checkpoint: bool | _NS = _NOTSET - num_layers: int | _NS = _NOTSET - ttt_step_loss_decay: float | _NS = _NOTSET - use_off_policy_tokens: bool | _NS = _NOTSET - scheduler_type: str | _NS = _NOTSET - scheduler_warmup_steps: int | _NS = _NOTSET - scheduler_total_steps: int | _NS = _NOTSET - scheduler_num_cosine_cycles: float | _NS = _NOTSET - norm_before_fc: bool | _NS = _NOTSET - - -### END OF SCRIPT ARGUMENTS ### - - -def prepare_args(args: dict[str, Any]) -> list[str]: - args_list = [] - for key, value in args.items(): - if value is _NOTSET: - continue - # Convert snake_case to kebab-case for command line arguments. - dashed_key = key.replace("_", "-") - # Handle boolean flags (action="store_true") - if isinstance(value, bool): - if value: - args_list.append(f"--{dashed_key}") - # If False, don't add the flag at all - else: - args_list.append(f"--{dashed_key}") - args_list.append(str(value)) - return args_list - - -def print_block(title: str, content: str): - title = f" {title} " - term_width, _terminal_height = shutil.get_terminal_size((80, 20)) - print( - "\n", - "#" * ((term_width - len(title)) // 2), - title, - "#" * ((term_width - len(title) + 1) // 2), - "\n", - sep="", - ) - print(content) - print("\n", "#" * term_width, "\n", sep="") - - -def run_script( - script_name: str, - script_args: list[str], - requires: list[str], - python_alt: str = "python", - use_uv: bool = True, -): - command = [] - if use_uv: - command = [ - "uv", - "run", - "--no-sync", - "--no-dev", - "--no-default-groups", - "--isolated", - ] - for i, package in enumerate(requires): - command.append("--with-editable" if i == 0 else "--with") - command.append(package) - - command.extend(python_alt.split()) - - script_path = (Path(__file__).parent / script_name).absolute() - command.append(str(script_path)) - command.extend(script_args) - - print_block(f"RUNNING {script_name}", " ".join(command)) - - start_time = time.perf_counter() - try: - process = subprocess.Popen(command, stdout=sys.stdout, stderr=sys.stderr) # noqa: S603 - process.wait() - except KeyboardInterrupt: - # Clean up subprocesses - print( - f"Received KeyboardInterrupt. Terminating process {process.pid} " - "and its children." - ) - end_time = time.perf_counter() - print_block( - f"CANCELLED {script_name}", - f"Time taken: {end_time - start_time:.2f} seconds", - ) - - for child in psutil.Process(process.pid).children(recursive=True): - child.terminate() - process.terminate() - - for _ in range(10): - remaining_children = list( - psutil.Process(process.pid).children(recursive=True) - ) - if not remaining_children: - break - time.sleep(1) - else: - print(f"Failed to terminate all children of process {process.pid}.") - print("Retrying...") - for child in psutil.Process(process.pid).children(recursive=True): - child.kill() # escalate to SIGKILL - process.kill() # escalate to SIGKILL - - sys.exit(1) - - end_time = time.perf_counter() - print_block( - f"COMPLETED {script_name}", - ( - f"Time taken: {end_time - start_time:.2f} seconds. " - f"Exit code: {process.returncode}" - ), - ) - - if process.returncode != 0: - raise subprocess.CalledProcessError(process.returncode, command) - - -def run_e2e( - verifier_name_or_path: str, - output_path: str, - data_gen_args: DataGenArgs | list[DataGenArgs], - vocab_mapping_args: VocabMappingArgs | None, - train_args: TrainArgs, -): - """Run the full pipeline in one command.""" - output_path = Path(output_path) - - # Data Generation - if isinstance(data_gen_args, DataGenArgs): - data_gen_args = [data_gen_args] - - token_freq_paths = [] - num_datasets = len(data_gen_args) - - for dga_obj in data_gen_args: - dga_dict = dga_obj._asdict() - dga_dict["target-model-path"] = verifier_name_or_path - - dataset_name = dga_dict["dataset_name"] - if dataset_name is None: - if dga_dict["train_data_path"] in ["sharegpt", "ultrachat"]: - dataset_name = dga_dict["train_data_path"] - else: - raise ValueError( - f"Dataset name is required for {dga_dict['train_data_path']}" - ) - del dga_dict[ - "dataset_name" - ] # Remove name so it isn't passed as argument to data_generation_offline.py - - token_freq_path = ( - output_path / "vocab_mapping" / f"token_freq_{dataset_name}.pt" - ) - dga_dict["token-freq-path"] = str(token_freq_path) - token_freq_paths.append(token_freq_path) - dga_dict["output-dir"] = str(output_path / "gen" / dataset_name) - - dga_list = prepare_args(dga_dict) - run_script( - "data_generation_offline.py", - dga_list, - [".[datagen]"], - use_uv=not is_npu_available(), - ) - - # Combine token frequency files from all datasets into a single file. - if num_datasets > 1: - combined_token_freq_path = ( - output_path / "vocab_mapping" / "token_freq_combined.pt" - ) - combine_token_frequency_distributions( - token_freq_paths, combined_token_freq_path - ) - else: - combined_token_freq_path = token_freq_paths[0] - - # Vocab Mapping (optional) - ta_dict = { - **train_args._asdict(), - "verifier-name-or-path": verifier_name_or_path, - "data-path": str(output_path / "gen"), - "save-path": str(output_path / "checkpoints"), - "log-dir": str(output_path / "logs"), - } - if vocab_mapping_args is not None: - vma_dict = vocab_mapping_args._asdict() - vma_dict["token-freq-path"] = str(combined_token_freq_path) - vma_dict["output-path"] = str(output_path / "vocab_mapping") - vma_list = prepare_args(vma_dict) - run_script( - "build_vocab_mapping.py", - vma_list, - [".[datagen]"], - use_uv=not is_npu_available(), - ) - ta_dict["d2t-path"] = str(output_path / "vocab_mapping" / "d2t.npy") - ta_dict["t2d-path"] = str(output_path / "vocab_mapping" / "t2d.npy") - - ta_list = prepare_args(ta_dict) - ta_list.append("--legacy-data") - - # Get additional packages to install if loggers are specified. - packages = ["."] - loggers = ta_dict["logger"] - if loggers and loggers is not _NOTSET: - if isinstance(loggers, str): - loggers = loggers.split(",") - loggers = [logger.strip() for logger in loggers] - packages.extend(loggers) - device_count = torch.accelerator.device_count() - - local_train_env = is_npu_available() or bool(os.environ.get("LOCAL_TRAIN_ENV", "")) - - run_script( - "train.py", - ta_list, - packages, - python_alt=f"torchrun --standalone --nproc_per_node={device_count}", - use_uv=not local_train_env, - ) diff --git a/src/speculators/data_generation/config_generator.py b/src/speculators/data_generation/config_generator.py deleted file mode 100644 index fbc9cd039..000000000 --- a/src/speculators/data_generation/config_generator.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Configuration generator for EAGLE data generation pipeline. - -Provides type-safe configuration generation with reproducibility tracking -and schema documentation. -""" - -from __future__ import annotations - -import sys -from dataclasses import asdict, dataclass, field -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar - -import torch -from transformers import AutoConfig - -from speculators.data_generation.logging_utils import PipelineLogger -from speculators.utils.util import get_device_name - -if TYPE_CHECKING: - from speculators.data_generation.vllm_hidden_states_generator import ( - VllmHiddenStatesGenerator, - ) - -__all__ = ["DataGenerationConfig", "PackageVersions"] - -log = PipelineLogger(__name__) - - -def _get_gpu_info() -> str: - """Get GPU information string. - - :return: GPU model and count or NPU model and count, - or "CPU only" if no GPU/NPU available - """ - device_count = torch.accelerator.device_count() - device_name = get_device_name(0) - if device_name == "NO ACCELERATOR": - return "CPU ONLY" - else: - return device_name if device_count == 1 else f"{device_count}x {device_name}" - - -@dataclass -class PackageVersions: - """Package versions for full reproducibility of data generation.""" - - torch: str - vllm: str - transformers: str - speculators: str - - @classmethod - def from_environment(cls) -> PackageVersions: - """Detect package versions from current environment. - - :return: PackageVersions with all detected versions - """ - from importlib.metadata import version # noqa: PLC0415 - - import transformers # noqa: PLC0415 - import vllm # noqa: PLC0415 - - return cls( - torch=torch.__version__, - vllm=vllm.__version__, - transformers=transformers.__version__, - speculators=version("speculators"), - ) - - -@dataclass -class ReproducibilityInfo: - """Information needed to reproduce the data generation run.""" - - command: str - package_versions: PackageVersions - gpu: str = field(default_factory=_get_gpu_info) - - -@dataclass -class ModelConfig: - """Model configuration for the target model.""" - - target_model_path: str - tensor_parallel_size: int - gpu_memory_utilization: float - hidden_size: int - - -@dataclass -class DataConfig: - """Dataset and preprocessing configuration.""" - - train_data_path: str - seq_length: int - max_samples: int | None - num_samples: int - seed: int - chat_template_note: str = "Uses tokenizer's built-in chat template" - - -@dataclass -class HiddenStatesConfig: - """Configuration for which hidden states to extract.""" - - layer_ids: list[int] - description: str = "Layers selected for EAGLE3 fusion and target logits" - - -@dataclass -class GenerationConfig: - """Runtime generation parameters.""" - - cache_dir: str - - -@dataclass -class FormatConfig: - """Output format specification for generated data files.""" - - file_pattern: str - schema: dict[str, dict[str, Any]] - - @classmethod - def create_default(cls, num_layers: int, hidden_size: int) -> FormatConfig: - """Create default format config with schema documentation. - - :param num_layers: Number of hidden state layers being saved - :param hidden_size: Dimension of each hidden state tensor - :return: FormatConfig with complete schema information - """ - return cls( - file_pattern="data_{idx}.pt", - schema={ - "input_ids": { - "dtype": "torch.long", - "shape": "[seq_len]", - "description": "Tokenized input sequence", - }, - "hidden_states": { - "dtype": "list[torch.bfloat16]", - "shape": f"list of [seq_len, {hidden_size}]", - "num_tensors": num_layers, - "description": f"Hidden states from {num_layers} layers", - }, - "loss_mask": { - "dtype": "torch.long", - "shape": "[seq_len]", - "description": "1 for assistant tokens to train on, 0 elsewhere", - }, - }, - ) - - -@dataclass -class DataGenerationConfig: - """Complete configuration for EAGLE data generation run. - - Saved alongside generated data for full reproducibility. - """ - - VERSION: ClassVar[str] = "2.0" - - version: str - generated_at: str - speculators_version: str - reproducibility: ReproducibilityInfo - model: ModelConfig - data: DataConfig - hidden_states: HiddenStatesConfig - generation: GenerationConfig - format: FormatConfig - - def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for JSON serialization. - - Handles Path objects by converting them to strings. - - :return: Dictionary representation of the config - """ - - def serialize_value(obj: Any) -> Any: - """Recursively convert Path objects to strings.""" - if isinstance(obj, Path): - return str(obj) - if isinstance(obj, dict): - return {k: serialize_value(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [serialize_value(item) for item in obj] - return obj - - config_dict = asdict(self) - return serialize_value(config_dict) - - @classmethod - def from_generator( - cls, - generator: VllmHiddenStatesGenerator, - train_data_path: str, - seq_length: int, - cache_dir: str, - num_samples: int, - max_samples: int | None = None, - seed: int = 0, - ) -> DataGenerationConfig: - """Create config from an initialized VllmHiddenStatesGenerator. - - :param generator: Initialized VllmHiddenStatesGenerator instance - :param train_data_path: Path or HF dataset name used for training data - :param seq_length: Maximum sequence length used in preprocessing - :param cache_dir: Directory where preprocessed data is cached - :param num_samples: Total number of samples generated - :param max_samples: Maximum samples to process (None = all) - :param seed: Random seed used - :return: Complete DataGenerationConfig ready to save as JSON - """ - log.subsection("Generating configuration metadata") - - package_versions = PackageVersions.from_environment() - log.info( - f"Packages: torch={package_versions.torch}, vllm={package_versions.vllm}" - ) - - hidden_size = _get_hidden_size_from_model(generator.model_path) - log.info(f"Hidden size: {hidden_size}") - log.info(f"GPU: {_get_gpu_info()}") - - config = cls( - version=cls.VERSION, - generated_at=datetime.now(timezone.utc).isoformat(), - speculators_version=package_versions.speculators, - reproducibility=ReproducibilityInfo( - command=" ".join([Path(sys.argv[0]).name, *sys.argv[1:]]), - package_versions=package_versions, - ), - model=ModelConfig( - target_model_path=generator.model_path, - tensor_parallel_size=generator.tensor_parallel_size, - gpu_memory_utilization=generator.vllm_config.cache_config.gpu_memory_utilization, - hidden_size=hidden_size, - ), - data=DataConfig( - train_data_path=train_data_path, - seq_length=seq_length, - max_samples=max_samples, - num_samples=num_samples, - seed=seed, - ), - hidden_states=HiddenStatesConfig(layer_ids=generator.layer_ids), - generation=GenerationConfig(cache_dir=cache_dir), - format=FormatConfig.create_default( - num_layers=len(generator.layer_ids), hidden_size=hidden_size - ), - ) - - log.success("Configuration generated") - return config - - -def _get_hidden_size_from_model(model_path: str) -> int: - """Extract hidden size from model config. - - :param model_path: HuggingFace model ID or local path - :return: Hidden state dimension - :raises ValueError: If hidden size cannot be determined - """ - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - - if hidden_size := getattr(config, "hidden_size", None): - return hidden_size - - if text_config := getattr(config, "text_config", None): - if hidden_size := getattr(text_config, "hidden_size", None): - return hidden_size - - raise ValueError( - f"Could not determine hidden size for {model_path}. " - f"Expected 'hidden_size' or 'text_config.hidden_size' attribute" - ) diff --git a/src/speculators/data_generation/custom_worker.py b/src/speculators/data_generation/custom_worker.py deleted file mode 100644 index 7b0e92063..000000000 --- a/src/speculators/data_generation/custom_worker.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Custom worker extension for hidden states capture.""" - -import logging -import types -from collections import defaultdict -from itertools import islice - -import torch -from vllm.distributed import get_pp_group, get_tp_group -from vllm.sequence import IntermediateTensors - -__all__ = ["HiddenStatesWorkerExtension"] - -logger = logging.getLogger(__name__) - - -def _patched_forward( - self, - input_ids, - positions, - intermediate_tensors: dict[str, torch.Tensor] | None = None, - inputs_embeds=None, - **_kwargs, -): - """Patched forward pass that captures hidden states from specified layers. - - This function is bound to base_model instances via types.MethodType. - It expects base_model to have an _extension attribute pointing to the - HiddenStatesWorkerExtension instance. - - Args: - deepstack_input_embeds: For multimodal models with deepstack (Qwen3VL) - """ - if get_pp_group().is_first_rank: - hidden_states = ( - inputs_embeds - if inputs_embeds is not None - else self.embed_input_ids(input_ids) - ) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - aux_hidden_states = [] - extension = self._extension # noqa: SLF001 - # Only capture on TP rank 0 to avoid duplicates - should_capture = get_tp_group().rank_in_group == 0 - target_layers = extension._layer_ids if should_capture else frozenset() # noqa: SLF001 - - for idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): - hidden_states, residual = layer( - hidden_states=hidden_states, positions=positions, residual=residual - ) - absolute_layer_idx = self.start_layer + idx - - # Capture intermediate layers (not the last) before norm - if absolute_layer_idx in target_layers: - aux_hidden_states.append((hidden_states + residual).clone()) - - # Return early if not last PP rank - if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} # type: ignore[dict-item] - ) - - hidden_states, _ = self.norm(hidden_states, residual) - if should_capture and aux_hidden_states: - extension._store_captured_states(aux_hidden_states) # noqa: SLF001 - - return hidden_states - - -class HiddenStatesWorkerExtension: - """Worker extension that adds hidden states capture functionality. - - This extension hooks into VLLM's Worker initialization by being specified - in ParallelConfig.worker_extension_cls. It patches the model's forward pass - to intercept and capture intermediate layer hidden states during inference. - - Key behaviors: - - Only captures on tensor parallel (TP) rank 0 to avoid duplicate data when - using tensor parallelism. All TP ranks compute the same hidden states, so - capturing from rank 0 is sufficient. - - Stores captured states in GPU memory during batch processing as lists of - tensors, concatenating them only when retrieved via _get_captured_states(). - - Supports pipeline parallelism by handling IntermediateTensors correctly. - - Attributes: - _layer_ids: Frozenset of layer indices for O(1) lookup during capture - _captured_states: Accumulated hidden states per layer (GPU tensors) - model_runner: Reference to the VLLM model runner - """ - - def _store_captured_states(self, aux_hidden_states): - if self._captured_states is None: # type: ignore[has-type] - self._captured_states = [[h] for h in aux_hidden_states] - else: - for i, h in enumerate(aux_hidden_states): - self._captured_states[i].append(h) - - metadata = getattr(self, "_current_request_metadata", None) - if metadata is not None: - # Sort by vLLM's actual batch position (vLLM reorders requests internally) - input_batch = self.model_runner.input_batch # type: ignore[attr-defined] - sorted_metadata = sorted( - metadata.items(), - key=lambda item: input_batch.req_id_to_index.get(item[0], float("inf")), - ) - self._request_metadata.append(sorted_metadata) # type: ignore[has-type] - - def _setup_hidden_states_capture(self, layer_ids: list[int]): - """Setup model to capture auxiliary hidden states from specific layers""" - self._layer_ids = frozenset(layer_ids) # Convert once for O(1) lookup - self._captured_states = None # type: ignore[assignment] - - model = self.model_runner.model # type: ignore[attr-defined] - - # Vision-language models - if hasattr(model, "get_language_model"): - base_model = model.get_language_model().model - # Text models - elif hasattr(model, "model") and hasattr(model.model, "layers"): - base_model = model.model - else: - attrs = [a for a in dir(model) if not a.startswith("_")] - raise AttributeError( - f"Could not find base model with 'layers' attribute. " - f"Model type: {type(model).__name__}, " - f"Available attributes: {attrs}" - ) - - base_model._extension = self # noqa: SLF001 - base_model.forward = types.MethodType(_patched_forward, base_model) - logger.info(f"Hidden states capture setup complete for layers {layer_ids}") - - def _set_request_metadata(self, request_metadata: dict[str, int]): - """Set request metadata for the next forward pass. - - Args: - request_metadata: Dict mapping request_id -> num_prefill_tokens - """ - self._current_request_metadata = request_metadata # type: ignore[assignment] - - def _reset_capture(self): - """Reset captured states before starting a new batch""" - if not hasattr(self, "_layer_ids"): - raise RuntimeError( - "Must call _setup_hidden_states_capture before capturing states" - ) - self._captured_states = None # type: ignore[assignment] - self._request_metadata = [] # type: ignore[assignment] - self._current_request_metadata = None # type: ignore[assignment] - - def _get_captured_states(self): - """Get the captured hidden states organized by request ID. - - Returns: - Dict mapping request_id to list of tensors (one per layer), - or None if no states captured. - - Track which tokens belong to which request across chunked prefill iterations. - """ - if self._captured_states is None: - return None - - # Concatenate captured states from all scheduler iterations - concatenated_layers = [ - torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states - ] - - # Slice and group by request - request_chunks: defaultdict[str, list[list[torch.Tensor]]] = defaultdict( - lambda: [[] for _ in range(len(concatenated_layers))] - ) - current_idx = 0 - - for metadata in self._request_metadata: # type: ignore[has-type] - for req_id, num_tok in metadata: - for layer_idx, layer_tensor in enumerate(concatenated_layers): - chunk = layer_tensor[current_idx : current_idx + num_tok].clone() - request_chunks[req_id][layer_idx].append(chunk) - current_idx += num_tok - - # Concatenate chunks for each request - result: dict[str, list[torch.Tensor]] = { - req_id: [torch.cat(chunks, dim=0) for chunks in layer_chunks] - for req_id, layer_chunks in request_chunks.items() - } - - # Clear intermediate storage - self._captured_states = None # type: ignore[assignment] - self._request_metadata = [] # type: ignore[assignment] - return result diff --git a/src/speculators/data_generation/vllm_hidden_states_generator.py b/src/speculators/data_generation/vllm_hidden_states_generator.py deleted file mode 100644 index 0c987aab7..000000000 --- a/src/speculators/data_generation/vllm_hidden_states_generator.py +++ /dev/null @@ -1,374 +0,0 @@ -"""Extract hidden states from intermediate layers during prefill using vLLM.""" - -import warnings -from typing import Literal - -import torch -from transformers import AutoConfig, AutoTokenizer -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoadConfig, - ModelConfig, - ParallelConfig, - SchedulerConfig, - VllmConfig, -) -from vllm.sampling_params import SamplingParams -from vllm.utils.hashing import get_hash_fn_by_name -from vllm.v1.core.kv_cache_utils import ( - _get_kv_cache_groups_uniform_spec, - get_kv_cache_config_from_groups, - get_request_block_hasher, - init_none_hash, - unify_hybrid_kv_cache_specs, -) -from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.executor.multiproc_executor import MultiprocExecutor -from vllm.v1.request import Request, RequestStatus -from vllm.v1.structured_output import StructuredOutputManager - -from speculators.utils.util import empty_cache, is_npu_available, mem_get_info - -from .logging_utils import PipelineLogger - -__all__ = ["VllmHiddenStatesGenerator"] - -# Constants -CACHE_MEMORY_FRACTION = 0.2 # Fraction of GPU memory for KV cache -VLLM_BLOCK_SIZE: Literal[1, 8, 16, 32, 64, 128, 256] = ( - 128 if is_npu_available() else 16 -) # Block size for KV cache -MAX_NUM_SEQS = 32 # Maximum sequences for prefill-only workload -MIN_MAX_BATCHED_TOKENS = 8192 # Minimum batched tokens threshold -MAX_DECODE_TOKENS = 1 # Maximum tokens to generate (prefill only) -SAMPLING_TEMPERATURE = 0.0 # Temperature for sampling (greedy) -INITIAL_ARRIVAL_TIME = 0.0 # Initial request arrival time - -log = PipelineLogger(__name__) - - -class VllmHiddenStatesGenerator: - """Extracts hidden states from intermediate layers during prefill only. - - This module provides a generator for extracting hidden states from - transformer models during the prefill phase using VLLM's inference engine. - It is designed for generating training data for speculative decoding models - like EAGLE3. - - The generator: - - Uses VLLM's multiprocess executor for efficient batch inference - - Patches model forward pass to capture intermediate layer hidden states - - Operates in prefill-only mode (max_tokens=1) for data generation - - Supports tensor parallelism for large models - - Automatically manages KV cache and memory allocation - - Example: - generator = VllmHiddenStatesGenerator( - model_path="meta-llama/Llama-3.1-8B-Instruct", - layer_ids=[10, 20, 30], - tensor_parallel_size=2 - ) - - results = generator.generate(token_ids) - for result in results: - input_ids = result["input_ids"] - hidden_states = result["hidden_states"] # List of tensors per layer` - """ - - def __init__( - self, - model_path: str, - layer_ids: list[int] | None = None, - max_model_len: int = 2048, - gpu_memory_utilization: float = 0.8, - tensor_parallel_size: int = 1, - max_num_batched_tokens: int | None = None, - ): - warnings.warn( - "VllmHiddenStatesGenerator and the associated data_generation_offline.py" - " script are deprecatd and will be removed shortly.", - DeprecationWarning, - stacklevel=2, - ) - self.model_path = model_path - self.tensor_parallel_size = tensor_parallel_size - self._request_counter = 0 - - log.info(f"Initializing hidden states generator for {model_path}") - log.info(f"Tensor parallel size: {tensor_parallel_size}") - - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - if hasattr(config, "num_hidden_layers"): - num_layers = config.num_hidden_layers - elif hasattr(config, "text_config"): - num_layers = config.text_config.num_hidden_layers - else: - raise ValueError("Cannot determine num_layers from config") - - log.info(f"Model has {num_layers} layers") - - if layer_ids is None: - self.layer_ids = [2, num_layers // 2, num_layers - 3, num_layers - 1] - log.info( - f"Auto-selected layers: {self.layer_ids} " - f"(from {num_layers} total layers)" - ) - else: - self.layer_ids = layer_ids - log.info(f"Using specified layers: {layer_ids}") - - for layer_id in self.layer_ids: - if layer_id < 0 or layer_id >= num_layers: - raise ValueError( - f"Layer index {layer_id} out of bounds [0, {num_layers - 1}]" - ) - - self.vllm_config = self._create_vllm_config( - model_path=model_path, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - tensor_parallel_size=tensor_parallel_size, - max_num_batched_tokens=max_num_batched_tokens, - ) - - log.info("Initializing executor...") - self.executor = MultiprocExecutor(vllm_config=self.vllm_config) - - log.info("Setting up hidden states capture...") - self._setup_capture() - - log.info("Creating scheduler...") - kv_cache_spec_list = self.executor.collective_rpc("get_kv_cache_spec") - kv_cache_spec = kv_cache_spec_list[0] - # Normalize hybrid KV cache specs for models with non-uniform attention - # (e.g., GPT-OSS with sliding/full attention layers) - unify_hybrid_kv_cache_specs(kv_cache_spec) - kv_cache_groups = _get_kv_cache_groups_uniform_spec(kv_cache_spec) - - free_memory, _ = mem_get_info() - cache_memory = int(free_memory * gpu_memory_utilization * CACHE_MEMORY_FRACTION) - - kv_cache_config = get_kv_cache_config_from_groups( - vllm_config=self.vllm_config, - kv_cache_groups=kv_cache_groups, - available_memory=cache_memory, - ) - - self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - structured_output_manager = StructuredOutputManager( - vllm_config=self.vllm_config - ) - - self.scheduler = Scheduler( - vllm_config=self.vllm_config, - kv_cache_config=kv_cache_config, - structured_output_manager=structured_output_manager, - block_size=VLLM_BLOCK_SIZE, - ) - - log.info("Initializing KV cache on all workers...") - kv_cache_configs = [kv_cache_config] * tensor_parallel_size - self.executor.initialize_from_config(kv_cache_configs) - - # Create block hasher for request KV cache management - # Following vLLM's pattern in v1/engine/core.py - caching_hash_fn = get_hash_fn_by_name( - self.vllm_config.cache_config.prefix_caching_hash_algo - ) - init_none_hash(caching_hash_fn) - - self.block_hasher = get_request_block_hasher( - self.vllm_config.cache_config.block_size, - caching_hash_fn, - ) - - def _create_vllm_config( - self, - model_path: str, - max_model_len: int, - gpu_memory_utilization: float, - tensor_parallel_size: int, - max_num_batched_tokens: int | None = None, - ) -> VllmConfig: - """Create VllmConfig with hidden states worker extension""" - cache_config = CacheConfig( - block_size=VLLM_BLOCK_SIZE, - gpu_memory_utilization=gpu_memory_utilization, - # disable to prevent cache state leakage - enable_prefix_caching=False, - ) - - # For prefill-only workloads, use conservative scheduler limits - # to reduce warmup memory allocation. max_num_seqs controls the - # warmup allocation size (see gpu_worker.py:441-444). - # We set it to a small value since we only do prefill in batches. - max_num_seqs = MAX_NUM_SEQS - if not max_num_batched_tokens: - max_num_batched_tokens = max(MIN_MAX_BATCHED_TOKENS, max_model_len) - - return VllmConfig( - model_config=ModelConfig( - model=model_path, - tokenizer=model_path, - trust_remote_code=True, - dtype="auto", - max_model_len=max_model_len, - enforce_eager=True, - ), - cache_config=cache_config, - parallel_config=ParallelConfig( - tensor_parallel_size=tensor_parallel_size, - worker_extension_cls="speculators.data_generation.custom_worker.HiddenStatesWorkerExtension", - ), - scheduler_config=SchedulerConfig( - max_num_seqs=max_num_seqs, - max_model_len=max_model_len, - max_num_batched_tokens=max_num_batched_tokens, - is_encoder_decoder=False, - ), - device_config=DeviceConfig(), - load_config=LoadConfig(), - ) - - def _setup_capture(self): - self.executor.collective_rpc( - "_setup_hidden_states_capture", - args=(self.layer_ids,), - ) - - def generate(self, token_ids: list[list[int]] | torch.Tensor) -> list[dict]: # noqa: PLR0912, PLR0915 - """Extract hidden states from prefill phase only. - - Args: - token_ids: Batch of token ID sequences as list[list[int]] or Tensor - - Returns: - List of dicts with keys: input_ids, hidden_states, loss_mask - """ - if isinstance(token_ids, torch.Tensor): - input_ids_list = token_ids.tolist() - else: - if not token_ids: - raise ValueError("token_ids cannot be empty") - input_ids_list = token_ids - - log.debug(f"Generating hidden states for {len(input_ids_list)} sequences") - # Account for max_tokens=1 in sampling params - # (vLLM enforces: len(prompt) + max_tokens <= max_model_len) - max_len = self.vllm_config.model_config.max_model_len - MAX_DECODE_TOKENS - input_ids_list = [ids[:max_len] for ids in input_ids_list] - - # Track request IDs and prompt lengths for proper token attribution - request_id_to_idx = {} - request_id_to_prompt_len = {} - - for i, ids in enumerate(input_ids_list): - # Ensure ids is a list (not tensor) for vLLM Request - ids_list = ids.tolist() if isinstance(ids, torch.Tensor) else ids - req_id = f"req_{self._request_counter}_{i}" - request_id_to_idx[req_id] = i - request_id_to_prompt_len[req_id] = len(ids_list) - - req = Request( - request_id=req_id, - prompt_token_ids=ids_list, - sampling_params=SamplingParams( - max_tokens=MAX_DECODE_TOKENS, temperature=SAMPLING_TEMPERATURE - ), - pooling_params=None, - eos_token_id=self.tokenizer.eos_token_id, - arrival_time=INITIAL_ARRIVAL_TIME, - block_hasher=self.block_hasher, - ) - self.scheduler.add_request(req) - - # Increment to ensure unique request IDs across calls - # (prevents KV cache corruption with delayed block freeing) - self._request_counter += 1 - self.executor.collective_rpc("_reset_capture") - - # Track progress for each request to distinguish prefill from decode - request_num_computed = dict.fromkeys(request_id_to_idx, 0) - schedule_iterations = 0 - all_prefill_complete = False - - while ( - scheduler_output := self.scheduler.schedule() - ).total_num_scheduled_tokens > 0 and not all_prefill_complete: - schedule_iterations += 1 - - # Calculate prefill tokens for each request (ignore decode tokens) - prefill_metadata = {} - for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items(): - num_already_computed = request_num_computed[req_id] - num_prompt = request_id_to_prompt_len[req_id] - num_prefill = max(0, min(num_tokens, num_prompt - num_already_computed)) - - if num_prefill > 0: - prefill_metadata[req_id] = num_prefill - - request_num_computed[req_id] += num_tokens - - all_prefill_complete = all( - request_num_computed[req_id] >= request_id_to_prompt_len[req_id] - for req_id in request_id_to_idx - ) - - if prefill_metadata: - self.executor.collective_rpc( - "_set_request_metadata", args=(prefill_metadata,) - ) - - model_output = self.executor.execute_model(scheduler_output) - self.executor.sample_tokens(model_output) - - # Abort all requests (prefill complete, don't need decode) - self.scheduler.finish_requests( - list(request_id_to_idx.keys()), RequestStatus.FINISHED_ABORTED - ) - - # Get captured states organized by request ID - request_states_dict = self.executor.collective_rpc( - "_get_captured_states", - unique_reply_rank=0, - ) - - if not request_states_dict: - raise RuntimeError("Failed to capture hidden states from worker") - - log.debug(f"Captured states for {len(request_states_dict)} requests") - - # Map results back to original input order - results = [] - for req_id in sorted(request_id_to_idx.keys()): - i = request_id_to_idx[req_id] - - if req_id not in request_states_dict: - raise RuntimeError( - f"Request {req_id} not found in captured states. " - f"Available: {list(request_states_dict.keys())}" - ) - - layer_states = [h.clone().cpu() for h in request_states_dict[req_id]] - input_ids_tensor = torch.as_tensor(input_ids_list[i], dtype=torch.long) - - results.append( - { - "input_ids": input_ids_tensor, - "hidden_states": layer_states, - "loss_mask": None, - } - ) - - empty_cache() - return results - - def __del__(self): - if hasattr(self, "executor"): - try: - self.executor.shutdown() - except Exception: - log.warning("Exception during executor shutdown") diff --git a/tests/datagen/test_config_generator.py b/tests/datagen/test_config_generator.py deleted file mode 100644 index 52144f815..000000000 --- a/tests/datagen/test_config_generator.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Unit tests for the config_generator module. -""" - -from __future__ import annotations - -from unittest import mock - -import pytest - -from speculators.data_generation.config_generator import DataGenerationConfig - -TRAIN_DATA_PATH = "sharegpt" -SEQ_LENGTH = 2048 -CACHE_DIR = "/cache" -NUM_SAMPLES = 100 - - -@pytest.fixture -def mock_vllm_generator(): - """Mock VllmHiddenStatesGenerator for testing.""" - generator = mock.MagicMock() - generator.model_path = "meta-llama/Llama-3.1-8B-Instruct" - generator.layer_ids = [2, 16, 29, 31] - generator.tensor_parallel_size = 1 - generator.vllm_config.cache_config.gpu_memory_utilization = 0.8 - return generator - - -def _create_model_config_fixture(hidden_size=None, text_config_hidden_size=None): - """Factory for creating model config mocks with different hidden_size.""" - config = mock.MagicMock() - config.hidden_size = hidden_size - - if text_config_hidden_size is not None: - config.text_config = mock.MagicMock() - config.text_config.hidden_size = text_config_hidden_size - else: - config.text_config = None - - return config - - -@pytest.fixture -def model_config_direct(): - """Create AutoConfig mock with direct hidden_size attribute.""" - config = _create_model_config_fixture(hidden_size=4096) - with mock.patch("transformers.AutoConfig.from_pretrained", return_value=config): - yield - - -@pytest.fixture -def model_config_text_config(): - """Create AutoConfig mock with text_config.hidden_size attribute.""" - config = _create_model_config_fixture(text_config_hidden_size=2048) - with mock.patch("transformers.AutoConfig.from_pretrained", return_value=config): - yield - - -@pytest.fixture -def model_config_missing(): - """Create AutoConfig mock with no hidden_size attribute.""" - config = _create_model_config_fixture() - with mock.patch("transformers.AutoConfig.from_pretrained", return_value=config): - yield - - -def create_config(generator): - """Helper to create config with consistent test parameters.""" - return DataGenerationConfig.from_generator( - generator=generator, - train_data_path=TRAIN_DATA_PATH, - seq_length=SEQ_LENGTH, - cache_dir=CACHE_DIR, - num_samples=NUM_SAMPLES, - ) - - -@pytest.mark.smoke -def test_config_from_generator_with_direct_hidden_size( - mock_vllm_generator, model_config_direct -): - """Test config extraction with direct hidden_size attribute.""" - config = create_config(mock_vllm_generator) - assert config.model.hidden_size == 4096 - - -@pytest.mark.smoke -def test_config_from_generator_with_text_config_hidden_size( - mock_vllm_generator, model_config_text_config -): - """Test config extraction with text_config.hidden_size attribute.""" - config = create_config(mock_vllm_generator) - assert config.model.hidden_size == 2048 - - -@pytest.mark.smoke -def test_config_from_generator_extracts_all_settings( - mock_vllm_generator, model_config_direct -): - """Test config extracts all settings from generator.""" - config = create_config(mock_vllm_generator) - - assert config.model.target_model_path == mock_vllm_generator.model_path - assert config.hidden_states.layer_ids == mock_vllm_generator.layer_ids - assert config.model.tensor_parallel_size == mock_vllm_generator.tensor_parallel_size - - -@pytest.mark.smoke -def test_config_tracks_reproducibility_metadata( - mock_vllm_generator, model_config_direct -): - """Test config captures reproducibility metadata.""" - config = create_config(mock_vllm_generator) - - assert config.reproducibility.command - assert config.reproducibility.package_versions - assert config.reproducibility.gpu - - -@pytest.mark.sanity -def test_config_from_generator_fails_with_missing_hidden_size( - mock_vllm_generator, model_config_missing -): - """Test config generation fails with helpful error when hidden_size not found.""" - with pytest.raises(ValueError, match="Could not determine hidden size"): - create_config(mock_vllm_generator) diff --git a/tests/datagen/test_vllm_hidden_states.py b/tests/datagen/test_vllm_hidden_states.py deleted file mode 100644 index 4ff26f1ed..000000000 --- a/tests/datagen/test_vllm_hidden_states.py +++ /dev/null @@ -1,349 +0,0 @@ -"""Tests for vLLM hidden states generator accuracy against HuggingFace baseline.""" - -import gc -import logging -import os -import time - -import pytest -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from speculators.data_generation.vllm_hidden_states_generator import ( - VllmHiddenStatesGenerator, -) - -logger = logging.getLogger(__name__) - -# Set vLLM multiprocessing method to spawn for CUDA compatibility -# Must be set before vLLM imports to avoid CUDA re-initialization errors -os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - - -@pytest.fixture(autouse=True) -def cleanup_memory(): - """Fixture to clean up GPU memory before and after each test.""" - # Cleanup before test - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - yield # Run the test - - # Cleanup after test - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - time.sleep(1) # Give time for cleanup - - -@pytest.mark.regression -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize( - ("model_path", "tensor_parallel_size"), - [ - ("Qwen/Qwen2-0.5B", 1), - ], -) -def test_vllm_vs_huggingface_accuracy(model_path, tensor_parallel_size): - """Test vLLM hidden states match HuggingFace baseline within tolerance.""" - - test_prompts = [ - ( - "The future of artificial intelligence is bright and full " - "of possibilities that will transform humanity." - ), - ( - "In a world where technology advances rapidly, we must " - "carefully consider the ethical implications." - ), - ] - - logger.info("=" * 80) - logger.info(f"Testing {model_path}") - logger.info(f"Prompts: {len(test_prompts)}") - logger.info("=" * 80) - - # HuggingFace baseline Implementation, adapted from research/eagle3/ge_data - logger.info("[1/2] Running HuggingFace baseline...") - tokenizer = AutoTokenizer.from_pretrained(model_path) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - hf_model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - trust_remote_code=True, - ).to("cuda") # type: ignore[arg-type] - num_layers = len(hf_model.model.layers) - logger.info(f"Model has {num_layers} layers") - - inputs = tokenizer( - test_prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=2048, - ).to(hf_model.device) - logger.info(f"Input shape: {inputs['input_ids'].shape}") - - with torch.no_grad(): - hf_output = hf_model(**inputs, output_hidden_states=True) - - # Extract layers using EAGLE3 pattern - # Feature fusion: layers 2, num_layers//2, num_layers-3 (before norm) - # Excluding the last layer (after norm) which has different behavior - expected_layer_ids = [2, num_layers // 2, num_layers - 3] - hf_layers = [ - hf_output.hidden_states[3], # layer 2 (before norm) - hf_output.hidden_states[ - num_layers // 2 + 1 - ], # layer num_layers//2 (before norm) - hf_output.hidden_states[num_layers - 2], # layer num_layers-3 (before norm) - ] - - hf_concat = torch.cat(hf_layers, dim=-1).cpu() - logger.info(f"HuggingFace layers {expected_layer_ids}: {hf_concat.shape}") - - # Cleanup HuggingFace model - aggressive cleanup - del hf_model, hf_output, hf_layers, inputs, tokenizer - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() - gc.collect() - time.sleep(3) - - logger.info( - f"GPU memory freed, available: {torch.cuda.mem_get_info()[0] / 1024**3:.2f} GiB" - ) - - # 2. vLLM implementation - logger.info("[2/2] Running vLLM implementation...") - # Only test feature fusion layers (before norm), exclude the last layer (after norm) - test_layer_ids = [2, num_layers // 2, num_layers - 3] - generator = VllmHiddenStatesGenerator( - model_path=model_path, - layer_ids=test_layer_ids, - max_model_len=2048, - gpu_memory_utilization=0.3, # Conservative to avoid OOM after HF cleanup - tensor_parallel_size=tensor_parallel_size, - ) - - try: - # Tokenize prompts for vLLM (current implementation expects token_ids) - # IMPORTANT: Use the SAME tokenizer that was used for HuggingFace - # to ensure identical tokenization - vllm_tokenizer = AutoTokenizer.from_pretrained(model_path) - if vllm_tokenizer.pad_token is None: - vllm_tokenizer.pad_token = vllm_tokenizer.eos_token - - # Tokenize with padding to match HuggingFace behavior - vllm_inputs = vllm_tokenizer( - test_prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=2048, - ) - token_ids_batch = vllm_inputs["input_ids"].tolist() - - vllm_results = generator.generate(token_ids=token_ids_batch) - if not isinstance(vllm_results, list): - vllm_results = [vllm_results] - - vllm_concat_per_seq = [] - for r in vllm_results: - seq_concat = torch.cat(r["hidden_states"], dim=-1) - vllm_concat_per_seq.append(seq_concat) - vllm_concat = torch.stack(vllm_concat_per_seq).cpu() - logger.info(f"vLLM layers {expected_layer_ids}: {vllm_concat.shape}") - - # Check layer IDs before cleanup - actual_layer_ids = generator.layer_ids - finally: - del generator - gc.collect() - torch.cuda.empty_cache() - time.sleep(1) - - # Verify layer IDs - assert actual_layer_ids == expected_layer_ids, ( - f"Layer IDs mismatch! Got {actual_layer_ids}, expected {expected_layer_ids}" - ) - - # Verify shapes - assert hf_concat.shape == vllm_concat.shape, ( - f"Shape mismatch! HF: {hf_concat.shape}, vLLM: {vllm_concat.shape}" - ) - - # Verify EAGLE3 output format - for result in vllm_results: - assert "input_ids" in result - assert "hidden_states" in result - assert "loss_mask" in result - assert isinstance(result["hidden_states"], list) - for layer_state in result["hidden_states"]: - assert layer_state.shape[0] == result["input_ids"].shape[0], ( - "Sequence length mismatch" - ) - - # Numerical comparison - max_diff = torch.abs(hf_concat - vllm_concat).max().item() - mean_diff = torch.abs(hf_concat - vllm_concat).mean().item() - logger.info(f"Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}") - - assert mean_diff < 0.02, ( - f"Mean difference {mean_diff} too large. " - f"Expected layer_ids={expected_layer_ids}" - ) - - -@pytest.mark.regression -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize( - ("model_path", "tensor_parallel_size"), - [ - ("Qwen/Qwen2-0.5B", 1), - ], -) -def test_batch_vs_individual_consistency( # noqa: C901 - model_path, tensor_parallel_size -): - """Test that batch processing matches individual processing. - - Regression test for GitHub issue #279: VllmHiddenStatesGenerator returns - silently wrong hidden states with batch_size > 1 or repeated calls. - - This test verifies: - 1. No KV cache state leakage between calls (Bug 1) - 2. Correct token ordering in chunked prefill (Bug 2) - """ - # 8 distinct prompts of varying length to trigger chunked prefill - test_prompts = [ - "What is 2+2?", - "Explain the theory of relativity in simple terms.", - "Write a haiku about the ocean.", - "What are the main differences between Python and JavaScript?", - "Hello!", - "Translate 'good morning' to French, Spanish, and German.", - "What is the capital of Brazil?", - "Describe the process of photosynthesis step by step.", - ] - - logger.info(f"Testing batch vs individual consistency: {model_path}") - - # Initialize generator with aggressive chunking to properly test the fix - # This forces multi-iteration chunked prefill which exposes token ordering bugs - generator = VllmHiddenStatesGenerator( - model_path=model_path, - layer_ids=[10], # Single layer for faster testing - max_model_len=2048, - gpu_memory_utilization=0.3, - tensor_parallel_size=tensor_parallel_size, - max_num_batched_tokens=100, # Force chunking: ~212 tokens / 100 = 3 iterations - ) - - try: - # Tokenize prompts - tokenizer = AutoTokenizer.from_pretrained(model_path) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # Use chat template if available, otherwise just tokenize - all_ids: list[list[int]] = [] - for text in test_prompts: - try: - # Try chat template first (for instruct models) - msgs = [{"role": "user", "content": text}] - ids: torch.Tensor | dict[str, torch.Tensor] = ( - tokenizer.apply_chat_template( - msgs, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - padding=False, - ) # type: ignore[assignment] - ) - if isinstance(ids, dict): - ids = ids["input_ids"] - assert isinstance(ids, torch.Tensor) # typing - all_ids.append(ids.squeeze(0).tolist()) - except (ValueError, AttributeError): - # Fallback for base models without chat template - ids = tokenizer(text, return_tensors="pt")["input_ids"] - assert isinstance(ids, torch.Tensor) # typing - all_ids.append(ids.squeeze(0).tolist()) - - seq_lens = [len(ids) for ids in all_ids] - logger.info(f"Sequence lengths: {seq_lens}") - logger.info(f"Total tokens: {sum(seq_lens)}") - - # --- Ground truth: process each sequence individually --- - logger.info("Processing sequences individually...") - individual_results = [] - for i, i_ids in enumerate(all_ids): - results = generator.generate([i_ids]) - individual_results.append(results[0]) - hs = results[0]["hidden_states"][0] - logger.info( - f" Seq {i}: input_len={seq_lens[i]:3d}, hs_shape={list(hs.shape)}" - ) - - # --- Batch processing --- - logger.info("Processing all sequences as batch...") - batch_results = generator.generate(all_ids) - - # --- Verify results match --- - misaligned = 0 - empty = 0 - for i in range(len(all_ids)): - individual_hs = individual_results[i]["hidden_states"][0] - batch_hs = batch_results[i]["hidden_states"][0] - - expected_shape = list(individual_hs.shape) - got_shape = list(batch_hs.shape) - - # Check for empty results - if batch_hs.numel() == 0: - empty += 1 - logger.error(f" Seq {i}: EMPTY (bug reproduced)") - continue - - # Check for shape mismatch - if got_shape != expected_shape: - misaligned += 1 - logger.error( - f" Seq {i}: WRONG SHAPE " - f"(got {got_shape}, expected {expected_shape})" - ) - continue - - # Check for value mismatch - if individual_hs.shape[0] > 0 and batch_hs.shape[0] > 0: - mean_diff = torch.abs(individual_hs - batch_hs).mean().item() - - if mean_diff > 0.01: # Tolerance for numerical differences - misaligned += 1 - logger.error(f" Seq {i}: WRONG VALUES (mean_diff={mean_diff:.6f})") - continue - - # Assert no errors - total_errors = empty + misaligned - assert total_errors == 0, ( - f"Batch processing returned wrong hidden states: " - f"{empty} empty, {misaligned} misaligned out of {len(all_ids)} sequences. " - f"This indicates bug #279 regression." - ) - - logger.info( - f"SUCCESS: All {len(all_ids)} sequences matched between " - f"individual and batch processing" - ) - - finally: - del generator - gc.collect() - torch.cuda.empty_cache() - time.sleep(1) diff --git a/tests/e2e/regression/test_eagle3_offline_acceptance.py b/tests/e2e/regression/test_eagle3_offline_acceptance.py index dbf509c85..30b83b3be 100644 --- a/tests/e2e/regression/test_eagle3_offline_acceptance.py +++ b/tests/e2e/regression/test_eagle3_offline_acceptance.py @@ -3,7 +3,7 @@ Exercises the full offline pipeline: 1. Prepare data (scripts/prepare_data.py) 2. Launch a vLLM server for hidden-state extraction (scripts/launch_vllm.py) - 3. Generate hidden states offline (scripts/data_generation_offline2.py) + 3. Generate hidden states offline (scripts/data_generation_offline.py) 4. Stop the vLLM server 5. Train a draft model using pre-generated hidden states (scripts/train.py) 6. Validate the trained checkpoint via vLLM inference (run_vllm_engine) diff --git a/tests/e2e/smoke/test_offline_training.py b/tests/e2e/smoke/test_offline_training.py index 8f3c212b7..70582cebe 100644 --- a/tests/e2e/smoke/test_offline_training.py +++ b/tests/e2e/smoke/test_offline_training.py @@ -3,7 +3,7 @@ Exercises the full offline pipeline: 1. Prepare data (scripts/prepare_data.py) 2. Launch a vLLM server for hidden-state extraction (scripts/launch_vllm.py) - 3. Generate hidden states offline (scripts/data_generation_offline2.py) + 3. Generate hidden states offline (scripts/data_generation_offline.py) 4. Stop the vLLM server 5. Train a draft model using pre-generated hidden states (scripts/train.py) 6. Validate the trained checkpoint via vLLM inference (run_vllm_engine) @@ -15,7 +15,7 @@ from tests.e2e.utils import ( launch_vllm_server_context, - run_data_generation_offline2, + run_data_generation_offline, run_prepare_data, run_training, run_vllm_engine, @@ -90,7 +90,7 @@ def run_offline_e2e( target_layer_ids=target_layer_ids, ): # Step 2: Generate hidden states offline - run_data_generation_offline2( + run_data_generation_offline( data_path, offline_hidden_states, port, diff --git a/tests/e2e/smoke/test_resume_optimizer.py b/tests/e2e/smoke/test_resume_optimizer.py index 3dc0d90ea..c887d89e9 100644 --- a/tests/e2e/smoke/test_resume_optimizer.py +++ b/tests/e2e/smoke/test_resume_optimizer.py @@ -21,7 +21,7 @@ from tests.e2e.utils import ( SCRIPTS_DIR, launch_vllm_server_context, - run_data_generation_offline2, + run_data_generation_offline, run_prepare_data, ) @@ -86,7 +86,7 @@ def test_resume_after_checkpoint_best(tmp_path: Path): # Step 2: Generate hidden states offline with launch_vllm_server_context(MODEL, VLLM_PORT, str(tmp_path / "hidden_states")): - run_data_generation_offline2(data_path, hidden_states_path, port=VLLM_PORT) + run_data_generation_offline(data_path, hidden_states_path, port=VLLM_PORT) # Step 3: Train 1 epoch with --save-best result = _run_distributed_training( diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index a34e09a14..0746f01e9 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -17,7 +17,7 @@ "VLLM_PYTHON", "launch_vllm_server", "launch_vllm_server_context", - "run_data_generation_offline2", + "run_data_generation_offline", "run_prepare_data", "run_training", "run_vllm_engine", @@ -164,7 +164,7 @@ def run_prepare_data( assert result.returncode == 0, "prepare_data.py failed" -def run_data_generation_offline2( +def run_data_generation_offline( data_path: Path, hidden_states_path: Path | None = None, port: int = 8321, @@ -176,7 +176,7 @@ def run_data_generation_offline2( ): datagen_cmd = [ sys.executable, - str(SCRIPTS_DIR / "data_generation_offline2.py"), + str(SCRIPTS_DIR / "data_generation_offline.py"), "--preprocessed-data", str(data_path), "--endpoint", @@ -199,7 +199,7 @@ def run_data_generation_offline2( datagen_cmd, stderr=subprocess.PIPE, text=True, check=False, timeout=timeout ) assert result.returncode == 0, ( - f"data_generation_offline2.py failed:\n{result.stderr}" + f"data_generation_offline.py failed:\n{result.stderr}" ) From fee40d0528d8af5a2413fab3912285802cd2a05e Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 17 Apr 2026 16:14:17 +0000 Subject: [PATCH 2/7] min diff Signed-off-by: shanjiaz --- tests/e2e/smoke/test_resume_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/smoke/test_resume_optimizer.py b/tests/e2e/smoke/test_resume_optimizer.py index c887d89e9..ba12fae54 100644 --- a/tests/e2e/smoke/test_resume_optimizer.py +++ b/tests/e2e/smoke/test_resume_optimizer.py @@ -75,7 +75,6 @@ def _run_distributed_training( @pytest.mark.e2e @pytest.mark.slow -@requires_multi_gpu def test_resume_after_checkpoint_best(tmp_path: Path): data_path = tmp_path / "data" hidden_states_path = tmp_path / "offline_hidden_states" From 3945770cca7ca75891ec3364a1ca122cc5f52691 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 20 Apr 2026 17:51:27 +0000 Subject: [PATCH 3/7] move datagen tests to integration Signed-off-by: shanjiaz --- tests/{ => integration}/datagen/__init__.py | 0 tests/{ => integration}/datagen/test_preprocessing.py | 0 tests/{ => integration}/datagen/test_regex_patterns.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => integration}/datagen/__init__.py (100%) rename tests/{ => integration}/datagen/test_preprocessing.py (100%) rename tests/{ => integration}/datagen/test_regex_patterns.py (100%) diff --git a/tests/datagen/__init__.py b/tests/integration/datagen/__init__.py similarity index 100% rename from tests/datagen/__init__.py rename to tests/integration/datagen/__init__.py diff --git a/tests/datagen/test_preprocessing.py b/tests/integration/datagen/test_preprocessing.py similarity index 100% rename from tests/datagen/test_preprocessing.py rename to tests/integration/datagen/test_preprocessing.py diff --git a/tests/datagen/test_regex_patterns.py b/tests/integration/datagen/test_regex_patterns.py similarity index 100% rename from tests/datagen/test_regex_patterns.py rename to tests/integration/datagen/test_regex_patterns.py From deaab3a1e20764ccd134a85881d7a88499d08346 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 21 Apr 2026 00:14:34 +0000 Subject: [PATCH 4/7] min diff Signed-off-by: shanjiaz --- tests/e2e/smoke/test_resume_optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e/smoke/test_resume_optimizer.py b/tests/e2e/smoke/test_resume_optimizer.py index ba12fae54..c887d89e9 100644 --- a/tests/e2e/smoke/test_resume_optimizer.py +++ b/tests/e2e/smoke/test_resume_optimizer.py @@ -75,6 +75,7 @@ def _run_distributed_training( @pytest.mark.e2e @pytest.mark.slow +@requires_multi_gpu def test_resume_after_checkpoint_best(tmp_path: Path): data_path = tmp_path / "data" hidden_states_path = tmp_path / "offline_hidden_states" From 38df3d3c0b2711dc1c725d254569c879167475ab Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 21 Apr 2026 00:17:47 +0000 Subject: [PATCH 5/7] fix doc failure Signed-off-by: shanjiaz --- docs/scripts/gen_files.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/scripts/gen_files.py b/docs/scripts/gen_files.py index 499081586..3e07c4691 100644 --- a/docs/scripts/gen_files.py +++ b/docs/scripts/gen_files.py @@ -151,12 +151,6 @@ def migrate_developer_docs(): title="Evaluate", weight=4, ), - ProcessFile( - root_path=Path("scripts/README.md"), - docs_path=Path("train.md"), - title="Train", - weight=-8, - ), ProcessFile( root_path=Path("scripts/response_regeneration/README.md"), docs_path=Path("response_regeneration.md"), From aeb42aa6b1a63b69c55b144574905c1767eb08d6 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 21 Apr 2026 00:26:31 +0000 Subject: [PATCH 6/7] fix links Signed-off-by: shanjiaz --- docs/scripts/gen_files.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/scripts/gen_files.py b/docs/scripts/gen_files.py index 3e07c4691..0ea911ea8 100644 --- a/docs/scripts/gen_files.py +++ b/docs/scripts/gen_files.py @@ -151,6 +151,12 @@ def migrate_developer_docs(): title="Evaluate", weight=4, ), + ProcessFile( + root_path=Path("examples/ONLINE_TRAINING.md"), + docs_path=Path("train.md"), + title="Train", + weight=-8, + ), ProcessFile( root_path=Path("scripts/response_regeneration/README.md"), docs_path=Path("response_regeneration.md"), From 6c7c09334bd0104a829c17a38d2fe589b0dc0933 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 21 Apr 2026 14:17:16 +0000 Subject: [PATCH 7/7] remove vllm dependency Signed-off-by: shanjiaz --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3426391a8..ecee14f71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,8 +111,6 @@ dev = [ "mkdocs-linkcheck~=1.0.6", ] -datagen = ["vllm>=0.12.0,<=0.16.0"] - [project.entry-points.console_scripts] speculators = "speculators.__main__:app"