Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Feb 27, 2025

This PR aims to enable the support of the V1 Sampler for TPUModelRunner.

TL;DR: topk/topp is way too slow now, but we should be able to interface the other options with the current Sampler code.

The idea is to add the sampling operations to the model xla graph. Since re-utilizing a subgraph (the model fwd) is not supported in xla (this could've been useful to compile a bunch of different sampling graphs on top of the model), a single traced sampling function should be recorded.
This entails we have to provide default values to all sampling options so that we can toggle off the unneeded ones by evaluating them to an Identity operation (eg top_p=1.0).

The Sampler's code is mostly already compatible for being compiled into an xla graph, apart from a few minor options: logit_bias , min_tokens , all_greedy.
The main issue is in executing the topk/topp code on TPU, as it results in a significant slowdown in both compilation/inti and run time. For now, I've resolved to disable both options until we have some faster kernel.

The rest of the PR is just making sure to have an xla-friendly SamplingMetadata interface, which avoids recompiling optional parameters and handles the "padding to a pre-compiled shape" that we have on TPUs.

Test sampling with VLLM_USE_V1=1 python -m pytest -s tests/v1/tpu/test_sampler.py for "~jit" execution or VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct --max-model-len 1024 --max-num-seqs 256 for pre-compiling.

INFO 02-27 18:18:39 [tpu_model_runner.py:1009] Compiling the model with different input shapes for prefill:

INFO 02-27 18:19:08 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 16
INFO 02-27 18:19:08 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 32
INFO 02-27 18:19:09 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 64
INFO 02-27 18:19:09 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 128
INFO 02-27 18:19:10 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 256
INFO 02-27 18:19:10 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 512
INFO 02-27 18:19:11 [tpu_model_runner.py:1020]   batch_size: 1, seq_len: 1024
INFO 02-27 18:19:11 [tpu_model_runner.py:1028]     -- Compilation for prefill done in 31.66 [secs].
INFO 02-27 18:19:11 [tpu_model_runner.py:1057] Compiling the model with different input shapes for decode:
INFO 02-27 18:19:21 [tpu_model_runner.py:1068]   batch_size: 8, seq_len: 1
INFO 02-27 18:19:31 [tpu_model_runner.py:1068]   batch_size: 16, seq_len: 1
INFO 02-27 18:19:43 [tpu_model_runner.py:1068]   batch_size: 32, seq_len: 1
INFO 02-27 18:19:54 [tpu_model_runner.py:1068]   batch_size: 48, seq_len: 1
INFO 02-27 18:20:06 [tpu_model_runner.py:1068]   batch_size: 64, seq_len: 1
INFO 02-27 18:20:17 [tpu_model_runner.py:1068]   batch_size: 80, seq_len: 1
INFO 02-27 18:20:28 [tpu_model_runner.py:1068]   batch_size: 96, seq_len: 1
INFO 02-27 18:20:41 [tpu_model_runner.py:1068]   batch_size: 112, seq_len: 1
INFO 02-27 18:20:52 [tpu_model_runner.py:1068]   batch_size: 128, seq_len: 1
INFO 02-27 18:21:03 [tpu_model_runner.py:1068]   batch_size: 144, seq_len: 1
INFO 02-27 18:21:17 [tpu_model_runner.py:1068]   batch_size: 160, seq_len: 1
INFO 02-27 18:21:28 [tpu_model_runner.py:1068]   batch_size: 176, seq_len: 1
INFO 02-27 18:21:40 [tpu_model_runner.py:1068]   batch_size: 192, seq_len: 1
INFO 02-27 18:21:51 [tpu_model_runner.py:1068]   batch_size: 208, seq_len: 1
INFO 02-27 18:22:05 [tpu_model_runner.py:1068]   batch_size: 224, seq_len: 1
INFO 02-27 18:22:17 [tpu_model_runner.py:1068]   batch_size: 240, seq_len: 1
INFO 02-27 18:22:28 [tpu_model_runner.py:1068]   batch_size: 256, seq_len: 1
INFO 02-27 18:22:28 [tpu_model_runner.py:1075]     -- Compilation for decode done in 197.66 [secs].

TODO:

  • test sampling param produce different outputs
  • enable more options. Probably in following PRs to better check impact on performance too.
UPDATE: On topk/topp performace

I've run some micro-benchmarks on the current topk/topp implementation to highlight differences in time.
It runs a random linear layer (simulated lm_head) followed by sampling. Code here, feel free to double check the scripts.

vLLM sampler on TPUv6 (torch-xla):

   python vllm/tpu-sampler-benchmark/sampler_microbenchmark.py
INFO 03-03 11:20:11 [__init__.py:207] Automatically detected platform tpu.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Compiling/Warmup 1 elapsed time: 4.442102432250977
Compiling/Warmup 4 elapsed time: 4.5806968212127686
Compiling/Warmup 16 elapsed time: 4.375824928283691
Compiling/Warmup 32 elapsed time: 4.597050428390503
Running 1 elapsed time: 0.02094721794128418
Running 1 elapsed time: 0.020517826080322266
Running 1 elapsed time: 0.020459890365600586
Running 1 elapsed time: 0.02043437957763672
Running 4 elapsed time: 0.0777747631072998
Running 4 elapsed time: 0.07848191261291504
Running 4 elapsed time: 0.07846784591674805
Running 4 elapsed time: 0.07887482643127441
Running 16 elapsed time: 0.25572896003723145
Running 16 elapsed time: 0.2562577724456787
Running 16 elapsed time: 0.25624656677246094
Running 16 elapsed time: 0.2562131881713867
Running 32 elapsed time: 0.5071756839752197
Running 32 elapsed time: 0.5076909065246582
Running 32 elapsed time: 0.5076868534088135
Running 32 elapsed time: 0.5077009201049805

vLLM sampler on CUDA, no flashinfer (forward_native) on L4:

INFO 03-03 10:41:46 [__init__.py:207] Automatically detected platform cuda.
WARNING 03-03 10:41:46 [topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
Compiling/Warmup 1 elapsed time: 0.8291933536529541
Compiling/Warmup 4 elapsed time: 0.10055041313171387
Compiling/Warmup 16 elapsed time: 0.013581037521362305
Compiling/Warmup 32 elapsed time: 0.013128995895385742
Running 1 elapsed time: 0.008679866790771484
Running 1 elapsed time: 0.008601188659667969
Running 1 elapsed time: 0.008600711822509766
Running 1 elapsed time: 0.008601903915405273
Running 4 elapsed time: 0.009422063827514648
Running 4 elapsed time: 0.009502649307250977
Running 4 elapsed time: 0.009497880935668945
Running 4 elapsed time: 0.00953054428100586
Running 16 elapsed time: 0.01054692268371582
Running 16 elapsed time: 0.010563373565673828
Running 16 elapsed time: 0.01054239273071289
Running 16 elapsed time: 0.01055598258972168
Running 32 elapsed time: 0.013098955154418945
Running 32 elapsed time: 0.013221025466918945
Running 32 elapsed time: 0.013271808624267578

I also benchmarked a barebone jax implementation of the sampling, for comparison.

JAX on same TPU

python vllm/tpu-sampler-benchmark/jax_proto_sampler.py     
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]


Compiling/Warmup 1 40.11301302909851
Compiling/Warmup 4 19.129047632217407
Compiling/Warmup 16 17.657644033432007
Compiling/Warmup 32 17.789708852767944
Running 1 0.0032927989959716797
Running 1 0.003267049789428711
Running 1 0.0032567977905273438
Running 1 0.0032553672790527344
Running 4 0.006731271743774414
Running 4 0.006708860397338867
Running 4 0.0067102909088134766
Running 4 0.006703615188598633
Running 16 0.022034168243408203
Running 16 0.02204728126525879
Running 16 0.022046566009521484
Running 16 0.022042274475097656
Running 32 0.04318523406982422
Running 32 0.04320979118347168
Running 32 0.043207406997680664
Running 32 0.04319953918457031

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: NickLucche <[email protected]>
@mergify
Copy link

mergify bot commented Mar 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 3, 2025
@hyeygit
Copy link
Contributor

hyeygit commented Mar 4, 2025

Following up on our discussion earlier, this PR looks good to me; we'll work on improving the performance as a subsequent step.

@NickLucche
Copy link
Collaborator Author

I will have to update the logic a bit following the ragged attn kernel. I'll post an update asap.

@NickLucche
Copy link
Collaborator Author

Closing in favor of updated #14227.
cc @hyeygit if you want to take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants