diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index caafb5e0b0..f67132a837 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -3830,20 +3830,20 @@ def execute_model( return None def set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): - if attn_metadata is None or not attn_metadata.is_prompt: + if (attn_metadata is None + or (self.prefill_use_fusedsdpa and self.is_causal and attn_metadata.block_list is None) + or not attn_metadata.is_prompt): return attn_metadata - # FusedSDPA can handle a purely causal mask natively via - # is_causal=True + valid_seq_lengths, including chunked prefill where - # block_list is non-None. Skipping the materialised - # [bs, 1, q_len, total_kv_len] attn_bias avoids a large add_bf16 on the - # attention critical path (significant at long context). The original - # short-circuit only covered block_list is None; extend it to all - # plain-causal cases (no sliding-window / no chunked-attention / no - # alibi / not pooling). - # Conservative scope: only non-GDN hybrid models (e.g. Granite-4 - # Mamba2+Transformer). GDN / pure-transformer / other topologies keep - # the materialised bias path until validated. + # Extended FSDPA-native causal short-circuit for non-GDN hybrid models + # (e.g. Granite-4 Mamba2+Transformer). FusedSDPA can encode a purely + # causal mask natively via is_causal=True + valid_seq_lengths, including + # chunked prefill where block_list is non-None. Skipping the + # materialised [bs, 1, q_len, total_kv_len] attn_bias avoids a large + # add_bf16 on the attention critical path (significant at long + # context). Conservative scope: only non-GDN hybrid models; GDN / + # pure-transformer / other topologies keep the materialised bias path + # until validated. if (self.prefill_use_fusedsdpa and self.is_causal and not self.is_pooling_model and not getattr(self, 'sliding_window', None) and not getattr(self, 'model_has_chunked_attention', False) @@ -6689,18 +6689,21 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, Returns: Updated attention metadata with attn_bias set """ - if attn_metadata is None or not attn_metadata.is_prompt: + if (attn_metadata is None or (self.prefill_use_fusedsdpa and attn_metadata.block_list is None) + or not attn_metadata.is_prompt): return attn_metadata - # FusedSDPA handles a purely causal mask natively (is_causal=True + - # valid_seq_lengths). Skip materialising a [bs, 1, q_len, - # total_kv_len] attn_bias when the model is plain-causal (no - # sliding-window / chunked-attention). This removes a sizable - # add_bf16 from the attention critical path during long-context - # chunked prefill. interleaved_sliding_window and chunked-attention - # bias paths (window_attn_bias / chunked_attn_bias) are populated - # later in process_metadata and used by hpu_attn instead. - # Conservative scope: only non-GDN hybrid models (e.g. Granite-4). + # Extended FSDPA-native causal short-circuit for non-GDN hybrid models + # (e.g. Granite-4 Mamba2+Transformer). FusedSDPA handles a purely + # causal mask natively (is_causal=True + valid_seq_lengths). Skip + # materialising a [bs, 1, q_len, total_kv_len] attn_bias even during + # chunked prefill (block_list is non-None) for these topologies; this + # removes a sizable add_bf16 from the attention critical path during + # long-context chunked prefill. interleaved_sliding_window and + # chunked-attention bias paths (window_attn_bias / chunked_attn_bias) + # are populated later in process_metadata and used by hpu_attn + # instead. Conservative scope: only non-GDN hybrid models; all other + # topologies retain the original behaviour. if (self.prefill_use_fusedsdpa and not self.interleaved_sliding_window and self.is_non_gdn_hybrid): return attn_metadata