skip HPU graphs for long (query + context) prefills#1346
Conversation
There was a problem hiding this comment.
Pull request overview
Adjusts the HPU graph-bypass heuristic in HPUModelRunner to avoid capturing graphs for compute-/memory-heavy long prefills, reducing unnecessary memory footprint and OOM risk during long-context runs.
Changes:
- Default
max_cudagraph_capture_sizetomax_num_batched_tokenswhen unset. - Update the “skip graphs” threshold logic to account for context in addition to query length.
dbfd71e to
ab8831d
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
Signed-off-by: copilot <copilot@github.com> Tests cover the four PRs addressing long-context bucketing: - PR #762: Padding-aware bucketing strategy (warmup ranges, configs, generation) - PR #1122: Exponential decode block formula, limit cap, filter, linear fix - PR #1155: FusedSDPA slicing contract (pad_max bounds, strategy selection) - PR #1346: HPU graph capture skip (cudagraph size, warmup clamp scenarios) - Cross-PR integration: end-to-end 256K scenario, fallback, regressions 49 test functions organized in 6 test classes. Co-authored-by: michalkuligowski <23379006+michalkuligowski@users.noreply.github.com>
ab8831d to
ea94bc0
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
Remove all production code changes from PRs #1122, #1155, #1346 and keep only the two test files created for issue #1347: - tests/unit_tests/test_bucketing_issue_1347.py - tests/unit_tests/test_bucketing_warmup_time.py Signed-off-by: GitHub Copilot <copilot@github.com> Co-authored-by: michalkuligowski <23379006+michalkuligowski@users.noreply.github.com>
ea94bc0 to
c297bae
Compare
c297bae to
7689975
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
|
||
| def _use_graphs(self): | ||
| return not self.model_config.enforce_eager | ||
| def _use_graphs(self, attn_metadata=None, batch_size=0): |
There was a problem hiding this comment.
why batch_size=0 as default? Someone can forget this and get weird values
There was a problem hiding this comment.
Fixed by removing the default values, thanks.
| if attn_metadata is not None and attn_metadata.is_prompt: | ||
| seq_len = attn_metadata.seq_len() | ||
| num_blocks = attn_metadata.num_blocks() | ||
| total_tokens = (batch_size * seq_len + num_blocks * attn_metadata.block_size) |
There was a problem hiding this comment.
I don't understand this. Isn't total tokens num_blocks * block_size (with padding included)? Same for batch_size * seq_len
There was a problem hiding this comment.
No, num_blocks * block_size is the total context tokens, and batch_size * seq_len is the total query tokens. Take the chunked prefill for a prompt with 10240 tokens, max_num_batch_tokens=8192 and block_size=128 as an example:
- The prefill
[bs, seq_len, num_blocks]for the first chunk will be[1, 8192, 0]. - For the second chunk it will be
[1, 2048, 8192/128=64], where2048is the query length (q_lenin FSDPA) and8192is the context length (kv_len - q_lenin FSDPA).
5e19e06 to
99261f8
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
| if attn_metadata is not None and attn_metadata.is_prompt: | ||
| seq_len = attn_metadata.seq_len() | ||
| num_blocks = attn_metadata.num_blocks() | ||
| total_tokens = (batch_size * seq_len + num_blocks * attn_metadata.block_size) |
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
99261f8 to
7d8fc7d
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
Motivation
Current implementation use the following logic to skip the HPU graphs for long prefills:
vllm-gaudi/vllm_gaudi/v1/worker/hpu_model_runner.py
Lines 3105 to 3106 in bcff6c8
While:
self.max_cudagraph_capture_sizeis not set by default,batch_size * seq_lenwhich miss the context length which is comparable or even larger than the query length as the chunked-prefill and APC is enabled by default.Those lead to unnecessary HPU graphs for the compute-bound long prefills and introduce much more memory footprint which may cause OOM crash.
Changes:
self.max_cudagraph_capture_sizetoself.max_num_batched_tokensif it is not set,