diff --git a/tests/unit_tests/worker/test_hpu_model_runner.py b/tests/unit_tests/worker/test_hpu_model_runner.py index 57431da52a..36fc80cc26 100644 --- a/tests/unit_tests/worker/test_hpu_model_runner.py +++ b/tests/unit_tests/worker/test_hpu_model_runner.py @@ -3,6 +3,7 @@ import pytest import torch +from types import SimpleNamespace import habana_frameworks.torch # noqa: F401 from habana_frameworks.torch.utils.internal import is_lazy from vllm.model_executor.model_loader import get_model @@ -645,3 +646,58 @@ def assert_compilation(model, layer_name, module): assert_compilation(model, "lm_head", VocabParallelEmbedding) assert_compilation(model, "model.decoder.final_layer_norm", LayerNorm) assert_compilation(model, "model.decoder.embed_tokens", VocabParallelEmbedding) + + +def test_max_cudagraph_capture_size_defaults_to_max_num_batched_tokens(model_runner): + """max_cudagraph_capture_size defaults to max_num_batched_tokens when not configured.""" + assert model_runner.max_cudagraph_capture_size == model_runner.max_num_batched_tokens + + +def test_max_cudagraph_capture_size_uses_explicit_value(): + """max_cudagraph_capture_size uses the configured value when explicitly set.""" + vllm_config = get_vllm_config() + vllm_config.compilation_config.max_cudagraph_capture_size = 256 + with set_current_vllm_config(vllm_config): + environment.set_vllm_config(vllm_config) + num_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(num_heads, head_size, 0.1) + runner = HPUModelRunner(vllm_config, DEVICE) + assert runner.max_cudagraph_capture_size == 256 + + +@pytest.mark.parametrize( + "is_prompt,batch_size,seq_len,num_blocks,block_size,max_capture,expected", + [ + # Prefill within limits → use graphs + (True, 1, 128, 0, 128, 512, True), + # Prefill exceeding limits → skip graphs + (True, 1, 256, 4, 128, 512, False), + # Prefill at exact boundary → use graphs + (True, 1, 256, 2, 128, 512, True), + # Prefill just over boundary → skip graphs + (True, 1, 256, 2, 128, 511, False), + # Decode never skips graphs even with many tokens + (False, 256, 1, 100, 128, 512, True), + # Decode with many blocks → still use graphs + (False, 64, 1, 1000, 128, 512, True), + ]) +def test_use_graphs(model_runner, is_prompt, batch_size, seq_len, num_blocks, block_size, max_capture, expected): + model_runner.max_cudagraph_capture_size = max_capture + attn_metadata = SimpleNamespace(is_prompt=is_prompt, + block_size=block_size, + seq_len=lambda: seq_len, + num_blocks=lambda: num_blocks) + result = model_runner._use_graphs(attn_metadata, batch_size) + assert result == expected + + +def test_use_graphs_enforce_eager(model_runner): + """When enforce_eager is set, never use graphs.""" + orig = model_runner.model_config.enforce_eager + try: + model_runner.model_config.enforce_eager = True + attn_metadata = SimpleNamespace(is_prompt=False, block_size=128, seq_len=lambda: 1, num_blocks=lambda: 0) + assert model_runner._use_graphs(attn_metadata, 1) is False + finally: + model_runner.model_config.enforce_eager = orig diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index caafb5e0b0..edd2034a30 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1207,14 +1207,15 @@ def __init__( self.use_hpu_graph = not self.model_config.enforce_eager self.max_batch_size = self.scheduler_config.max_num_seqs self.max_num_seqs = self.scheduler_config.max_num_seqs - self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size if prompt_profile_cfg: self.max_prefill_batch_size = prompt_profile_cfg[0] else: self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1) self.seen_configs: set = set() - self.max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size + if self.max_cudagraph_capture_size is None: + self.max_cudagraph_capture_size = self.max_num_batched_tokens self.use_prefix_caching = (self.vllm_config.cache_config.enable_prefix_caching) self.bucketing_manager = HPUBucketingManager() max_num_prefill_seqs = self.max_num_seqs if self.use_merged_prefill \ @@ -3264,9 +3265,7 @@ def _execute_model_generic(self, self._check_config(batch_size, seq_len, num_blocks, attn_metadata, warmup_mode) additional_kwargs = {} if htorch.utils.internal.is_lazy(): - use_graphs = self._use_graphs() - if self.max_cudagraph_capture_size is not None and batch_size * seq_len > self.max_cudagraph_capture_size: - use_graphs = False + use_graphs = self._use_graphs(attn_metadata, batch_size) additional_kwargs.update({"bypass_hpu_graphs": not use_graphs}) else: # no hpu graphs for t.compile? @@ -4583,8 +4582,20 @@ def _compile_region(self, model, name, module): def _compile(self, module): return torch.compile(module, **self.compile_config.get_compile_args()) - def _use_graphs(self): - return not self.model_config.enforce_eager + def _use_graphs(self, attn_metadata, batch_size): + if self.model_config.enforce_eager: + return False + # skip HPU graphs for long (query + context) prefills + if attn_metadata is not None and attn_metadata.is_prompt: + seq_len = attn_metadata.seq_len() + num_blocks = attn_metadata.num_blocks() + total_tokens = (batch_size * seq_len + num_blocks * attn_metadata.block_size) + if total_tokens > self.max_cudagraph_capture_size: + logger.debug_once(f"Skipping HPU graph capture for prompt with [bs, query, num_blocks] = " + f"[{batch_size}, {seq_len}, {num_blocks}] due to total token count " + f"{total_tokens} exceeding the threshold of {self.max_cudagraph_capture_size}.") + return False + return True def _get_model_layers(self): """Return the decoder layers from the model, handling both