diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 746d4408f9..14321b3869 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -169,6 +169,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): window_block_groups: Optional[torch.Tensor] = None window_block_usage: Optional[torch.Tensor] = None window_attn_bias: Optional[torch.Tensor] = None + chunked_slot_mapping: Optional[torch.Tensor] = None + chunked_attn_bias: Optional[torch.Tensor] = None + chunked_block_mapping: Optional[torch.Tensor] = None + chunked_block_list: Optional[torch.Tensor] = None + chunked_block_groups: Optional[torch.Tensor] = None + chunked_block_usage: Optional[torch.Tensor] = None @dataclass @@ -493,6 +499,8 @@ def __init__( "is not implemented for " "HPUAttentionImpl") + self.is_chunked_attention = False + def _maybe_init_alibi_biases( self, max_seq_len, @@ -630,6 +638,9 @@ def forward( attn_bias = None window_size = (self.sliding_window, 0) common_args['window_size'] = window_size + if self.is_chunked_attention and \ + hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None: + attn_bias = attn_metadata.chunked_attn_bias out = ops.prompt_attention(impl=self.prefill_impl, query=query.view(query_shape), @@ -650,6 +661,12 @@ def forward( block_groups = attn_metadata.window_block_groups block_mapping = attn_metadata.window_block_mapping attn_bias = attn_metadata.window_attn_bias + elif self.is_chunked_attention and \ + attn_metadata.chunked_block_list is not None: + block_list = attn_metadata.chunked_block_list + block_groups = attn_metadata.chunked_block_groups + block_mapping = attn_metadata.chunked_block_mapping + attn_bias = attn_metadata.chunked_attn_bias else: block_list = attn_metadata.block_list block_groups = attn_metadata.block_groups diff --git a/vllm_gaudi/v1/attention/backends/hpu_attn.py b/vllm_gaudi/v1/attention/backends/hpu_attn.py index 1d2c1d2363..978cec0cd9 100644 --- a/vllm_gaudi/v1/attention/backends/hpu_attn.py +++ b/vllm_gaudi/v1/attention/backends/hpu_attn.py @@ -87,6 +87,9 @@ def make_decode_metadata(cls, window_block_list, window_block_usage, window_block_groups, + chunked_block_list, + chunked_block_usage, + chunked_block_groups, query_start_loc=None): return cls(is_prompt=False, block_mapping=None, @@ -100,6 +103,9 @@ def make_decode_metadata(cls, window_block_list=window_block_list, window_block_usage=window_block_usage, window_block_groups=window_block_groups, + chunked_block_list=chunked_block_list, + chunked_block_usage=chunked_block_usage, + chunked_block_groups=chunked_block_groups, input_positions=input_positions, slot_mapping=slot_mapping, block_size=block_size, diff --git a/vllm_gaudi/v1/spec_decode/hpu_eagle.py b/vllm_gaudi/v1/spec_decode/hpu_eagle.py index 7a0d508b0a..1c9159a810 100644 --- a/vllm_gaudi/v1/spec_decode/hpu_eagle.py +++ b/vllm_gaudi/v1/spec_decode/hpu_eagle.py @@ -237,6 +237,9 @@ def prepare_attn_metadata( window_block_list=None, window_block_usage=None, window_block_groups=None, + chunked_block_list=None, + chunked_block_usage=None, + chunked_block_groups=None, ) return common_attn_metadata diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 7c87f82f89..52edbed073 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -459,10 +459,12 @@ def forward(self, *args, **kwargs): if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] + model_has_chunked_attention = kwargs.pop('model_has_chunked_attention', False) if not self.unified_attn: kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + input_ids.device, self.dtype, + model_has_chunked_attention) if self._rotary_prepare_cos_sin is not None: self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') @@ -587,6 +589,11 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object: 'window_block_usage', 'window_block_groups', 'window_attn_bias', + 'chunked_block_mapping', + 'chunked_attn_bias', + 'chunked_block_list', + 'chunked_block_usage', + 'chunked_block_groups', ]) return attention_metadata @@ -833,6 +840,8 @@ def __init__( self.scheduler_output: SchedulerOutput | None = None self.warmup_mode: bool = False self.batch_changed: bool = False + # WA for chunked attention support + self.model_has_chunked_attention = False assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' @@ -1400,6 +1409,18 @@ def _get_num_decodes(self) -> int: num_decodes += 1 return num_decodes + def maybe_set_chunked_attention_layers(self, model): + if hasattr(model.config, 'text_config') and \ + hasattr(model.config.text_config, 'attention_chunk_size') and \ + model.config.text_config.attention_chunk_size: + self.model_has_chunked_attention = True + try: + for layer in model.language_model.model.layers: + if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: + layer.self_attn.attn.impl.is_chunked_attention = True + except Exception: + pass + def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", @@ -2116,6 +2137,19 @@ def _create_decode_input_data(self, window_block_tables, slot_mapping.tolist(), padded_batch_size * num_tokens) + if self.model_has_chunked_attention: + chunk_size_in_blocks = (self.model.model.config.text_config.attention_chunk_size // self.block_size) + seq_lens_block = [len(block_table) for block_table in block_tables_list] + num_seq_chunks = [math.ceil(sl / chunk_size_in_blocks) - 1 for sl in seq_lens_block] + block_tables_chunk = [ + block_table[num_seq_chunks[i] * chunk_size_in_blocks:] + for i, block_table in enumerate(block_tables_list) + ] + chunked_block_list, chunked_block_groups, chunked_block_usage = \ + self.get_habana_paged_attn_buffers( + block_tables_chunk, slot_mapping.tolist(), + padded_batch_size * num_tokens) + # CPU<>HPU sync *should not* happen here. block_list_device = async_h2d_copy(block_list, device=self.device) block_usage_device = async_h2d_copy(block_usage, device=self.device) @@ -2130,6 +2164,12 @@ def _create_decode_input_data(self, window_block_groups_device = async_h2d_copy( window_block_groups, device=self.device) if self.interleaved_sliding_window and self.sliding_window is not None else None + chunked_block_list_device = async_h2d_copy(chunked_block_list, + device=self.device) if self.model_has_chunked_attention else None + chunked_block_usage_device = async_h2d_copy(chunked_block_usage, + device=self.device) if self.model_has_chunked_attention else None + chunked_block_groups_device = async_h2d_copy(chunked_block_groups, + device=self.device) if self.model_has_chunked_attention else None token_ids_device = async_h2d_copy(token_ids, device=self.device) # when DP also enabled, some DP ranks will exeucte dummy run with empty @@ -2160,21 +2200,26 @@ def _create_decode_input_data(self, spec_decode_metadata = None logits_indices_device = async_h2d_copy(logits_indices, device=self.device) + attn_metadata = HPUAttentionMetadataV1.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + input_positions=None, + slot_mapping=slot_mapping_device, + block_size=self.block_size, + window_block_list=window_block_list_device, + window_block_usage=window_block_usage_device, + window_block_groups=window_block_groups_device, + chunked_block_list=chunked_block_list_device, + chunked_block_usage=chunked_block_usage_device, + chunked_block_groups=chunked_block_groups_device, + ) + return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids_device, position_ids=positions_device, logits_indices=logits_indices_device, - attn_metadata=HPUAttentionMetadataV1.make_decode_metadata( - block_list=block_list_device, - block_usage=block_usage_device, - block_groups=block_groups_device, - input_positions=None, - slot_mapping=slot_mapping_device, - block_size=self.block_size, - window_block_list=window_block_list_device, - window_block_usage=window_block_usage_device, - window_block_groups=window_block_groups_device, - ), + attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata) def _prepare_decode_inputs(self, @@ -2602,6 +2647,8 @@ def _execute_model_generic(self, else: # no hpu graphs for t.compile? use_graphs = False + if self.model_has_chunked_attention: + additional_kwargs.update({"model_has_chunked_attention": True}) trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata) if self.is_driver_worker: model_event_name = ("model_forward_" @@ -3709,6 +3756,7 @@ def load_model(self) -> None: htcore.mark_step() apply_model_specific_patches(self.model) + self.maybe_set_chunked_attention_layers(self.model) hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) modify_model_layers(self.model, @@ -5571,12 +5619,77 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata + def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + chunk_size: int, device: torch.device, + dtype: torch.dtype) -> HPUAttentionMetadataV1: + """Set attention bias for chunked attention. + + Args: + attn_metadata (HPUAttentionMetadataV1): The attention metadata. + batch_size (int): The batch size. + seq_len (int): The sequence length. + chunk_size (int): The chunk size. + device (torch.device): The device to use. + dtype (torch.dtype): The data type. + + Returns: + HPUAttentionMetadataV1: The updated attention metadata. + """ + + if (attn_metadata is None or not attn_metadata.is_prompt): + return attn_metadata + + prefill_metadata = attn_metadata + shift = 0 + + if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: + + context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None + block_list = prefill_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + query_positions = torch.arange(seq_len, device=device) + total_token_positions = context_lens_t.unsqueeze(-1) + query_positions.unsqueeze(0) + which_chunk = (total_token_positions // chunk_size) + chunk_start_positions = which_chunk * chunk_size + invalid_lens_t = chunk_start_positions - 1 + + past_indices = torch.arange(max_context_len, device=device) + past_mask = ( + (past_indices.unsqueeze(0).unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & + (past_indices.unsqueeze(0).unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(-1))).unsqueeze(1) + + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) + query_chunk_ids = which_chunk[0] + same_chunk_mask = query_chunk_ids.unsqueeze(0) == query_chunk_ids.unsqueeze(1) + + causal_mask = causal_mask & same_chunk_mask + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) + + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(float('-inf'), dtype=dtype, device=device)) + else: + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + idx = torch.arange(seq_len, device=device) + chunk_id = idx // chunk_size + same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) + same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) + mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) + attn_bias = torch.log(mask) + + attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) + return attn_metadata + def _set_block_mapping(self, metadata: HPUAttentionMetadataV1, batch_size: int, device: torch.device, dtype: torch.dtype, - is_window_block: bool = False) -> HPUAttentionMetadataV1: + is_window_block: bool = False, + update_for_chunked_attention: bool = False) -> HPUAttentionMetadataV1: """ Set block mapping for decode phase. @@ -5588,6 +5701,7 @@ def _set_block_mapping(self, device: Device to create tensors on dtype: Data type for tensors is_window_block: Whether this is for window blocks + update_for_chunked_attention: Whether to update for chunked attention Returns: Updated attention metadata with block_mapping and attn_bias set @@ -5595,6 +5709,9 @@ def _set_block_mapping(self, if is_window_block: block_usage = metadata.window_block_usage block_groups = metadata.window_block_groups + elif update_for_chunked_attention: + block_usage = metadata.chunked_block_usage + block_groups = metadata.chunked_block_groups else: block_usage = metadata.block_usage block_groups = metadata.block_groups @@ -5625,6 +5742,11 @@ def _set_block_mapping(self, "TrimmedAttentionMetadata", window_block_mapping=block_mapping, window_attn_bias=attn_bias) + elif update_for_chunked_attention: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + chunked_block_mapping=block_mapping, + chunked_attn_bias=attn_bias) else: metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", @@ -5632,8 +5754,13 @@ def _set_block_mapping(self, attn_bias=attn_bias) return metadata - def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, - device: torch.device, dtype: torch.dtype) -> HPUAttentionMetadataV1: + def process_metadata(self, + attn_metadata: HPUAttentionMetadataV1, + batch_size: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + model_has_chunked_attention: bool = False) -> HPUAttentionMetadataV1: """ Post-process attention metadata with appropriate masks and mappings. @@ -5647,6 +5774,7 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in seq_len: Sequence length (for prompt phase) device: Device to create tensors on dtype: Data type for tensors + model_has_chunked_attention: Whether the model has chunked attention Returns: Post-processed attention metadata with additional tensors @@ -5656,8 +5784,18 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) + if model_has_chunked_attention: + attention_chunk_size = self.vllm_config.model_config.hf_config.text_config.attention_chunk_size + attn_metadata = self._set_attn_bias_for_chunked_attention(attn_metadata, batch_size, seq_len, + attention_chunk_size, device, dtype) else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) + if model_has_chunked_attention: + attn_metadata = self._set_block_mapping(attn_metadata, + batch_size, + device, + dtype, + update_for_chunked_attention=True) if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) return attn_metadata