fix: piecewise_cuda_graph get correct qo_indptr#21452
fix: piecewise_cuda_graph get correct qo_indptr#21452Fridge003 merged 8 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
|
There is performance regression with this PR |
|
Run the test locally with h100 on tp=1 and tp=8 and gsm test passes |
Oasis-Git
left a comment
There was a problem hiding this comment.
In general the change is reasonable. Here is some suggestions for revision.
| num_tokens = len(forward_batch.input_ids) | ||
| index = bisect.bisect_left(self.capture_num_tokens, num_tokens) | ||
| static_num_tokens = self.capture_num_tokens[index] | ||
| with enable_piecewise_cuda_graph(num_tokens=static_num_tokens): |
There was a problem hiding this comment.
I think we can move num_tokens into the ForwardContext. Also to skip the computation and sync with item(), it is suggested that the var such as num_dummy_pages should be pre-calculated
There was a problem hiding this comment.
Hi I take your suggestions to update the code:
- Added self.num_tokens: Optional[int] = None field to ForwardContext
- Eliminated both .item() GPU-CPU syncs in the dummy-request block
|
with |
Motivation
#21218
Modifications
for padding tokens, append a fake bs+1-th request with pad_tokens extend tokens whose KV indices all point to scratch slot 0. This makes qo_indptr[-1] = static_num_tokens, without affecting causal masks for real requests.
Accuracy Tests
python -m sglang.launch_server --model-path Qwen/Qwen3-14B --attention-backend flashinfer --disable-cuda-graph
python3 benchmark/gsm8k/bench_sglang.py --num-questions 100
100%|███████████████████████████████| 100/100 [00:05<00:00, 18.24it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 5.524 s
Output throughput: 2255.160 token/s
enable cuda graph:
python3 benchmark/gsm8k/bench_sglang.py --num-questions 100
100%|█████████████████████████████| 100/100 [00:04<00:00, 24.05it/s]
Accuracy: 0.940
Invalid: 0.000
Latency: 4.198 s
Output throughput: 2968.234 token/s
(flashinfer_bench) averyh@umb-b200-238:~/flashinfer-bench/tmp/sglang$ python3 benchmark/gsm8k/bench_sglang.py --num-questions 100
100%|███████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.82it/s]
Accuracy: 0.930
Invalid: 0.000
Latency: 4.035 s
Output throughput: 3180.599 token/s
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci