Skip to content

Update data generation to use vLLM extract_hidden_states#353

Merged
fynnsu merged 17 commits intomainfrom
update_datagen
Mar 26, 2026
Merged

Update data generation to use vLLM extract_hidden_states#353
fynnsu merged 17 commits intomainfrom
update_datagen

Conversation

@fynnsu
Copy link
Copy Markdown
Collaborator

@fynnsu fynnsu commented Mar 19, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

This pr adds support for vLLM's extract_hidden_states system introduced in vllm-project/vllm#33736. It also includes a major rework of our data generation system and the way our data is formatted on disk.

These changes enable online training (also introduced by this pr), offline training, and hybrid methods.

Description

Data format

As discussed in #335, some changes to the data format are required to enable online training. The previous system generated data files containing aligned token ids, loss masks, and hidden states and saved them all to disk. The new system instead separates the hidden states from the token ids and loss masks.

I initially proposed in the RFC that the data files could contain a file path pointing to the cached hidden states (if they exist). Although this works well, it requires updating the files as the data is generated. We also need 1 file per data sample (2 in the offline case) stored on disk, which is a substantial number of files.

To combat this, I looked into different data formats for storing the data and realized that if we simply store the hidden states files in a file with a name linked to the data's index (e.g. hs_{i}.safetensors) then we don't need to explicitly add the file path. And since we're not updating the files manually, there is no need for simultaneous read/writes, and we can store the entire preprocessed data (just the token ids + loss masks) in a single arrow dataset, which is already been produced as an intermediate step during generation. That is what the new system uses.

Training flow

Abridged version. See examples/ONLINE_TRAINING.md for more details.

  1. Prepare data
python scripts/prepare_data.py --model Qwen/Qwen3-8B --data sharegpt --output ./output
  1. Launch vLLM
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/launch_vllm.py Qwen/Qwen3-8B -- --data-parallel-size 4 --port 8000
  1. (Optional) Run offline hidden states generation
python scripts/data_generation_offline2.py --endpoint http://localhost:8000/v1 --preprocessed-data ./output --validate-outputs
  1. Run training
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --standalone --nproc_per_node 4 scripts/train.py --verifier-name-or-path Qwen/Qwen3-8B --data-path ./output --vllm-endpoint http://localhost:8000/v1 --save-path ./output/checkpoint --draft-model-size 32000

Online training enablement

The following options were added to scripts/train.py to enable online/hybrid generation

  • --hidden-states-path. Defaults to args.data_path / "hidden_states". Where to look for and/or cache hidden states.
  • --vllm-endpoint. Defaults to http://localhost:8000/v1. The location for the running configured vLLM instance for generating new states. Only required if on_missing="generate" and hidden states are missing.
  • --on-missing. Choices ["generate", "skip", "warn", "raise"]. Defaults to generate. What the data loader should do if it doesn't find a hidden state for a sample.
  • --on-generate Choices ["cache", "delete"]. Defaults to delete. What should be done with newly generated hidden states after loading.
  • --legacy-data. Stores True. Use the old data system with individual files for each sample. Will be deprecated and removed.

The following options were added to scripts/train.py to enable auto vocab mapping generation

  • --token-freq-path. Defaults to args.data_path / "token_freq.pt" Token frequency to load.
  • --draft-vocab-size, Defaults to None. If not provided, then we will use the full verifier vocab (same as existing behavior) but also output a warning.

Related Issue

#335

Tests

I've run online training with both --data-parallel-size and tensor-parallel-size set for 1 epoch (similar to instruction provided in example readme). We will need to add / update e2e tests for this new pipeline (but we may delay this for a future pr).

I have filled in:

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan/results, such as providing test command and pasting the results.
  • (Optional) The necessary documentation update.
  • I (a human) have written or reviewed the code in this pr to the best of my ability.

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 19, 2026
@fynnsu fynnsu force-pushed the update_datagen branch 2 times, most recently from 0b0932d to 519b1b3 Compare March 19, 2026 14:47
@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 19, 2026

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/23611849228/artifacts/6130486965.
They will be retained for up to 30 days.
Commit: f1f6dff

@fynnsu
Copy link
Copy Markdown
Collaborator Author

fynnsu commented Mar 19, 2026

Performance benchmarks for offline data generation

Qwen/Qwen3-8B, 5000 sharegpt samples, max seq length 8192

New system

Usage:

# vllm env
CUDA_VISIBLE_DEVICES=0 python scripts/launch_vllm.py Qwen/Qwen3-8
# speculators env
python scripts/data_generation_offline2.py --model Qwen/Qwen3-8B --preprocessed-data ./output/

1*A100
17 mins 8 secs generation time

1*A100 w/ vllm-project/vllm#37374 and --no-enable-chunked-prefill
7 mins 3 secs generation time

8*A100, data parallel 8
2 mins 33 secs generation time

8*A100 w/ vllm-project/vllm#37374 and --no-enable-chunked-prefill
2 mins 6 secs generation time

Old system

CUDA_VISIBLE_DEVICES=0 python scripts/data_generation_offline.py --target-model-path Qwen/Qwen3-8B --train-data-path sharegpt --seq-length 8192 --output-dir ./output2 --max-samples 5000

1*A100
15 mins 16 secs generation time

8*A100, data parallel 8
N/A, data parallel not officially supported.

Notes

The new system is currently slower on a single gpu. We expect this to improve once async hidden states writing is added to vllm. Confirmed that the draft async connector pr does more than half the generation time on a single gpu. 8 gpus sees a much smaller improvement but this is likely because the disk is already saturated at that level.

Multi-gpu was not officially supported previously. It was possible to launch multiple separate instances with different target indices as a work around, and I believe this did work to some degree but also crashed some servers, so this is not considered a supported approach.

There was no performance tuning done for either implementation.

@dsikka dsikka mentioned this pull request Mar 24, 2026
16 tasks
Copy link
Copy Markdown
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for covering this in deep dive. The changes make sense to me, just one suggestion on cli args

Copy link
Copy Markdown
Collaborator

@shanjiaz shanjiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be merge conflicts now. Discussed most of these design choices offline, all comments I had have been addressed. Looks great!

fynnsu added 6 commits March 26, 2026 17:20
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
We're going to keep this script / implementation until a new version of vllm is released that supports the new offline system.

At that stage we will remove this script.

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fynnsu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2026
fynnsu added 8 commits March 26, 2026 17:35
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
- Add openai>=2.0.0 as main dependency in pyproject.toml
- Move httpx logging suppression to setup_root_logger() for reusability
- Update default seq_length in prepare_data.py from 2048 to 8192 to match train.py


Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This is a convenvience wrapper script for launching vllm for hidden states extraction.

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Instructions for running online training in speculators.

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
fynnsu added 3 commits March 26, 2026 17:35
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Copy link
Copy Markdown
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

@fynnsu fynnsu merged commit 2a1443c into main Mar 26, 2026
12 checks passed
@fynnsu fynnsu deleted the update_datagen branch March 26, 2026 20:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants