Skip to content

Skip materialised causal attn_bias on FSDPA for non-GDN hybrid models#1413

Merged
ksmusz merged 1 commit into
mainfrom
dev/ksmusz/skip_materialised_causal_attn_bias
May 7, 2026
Merged

Skip materialised causal attn_bias on FSDPA for non-GDN hybrid models#1413
ksmusz merged 1 commit into
mainfrom
dev/ksmusz/skip_materialised_causal_attn_bias

Conversation

@ksmusz
Copy link
Copy Markdown
Collaborator

@ksmusz ksmusz commented May 5, 2026

Skip materialised causal attn_bias on FSDPA for non-GDN hybrid models

What

Extend the FSDPA-native causal short-circuit in both HPUModelRunner.set_attn_bias
and HPUAttentionMetadataProcessor._set_attn_bias so the materialised
[bs, 1, q_len, total_kv_len] attn_bias is no longer built for chunked-prefill
steps (i.e. block_list is not None) on non-GDN hybrid topologies (Mamba/SSM +
Transformer).

The previous short-circuit only fired when block_list is None (single-shot
prefill); any chunked-prefill step still allocated the bias and ran an
elementwise add against the attention scores. FusedSDPA can encode the same
causal mask natively via is_causal=True + valid_seq_lengths, so the bias is
pure overhead on the attention critical path — and grows with context length.

Why

For long-context prefills, the materialised bias and its associated elementwise
add become a significant exposed cost on the attention critical path. There is
also a secondary effect in vllm_gaudi/extension/ops.py: when an explicit
attn_bias is passed, FSDPA's optimised causal kernel is disabled
(is_causal=False fallback). Skipping the bias re-enables the faster causal
kernel.

A general performance improvement is expected on long-context executions
(both throughput and TTFT / TPOT). Short-context executions are unaffected
because the bias tensor is small and other ops dominate.

Scope (conservative)

The optimisation only fires when all of the following hold:

  • prefill_use_fusedsdpa is active
  • model is is_causal and not pooling
  • no sliding_window
  • no model_has_chunked_attention
  • no alibi_slopes
  • model is a non-GDN hybrid (num_mamba_like_layers > 0 and num_gdn == 0)

The non-GDN hybrid gate reuses the runner's existing num_mamba_like_layers /
num_gdn counters and re-derives the same flag in the metadata processor via
model_config.get_num_layers_by_block_type, exposed as a single
self.is_non_gdn_hybrid attribute.

Model class Hits short-circuit?
Mamba/SSM + Transformer hybrids (non-GDN) yes
GDN hybrids no — num_gdn > 0
Plain transformer (dense / MoE) no — num_mamba_like_layers == 0
Sliding-window models no — already excluded
Chunked-attention models no — already excluded
Pooling / encoder no — already excluded
ALiBi no — already excluded

Files changed

  • vllm_gaudi/v1/worker/hpu_model_runner.py
    • HPUModelRunner.__init__: define self.is_non_gdn_hybrid once, reusing
      self.num_mamba_like_layers / self.num_gdn.
    • HPUModelRunner.set_attn_bias: extend the early-return guard to all
      plain-causal cases and gate on self.is_non_gdn_hybrid.
    • HPUAttentionMetadataProcessor.__init__: compute
      self.is_non_gdn_hybrid via get_num_layers_by_block_type.
    • HPUAttentionMetadataProcessor._set_attn_bias: tighten guard to
      prefill_use_fusedsdpa and not interleaved_sliding_window and is_non_gdn_hybrid.

Risk and validation

  • A/B benchmark on a long-context workload of a non-GDN hybrid model showed a
    measurable improvement in throughput and TTFT / TPOT / E2EL, with identical
    input / generated token counts and no failed requests.
  • Topologies outside the gate are byte-for-byte unchanged (the early return
    does not fire).
  • Recommended follow-ups before broadening the gate to plain-causal
    transformers: accuracy validation (e.g. GSM8K / lm-eval), a sliding-window
    sanity check, and a long-context perplexity check.

Signed-off-by: Krzysztof Smusz <ksmusz@habana.ai>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes HPU prompt/prefill attention for non-GDN hybrid (Mamba/SSM + Transformer) models by avoiding construction of a large materialized causal attn_bias when FusedSDPA can apply causality natively, including during chunked-prefill.

Changes:

  • Add an is_non_gdn_hybrid topology flag in both HPUModelRunner and HPUAttentionMetadataProcessor.
  • Extend the FusedSDPA causal short-circuit in HPUModelRunner.set_attn_bias to cover chunked-prefill (non-None block_list) under additional “plain-causal” constraints.
  • Extend the corresponding short-circuit in HPUAttentionMetadataProcessor._set_attn_bias gated by prefill_use_fusedsdpa and is_non_gdn_hybrid.

Comment on lines +1102 to +1105
# Non-GDN hybrid: at least one mamba/linear-style layer and zero GDN
# (gdn_attention / linear_attention) layers. Used to gate optimizations
# that have only been validated on non-GDN hybrid topologies
# (e.g. Granite-4 Mamba2+Transformer).
Comment on lines +6687 to +6695
# FusedSDPA handles a purely causal mask natively (is_causal=True +
# valid_seq_lengths). Skip materialising a [bs, 1, q_len,
# total_kv_len] attn_bias when the model is plain-causal (no
# sliding-window / chunked-attention). This removes a sizable
# add_bf16 from the attention critical path during long-context
# chunked prefill. interleaved_sliding_window and chunked-attention
# bias paths (window_attn_bias / chunked_attn_bias) are populated
# later in process_metadata and used by hpu_attn instead.
# Conservative scope: only non-GDN hybrid models (e.g. Granite-4).
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
3975eb6de6ea914b9d7b27fd517e0c971ddeb6fc

@ksmusz ksmusz merged commit 808dbfa into main May 7, 2026
6 checks passed
@ksmusz ksmusz deleted the dev/ksmusz/skip_materialised_causal_attn_bias branch May 7, 2026 13:23
ksmusz added a commit that referenced this pull request May 11, 2026
The PR fixes the condition introduced in
#1413

Signed-off-by: Krzysztof Smusz <ksmusz@habana.ai>
mgawarkiewicz-intel pushed a commit that referenced this pull request May 26, 2026
#1481)

…d models (#1413)"

This reverts commit 808dbfa.

Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com>
Co-authored-by: Jakub Byczkowski <jbyczkowski@habana.ai>
iboiko-habana added a commit that referenced this pull request May 27, 2026
#1482)

…d models (#1413)"

This reverts commit 808dbfa.

Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com>
Co-authored-by: Iryna Boiko <iryna.boiko@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants