Skip to content

Add support for chunked attention#821

Merged
ksmusz merged 13 commits into
vllm-project:mainfrom
kfojcik-intel:dev/kfojcik/chunked_attn
Jan 23, 2026
Merged

Add support for chunked attention#821
ksmusz merged 13 commits into
vllm-project:mainfrom
kfojcik-intel:dev/kfojcik/chunked_attn

Conversation

@kfojcik-intel
Copy link
Copy Markdown
Contributor

Cherry-pick of
6e1be4e but adapted to recent changes in #526

Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
Copilot AI review requested due to automatic review settings January 15, 2026 06:35
@kfojcik-intel kfojcik-intel changed the title Dev/kfojcik/chunked attn Add support for chunked attention Jan 15, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for chunked attention by adapting a cherry-picked commit to work with recent changes. The implementation introduces metadata fields, bias calculations, and block mapping logic specifically for handling chunked attention patterns.

Changes:

  • Added chunked attention metadata fields and processing logic to support models with chunked attention patterns
  • Implemented attention bias calculation for chunked attention in both prefill and decode phases
  • Added automatic detection and configuration of chunked attention layers based on model config

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
vllm_gaudi/v1/worker/hpu_model_runner.py Core implementation including metadata processing, block mapping, bias calculation, and model detection for chunked attention
vllm_gaudi/v1/spec_decode/hpu_eagle.py Added chunked attention metadata parameters to EAGLE speculative decoding
vllm_gaudi/v1/attention/backends/hpu_attn.py Updated attention metadata factory method to include chunked attention parameters
vllm_gaudi/attention/backends/hpu_attn.py Added chunked attention fields to metadata class and implementation logic in attention forward pass

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1416 to +1426
def maybe_set_chunked_attention_layers(self, model):
if hasattr(model.config, 'text_config') and \
hasattr(model.config.text_config, 'attention_chunk_size') and \
model.config.text_config.attention_chunk_size:
self.model_has_chunked_attention = True
try:
for layer in model.language_model.model.layers:
if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__:
layer.self_attn.attn.impl.is_chunked_attention = True
except Exception:
pass
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bare except Exception: pass silently suppresses all errors without logging. This makes debugging difficult if the chunked attention setup fails. Add logging to record when this exception occurs, including the exception details.

Copilot uses AI. Check for mistakes.
padded_batch_size * num_tokens)

if self.model_has_chunked_attention:
chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size)
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division could result in zero if attention_chunk_size is smaller than block_size, leading to incorrect chunking behavior. Add validation to ensure attention_chunk_size is at least equal to block_size or handle the zero case appropriately.

Suggested change
chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size)
attention_chunk_size = self.model.model.config.text_config.attention_chunk_size
if attention_chunk_size < self.block_size:
raise ValueError(
f"Configured attention_chunk_size ({attention_chunk_size}) must be at least "
f"as large as block_size ({self.block_size}) when using chunked attention."
)
chunk_size_in_blocks = attention_chunk_size // self.block_size

Copilot uses AI. Check for mistakes.
Comment on lines +5611 to +5612
max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division could result in zero or incorrect value if block_list.size(-1) is smaller than batch_size. This could lead to incorrect attention bias calculation. Add validation or use math.ceil for the division to ensure proper handling of partial blocks.

Suggested change
max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
if block_list is not None and batch_size > 0:
# Compute number of blocks per sequence using ceiling division to handle partial blocks.
blocks_per_seq = math.ceil(block_list.size(-1) / batch_size)
max_context_len = blocks_per_seq * self.block_size
else:
max_context_len = 0

Copilot uses AI. Check for mistakes.
(past_indices.unsqueeze(0).unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) &
(past_indices.unsqueeze(0).unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(-1))).unsqueeze(1)

causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing with [0] assumes which_chunk has at least one element. While this may be guaranteed by the context, the assumption is not immediately clear. Consider adding a comment explaining why the first element is used or add an assertion to document this assumption.

Suggested change
causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
# which_chunk is expected to have at least one row (batch dimension > 0) in this code path.
assert which_chunk.size(0) > 0, "which_chunk is expected to have at least one row"

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
66652e8082b69ba7d1e6aca7c234433de55f1b9b

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
66652e8082b69ba7d1e6aca7c234433de55f1b9b

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
4c1c501a7ee1d5efbad945ea62a702ce5cefb799

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
6218034dd7f9a56596e4fd8c8c8fc1d8011ed9c2

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
6218034dd7f9a56596e4fd8c8c8fc1d8011ed9c2

Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
6218034dd7f9a56596e4fd8c8c8fc1d8011ed9c2

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
6218034dd7f9a56596e4fd8c8c8fc1d8011ed9c2

@ksmusz ksmusz merged commit 7e97f22 into vllm-project:main Jan 23, 2026
53 checks passed
wpyszka pushed a commit that referenced this pull request Jan 28, 2026
…#855 (#881)

Cherry pick missing fixes:
chunked attention fixes from
#821
llama4 32k+ context window
#855

---------

Signed-off-by: Luca Calabria <luca.calabria@intel.com>
Signed-off-by: Jakub Byczkowski <jbyczkowski@habana.ai>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com>
Signed-off-by: linoy buchnik <lbuchnik@habana.ai>
Signed-off-by: Iryna Boiko <iboiko@habana.ai>
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Jakub Byczkowski <jbyczkowski@habana.ai>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Radosław Smyrek <radoslawx.smyrek@intel.com>
Co-authored-by: Linoy Buchnik <linoybu@gmail.com>
Co-authored-by: Iryna Boiko <iboiko@habana.ai>
Co-authored-by: Artur Fierka <artur.fierka@intel.com>
testdig pushed a commit to testdig/vllm-gaudi-fork that referenced this pull request Jan 29, 2026
Cherry-pick of

vllm-project@6e1be4e
but adapted to recent changes in
vllm-project#526

---------

Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
Signed-off-by: Wang, Zheng W <zheng.w.wang@intel.com>
slokesha pushed a commit to libinta/vllm-gaudi that referenced this pull request Jan 29, 2026
…ndow fix from vllm-project#855 (vllm-project#881)

Cherry pick missing fixes:
chunked attention fixes from
vllm-project#821
llama4 32k+ context window
vllm-project#855

---------

Signed-off-by: Luca Calabria <luca.calabria@intel.com>
Signed-off-by: Jakub Byczkowski <jbyczkowski@habana.ai>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com>
Signed-off-by: linoy buchnik <lbuchnik@habana.ai>
Signed-off-by: Iryna Boiko <iboiko@habana.ai>
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Jakub Byczkowski <jbyczkowski@habana.ai>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Radosław Smyrek <radoslawx.smyrek@intel.com>
Co-authored-by: Linoy Buchnik <linoybu@gmail.com>
Co-authored-by: Iryna Boiko <iboiko@habana.ai>
Co-authored-by: Artur Fierka <artur.fierka@intel.com>
Signed-off-by: slokesha <slokeshappa@habana.ai>
slokesha pushed a commit to libinta/vllm-gaudi that referenced this pull request Feb 9, 2026
Cherry-pick of

vllm-project@6e1be4e
but adapted to recent changes in
vllm-project#526

---------

Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
Signed-off-by: slokesha <slokeshappa@habana.ai>
adobrzyn pushed a commit that referenced this pull request Mar 31, 2026
Cherry-pick of

6e1be4e
but adapted to recent changes in
#526

---------

Signed-off-by: Katarzyna Fojcik <kfojcik@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants