Skip materialised causal attn_bias on FSDPA for non-GDN hybrid models#1413
Merged
Conversation
Signed-off-by: Krzysztof Smusz <ksmusz@habana.ai>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR optimizes HPU prompt/prefill attention for non-GDN hybrid (Mamba/SSM + Transformer) models by avoiding construction of a large materialized causal attn_bias when FusedSDPA can apply causality natively, including during chunked-prefill.
Changes:
- Add an
is_non_gdn_hybridtopology flag in bothHPUModelRunnerandHPUAttentionMetadataProcessor. - Extend the FusedSDPA causal short-circuit in
HPUModelRunner.set_attn_biasto cover chunked-prefill (non-Noneblock_list) under additional “plain-causal” constraints. - Extend the corresponding short-circuit in
HPUAttentionMetadataProcessor._set_attn_biasgated byprefill_use_fusedsdpaandis_non_gdn_hybrid.
Comment on lines
+1102
to
+1105
| # Non-GDN hybrid: at least one mamba/linear-style layer and zero GDN | ||
| # (gdn_attention / linear_attention) layers. Used to gate optimizations | ||
| # that have only been validated on non-GDN hybrid topologies | ||
| # (e.g. Granite-4 Mamba2+Transformer). |
Comment on lines
+6687
to
+6695
| # FusedSDPA handles a purely causal mask natively (is_causal=True + | ||
| # valid_seq_lengths). Skip materialising a [bs, 1, q_len, | ||
| # total_kv_len] attn_bias when the model is plain-causal (no | ||
| # sliding-window / chunked-attention). This removes a sizable | ||
| # add_bf16 from the attention critical path during long-context | ||
| # chunked prefill. interleaved_sliding_window and chunked-attention | ||
| # bias paths (window_attn_bias / chunked_attn_bias) are populated | ||
| # later in process_metadata and used by hpu_attn instead. | ||
| # Conservative scope: only non-GDN hybrid models (e.g. Granite-4). |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
jbyczkow
approved these changes
May 7, 2026
ksmusz
added a commit
that referenced
this pull request
May 11, 2026
The PR fixes the condition introduced in #1413 Signed-off-by: Krzysztof Smusz <ksmusz@habana.ai>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Skip materialised causal
attn_biason FSDPA for non-GDN hybrid modelsWhat
Extend the FSDPA-native causal short-circuit in both
HPUModelRunner.set_attn_biasand
HPUAttentionMetadataProcessor._set_attn_biasso the materialised[bs, 1, q_len, total_kv_len]attn_biasis no longer built for chunked-prefillsteps (i.e.
block_list is not None) on non-GDN hybrid topologies (Mamba/SSM +Transformer).
The previous short-circuit only fired when
block_list is None(single-shotprefill); any chunked-prefill step still allocated the bias and ran an
elementwise add against the attention scores. FusedSDPA can encode the same
causal mask natively via
is_causal=True+valid_seq_lengths, so the bias ispure overhead on the attention critical path — and grows with context length.
Why
For long-context prefills, the materialised bias and its associated elementwise
add become a significant exposed cost on the attention critical path. There is
also a secondary effect in
vllm_gaudi/extension/ops.py: when an explicitattn_biasis passed, FSDPA's optimised causal kernel is disabled(
is_causal=Falsefallback). Skipping the bias re-enables the faster causalkernel.
A general performance improvement is expected on long-context executions
(both throughput and TTFT / TPOT). Short-context executions are unaffected
because the bias tensor is small and other ops dominate.
Scope (conservative)
The optimisation only fires when all of the following hold:
prefill_use_fusedsdpais activeis_causaland not poolingsliding_windowmodel_has_chunked_attentionalibi_slopesnum_mamba_like_layers > 0 and num_gdn == 0)The non-GDN hybrid gate reuses the runner's existing
num_mamba_like_layers/num_gdncounters and re-derives the same flag in the metadata processor viamodel_config.get_num_layers_by_block_type, exposed as a singleself.is_non_gdn_hybridattribute.num_gdn > 0num_mamba_like_layers == 0Files changed
vllm_gaudi/v1/worker/hpu_model_runner.pyHPUModelRunner.__init__: defineself.is_non_gdn_hybridonce, reusingself.num_mamba_like_layers/self.num_gdn.HPUModelRunner.set_attn_bias: extend the early-return guard to allplain-causal cases and gate on
self.is_non_gdn_hybrid.HPUAttentionMetadataProcessor.__init__: computeself.is_non_gdn_hybridviaget_num_layers_by_block_type.HPUAttentionMetadataProcessor._set_attn_bias: tighten guard toprefill_use_fusedsdpa and not interleaved_sliding_window and is_non_gdn_hybrid.Risk and validation
measurable improvement in throughput and TTFT / TPOT / E2EL, with identical
input / generated token counts and no failed requests.
does not fire).
transformers: accuracy validation (e.g. GSM8K / lm-eval), a sliding-window
sanity check, and a long-context perplexity check.