Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/unit_tests/worker/test_hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from types import SimpleNamespace
import habana_frameworks.torch # noqa: F401
from habana_frameworks.torch.utils.internal import is_lazy
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -645,3 +646,58 @@ def assert_compilation(model, layer_name, module):
assert_compilation(model, "lm_head", VocabParallelEmbedding)
assert_compilation(model, "model.decoder.final_layer_norm", LayerNorm)
assert_compilation(model, "model.decoder.embed_tokens", VocabParallelEmbedding)


def test_max_cudagraph_capture_size_defaults_to_max_num_batched_tokens(model_runner):
"""max_cudagraph_capture_size defaults to max_num_batched_tokens when not configured."""
assert model_runner.max_cudagraph_capture_size == model_runner.max_num_batched_tokens


def test_max_cudagraph_capture_size_uses_explicit_value():
"""max_cudagraph_capture_size uses the configured value when explicitly set."""
vllm_config = get_vllm_config()
vllm_config.compilation_config.max_cudagraph_capture_size = 256
with set_current_vllm_config(vllm_config):
environment.set_vllm_config(vllm_config)
num_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(num_heads, head_size, 0.1)
runner = HPUModelRunner(vllm_config, DEVICE)
assert runner.max_cudagraph_capture_size == 256


@pytest.mark.parametrize(
"is_prompt,batch_size,seq_len,num_blocks,block_size,max_capture,expected",
[
# Prefill within limits → use graphs
(True, 1, 128, 0, 128, 512, True),
# Prefill exceeding limits → skip graphs
(True, 1, 256, 4, 128, 512, False),
# Prefill at exact boundary → use graphs
(True, 1, 256, 2, 128, 512, True),
# Prefill just over boundary → skip graphs
(True, 1, 256, 2, 128, 511, False),
# Decode never skips graphs even with many tokens
(False, 256, 1, 100, 128, 512, True),
# Decode with many blocks → still use graphs
(False, 64, 1, 1000, 128, 512, True),
])
def test_use_graphs(model_runner, is_prompt, batch_size, seq_len, num_blocks, block_size, max_capture, expected):
model_runner.max_cudagraph_capture_size = max_capture
attn_metadata = SimpleNamespace(is_prompt=is_prompt,
block_size=block_size,
seq_len=lambda: seq_len,
num_blocks=lambda: num_blocks)
result = model_runner._use_graphs(attn_metadata, batch_size)
assert result == expected


def test_use_graphs_enforce_eager(model_runner):
"""When enforce_eager is set, never use graphs."""
orig = model_runner.model_config.enforce_eager
try:
model_runner.model_config.enforce_eager = True
attn_metadata = SimpleNamespace(is_prompt=False, block_size=128, seq_len=lambda: 1, num_blocks=lambda: 0)
assert model_runner._use_graphs(attn_metadata, 1) is False
finally:
model_runner.model_config.enforce_eager = orig
27 changes: 19 additions & 8 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,14 +1207,15 @@ def __init__(
self.use_hpu_graph = not self.model_config.enforce_eager
self.max_batch_size = self.scheduler_config.max_num_seqs
self.max_num_seqs = self.scheduler_config.max_num_seqs
self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size
if prompt_profile_cfg:
self.max_prefill_batch_size = prompt_profile_cfg[0]
else:
self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1)
self.seen_configs: set = set()
self.max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size
if self.max_cudagraph_capture_size is None:
self.max_cudagraph_capture_size = self.max_num_batched_tokens
self.use_prefix_caching = (self.vllm_config.cache_config.enable_prefix_caching)
self.bucketing_manager = HPUBucketingManager()
max_num_prefill_seqs = self.max_num_seqs if self.use_merged_prefill \
Expand Down Expand Up @@ -3264,9 +3265,7 @@ def _execute_model_generic(self,
self._check_config(batch_size, seq_len, num_blocks, attn_metadata, warmup_mode)
additional_kwargs = {}
if htorch.utils.internal.is_lazy():
use_graphs = self._use_graphs()
if self.max_cudagraph_capture_size is not None and batch_size * seq_len > self.max_cudagraph_capture_size:
use_graphs = False
use_graphs = self._use_graphs(attn_metadata, batch_size)
additional_kwargs.update({"bypass_hpu_graphs": not use_graphs})
else:
# no hpu graphs for t.compile?
Expand Down Expand Up @@ -4583,8 +4582,20 @@ def _compile_region(self, model, name, module):
def _compile(self, module):
return torch.compile(module, **self.compile_config.get_compile_args())

def _use_graphs(self):
return not self.model_config.enforce_eager
def _use_graphs(self, attn_metadata, batch_size):
if self.model_config.enforce_eager:
return False
# skip HPU graphs for long (query + context) prefills
if attn_metadata is not None and attn_metadata.is_prompt:
seq_len = attn_metadata.seq_len()
num_blocks = attn_metadata.num_blocks()
total_tokens = (batch_size * seq_len + num_blocks * attn_metadata.block_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. Isn't total tokens num_blocks * block_size (with padding included)? Same for batch_size * seq_len

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, num_blocks * block_size is the total context tokens, and batch_size * seq_len is the total query tokens. Take the chunked prefill for a prompt with 10240 tokens, max_num_batch_tokens=8192 and block_size=128 as an example:

  • The prefill [bs, seq_len, num_blocks] for the first chunk will be [1, 8192, 0].
  • For the second chunk it will be [1, 2048, 8192/128=64], where 2048 is the query length (q_len in FSDPA) and8192 is the context length (kv_len - q_len in FSDPA).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, ty

if total_tokens > self.max_cudagraph_capture_size:
Comment thread
kamil-kaczor marked this conversation as resolved.
logger.debug_once(f"Skipping HPU graph capture for prompt with [bs, query, num_blocks] = "
f"[{batch_size}, {seq_len}, {num_blocks}] due to total token count "
f"{total_tokens} exceeding the threshold of {self.max_cudagraph_capture_size}.")
return False
return True

def _get_model_layers(self):
"""Return the decoder layers from the model, handling both
Expand Down
Loading