-
Notifications
You must be signed in to change notification settings - Fork 129
Add support for chunked attention #821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bc27d3f
4318b82
3ec1172
d943557
77e3d2c
55f5ed3
bdc4319
e78da09
db9ac06
e20b15e
62a7b6b
b67e943
8af3617
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -459,10 +459,12 @@ def forward(self, *args, **kwargs): | |||||||||||||||||
| if 'warmup_mode' in kwargs: | ||||||||||||||||||
| kwargs.pop('warmup_mode') | ||||||||||||||||||
| input_ids = kwargs['input_ids'] | ||||||||||||||||||
| model_has_chunked_attention = kwargs.pop('model_has_chunked_attention', False) | ||||||||||||||||||
| if not self.unified_attn: | ||||||||||||||||||
| kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'], | ||||||||||||||||||
| input_ids.size(0), input_ids.size(1), | ||||||||||||||||||
| input_ids.device, self.dtype) | ||||||||||||||||||
| input_ids.device, self.dtype, | ||||||||||||||||||
| model_has_chunked_attention) | ||||||||||||||||||
| if self._rotary_prepare_cos_sin is not None: | ||||||||||||||||||
| self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) | ||||||||||||||||||
| attn_meta = kwargs.pop('attn_metadata') | ||||||||||||||||||
|
|
@@ -587,6 +589,11 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object: | |||||||||||||||||
| 'window_block_usage', | ||||||||||||||||||
| 'window_block_groups', | ||||||||||||||||||
| 'window_attn_bias', | ||||||||||||||||||
| 'chunked_block_mapping', | ||||||||||||||||||
| 'chunked_attn_bias', | ||||||||||||||||||
| 'chunked_block_list', | ||||||||||||||||||
| 'chunked_block_usage', | ||||||||||||||||||
| 'chunked_block_groups', | ||||||||||||||||||
| ]) | ||||||||||||||||||
| return attention_metadata | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -833,6 +840,8 @@ def __init__( | |||||||||||||||||
| self.scheduler_output: SchedulerOutput | None = None | ||||||||||||||||||
| self.warmup_mode: bool = False | ||||||||||||||||||
| self.batch_changed: bool = False | ||||||||||||||||||
| # WA for chunked attention support | ||||||||||||||||||
| self.model_has_chunked_attention = False | ||||||||||||||||||
|
|
||||||||||||||||||
| assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' | ||||||||||||||||||
| assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' | ||||||||||||||||||
|
|
@@ -1400,6 +1409,18 @@ def _get_num_decodes(self) -> int: | |||||||||||||||||
| num_decodes += 1 | ||||||||||||||||||
| return num_decodes | ||||||||||||||||||
|
|
||||||||||||||||||
| def maybe_set_chunked_attention_layers(self, model): | ||||||||||||||||||
| if hasattr(model.config, 'text_config') and \ | ||||||||||||||||||
| hasattr(model.config.text_config, 'attention_chunk_size') and \ | ||||||||||||||||||
| model.config.text_config.attention_chunk_size: | ||||||||||||||||||
| self.model_has_chunked_attention = True | ||||||||||||||||||
| try: | ||||||||||||||||||
| for layer in model.language_model.model.layers: | ||||||||||||||||||
| if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: | ||||||||||||||||||
| layer.self_attn.attn.impl.is_chunked_attention = True | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| pass | ||||||||||||||||||
|
|
||||||||||||||||||
| def _get_prompts_and_decodes( | ||||||||||||||||||
| self, | ||||||||||||||||||
| scheduler_output: "SchedulerOutput", | ||||||||||||||||||
|
|
@@ -2116,6 +2137,19 @@ def _create_decode_input_data(self, | |||||||||||||||||
| window_block_tables, slot_mapping.tolist(), | ||||||||||||||||||
| padded_batch_size * num_tokens) | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.model_has_chunked_attention: | ||||||||||||||||||
| chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size) | ||||||||||||||||||
|
||||||||||||||||||
| chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size) | |
| attention_chunk_size = self.model.model.config.text_config.attention_chunk_size | |
| if attention_chunk_size < self.block_size: | |
| raise ValueError( | |
| f"Configured attention_chunk_size ({attention_chunk_size}) must be at least " | |
| f"as large as block_size ({self.block_size}) when using chunked attention." | |
| ) | |
| chunk_size_in_blocks = attention_chunk_size // self.block_size |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division could result in zero or incorrect value if block_list.size(-1) is smaller than batch_size. This could lead to incorrect attention bias calculation. Add validation or use math.ceil for the division to ensure proper handling of partial blocks.
| max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) | |
| max_context_len = max_context_len * self.block_size | |
| if block_list is not None and batch_size > 0: | |
| # Compute number of blocks per sequence using ceiling division to handle partial blocks. | |
| blocks_per_seq = math.ceil(block_list.size(-1) / batch_size) | |
| max_context_len = blocks_per_seq * self.block_size | |
| else: | |
| max_context_len = 0 |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing with [0] assumes which_chunk has at least one element. While this may be guaranteed by the context, the assumption is not immediately clear. Consider adding a comment explaining why the first element is used or add an assertion to document this assumption.
| causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) | |
| causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) | |
| # which_chunk is expected to have at least one row (batch dimension > 0) in this code path. | |
| assert which_chunk.size(0) > 0, "which_chunk is expected to have at least one row" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bare
except Exception: passsilently suppresses all errors without logging. This makes debugging difficult if the chunked attention setup fails. Add logging to record when this exception occurs, including the exception details.