diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 7d3d392d35..c2799faeed 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -783,14 +783,14 @@ def __init__( self.use_hpu_graph = not self.model_config.enforce_eager self.max_batch_size = self.scheduler_config.max_num_seqs self.max_num_seqs = self.scheduler_config.max_num_seqs - self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size if prompt_profile_cfg: self.max_prefill_batch_size = prompt_profile_cfg[0] else: self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1) self.seen_configs: set = set() - self.max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + self.max_graph_capture_tokens = self.vllm_config.compilation_config.max_cudagraph_capture_size if \ + self.vllm_config.compilation_config.max_cudagraph_capture_size is not None else self.max_num_batched_tokens self.use_prefix_caching = (self.vllm_config.cache_config.enable_prefix_caching) self.bucketing_manager = HPUBucketingManager() max_num_prefill_seqs = self.max_num_seqs if self.use_merged_prefill \ @@ -2575,7 +2575,9 @@ def _execute_model_generic(self, additional_kwargs = {} if htorch.utils.internal.is_lazy(): use_graphs = self._use_graphs() - if self.max_cudagraph_capture_size is not None and batch_size * seq_len > self.max_cudagraph_capture_size: + # skip HPU graphs for long prefills + if seq_len > 1 and \ + batch_size * (seq_len + num_blocks * self.block_size) > self.max_graph_capture_tokens: use_graphs = False additional_kwargs.update({"bypass_hpu_graphs": not use_graphs}) else: