[Attention Metadata Overhaul 2/N] Move metadata processing outside HPUModelAdapter#530
[Attention Metadata Overhaul 2/N] Move metadata processing outside HPUModelAdapter#530kzawora-intel wants to merge 18 commits into
Conversation
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
There was a problem hiding this comment.
Pull Request Overview
This PR moves HPU attention metadata processing from the HpuModelAdapter into a dedicated HPUAttentionMetadataProcessor class, allowing metadata biases to be computed on CPU and copied asynchronously to HPU. This refactoring removes metadata processing logic from the model forward path and handles it at input preparation time instead.
Key Changes:
- Extracted metadata processing into a standalone
HPUAttentionMetadataProcessorclass - Moved metadata processing to occur during input preparation (prefill/decode batch formation) rather than in model forward
- Added support for processing metadata on CPU with async copy to HPU device
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
…or' into private/kzawora/metadata_process_cpu Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def metadata_update_with_trim(obj: object, typename: str, trim: bool, **to_override): | ||
| if trim: | ||
| return custom_tuple_replace(obj, typename, **to_override) | ||
|
|
||
| for key in to_override: | ||
| assert hasattr(obj, key), f"Field {key} must exist in untrimmed metadata." | ||
| setattr(obj, key, to_override[key]) | ||
| return obj |
There was a problem hiding this comment.
The function metadata_update_with_trim lacks a docstring explaining its purpose, parameters, return value, and the distinction between trimmed and untrimmed metadata handling. This is especially important given the conditional logic and the use of setattr for dynamic attribute modification.
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" | ||
| context_lens_t = prefill_metadata.context_lens_tensor | ||
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" |
There was a problem hiding this comment.
The error message should be more specific by indicating which phase (prefill) or operation is being performed when this assertion fails, to help with debugging.
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" | |
| context_lens_t = prefill_metadata.context_lens_tensor | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" | |
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias during prefill (prompt) phase" | |
| context_lens_t = prefill_metadata.context_lens_tensor | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias during prefill (prompt) phase" |
| seq_lens_t = prefill_metadata.seq_lens_tensor | ||
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" | ||
| context_lens_t = prefill_metadata.context_lens_tensor | ||
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" |
There was a problem hiding this comment.
The error message should be more specific by indicating which phase (prefill) or operation is being performed when this assertion fails, to help with debugging.
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias during prefill phase" |
|
|
||
| if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: | ||
| context_lens_t = prefill_metadata.context_lens_tensor | ||
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" |
There was a problem hiding this comment.
The error message should be more specific by indicating this is for sliding window attention to aid debugging when this assertion fails.
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias for sliding window attention" |
| # NOTE(kzawora): I'm not sure why we set block mapping twice for sliding window | ||
| # - we should check if that can be reduced to a single call. |
There was a problem hiding this comment.
This TODO-style comment expresses uncertainty about the implementation. Either investigate and resolve this concern, or rephrase as a clearer explanation if the double call is intentional (e.g., for separate window and non-window blocks).
| # NOTE(kzawora): I'm not sure why we set block mapping twice for sliding window | |
| # - we should check if that can be reduced to a single call. | |
| # For sliding window, we set block mapping twice: once for the base mapping and once for the sliding window mapping. | |
| # This ensures both standard and sliding window block mappings are correctly applied. |
|
|
||
|
|
||
| def metadata_update_with_trim(obj: object, typename: str, trim: bool, **to_override): | ||
| if trim: |
There was a problem hiding this comment.
I can't find a place where we use trimmed replace - do we need this boolean ?
| return _TYPE_CACHE[typename]['type'](**values) # type: ignore | ||
|
|
||
|
|
||
| def metadata_update_with_trim(obj: object, typename: str, trim: bool, **to_override): |
There was a problem hiding this comment.
I can't find a place where we use trimmed replace - do we need typename and trim args ? IF not, we can just rename this function to metadata_update
| dst_device: torch.device, | ||
| dtype: torch.dtype, | ||
| is_window_block: bool = False, | ||
| trim: bool = False) -> HPUAttentionMetadataV1: |
| "TrimmedAttentionMetadata", | ||
| trim=trim, |
| "TrimmedAttentionMetadata", | ||
| trim=trim, |
|
|
||
| def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, | ||
| window_size: int, src_device: torch.device, dst_device: torch.device, | ||
| dtype: torch.dtype, trim: bool) -> HPUAttentionMetadataV1: |
| "TrimmedAttentionMetadata", | ||
| trim=trim, |
|
|
||
| def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, | ||
| src_device: torch.device, dst_device: torch.device, dtype: torch.dtype, | ||
| trim: bool) -> HPUAttentionMetadataV1: |
| torch.device('cpu'), | ||
| token_ids_device.device, | ||
| self.dtype, | ||
| trim=False) |
| torch.device('cpu'), | ||
| token_ids.device, | ||
| self.dtype, | ||
| trim=False) |
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
…or' into HEAD Signed-off-by: Konrad Zawora <kzawora@habana.ai>
requires #526, the next logical step - we remove usage of metadata postprocessor inside HpuModelAdapter and do it at input preparation time, and on CPU, copying data asynchronously to HPU. I needed also to change some stuff around for the processor to accept untrimmed metadata - this works as-is, but unfortunately I've noticed pretty significant performance drop in small models e2e perf.