diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 96e38f3bc8..2049e36cda 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -320,9 +320,12 @@ def _naive_prompt_attention(query: torch.Tensor, htcore.mark_step() attn_weights.add_(position_bias) if attn_bias is not None: - if attn_weights.dtype != attn_bias.dtype: - attn_bias = attn_bias.to(dtype=attn_weights.dtype) - attn_weights.add_(attn_bias) + if attn_bias.dtype == torch.bool: + attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) + else: + if attn_weights.dtype != attn_bias.dtype: + attn_bias = attn_bias.to(dtype=attn_weights.dtype) + attn_weights.add_(attn_bias) if get_config().fp32_softmax: attn_weights = torch.softmax(attn_weights, dim=-1) else: diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ebbe26c781..4dbef5be8a 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1925,7 +1925,6 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul return all_batch_contents, num_pad_across_dp def _make_attn_bias(self, context_groups, token_groups): - dtype = self.dtype is_causal = True # TODO: add support for non-causal tasks context_groups = torch.tensor(context_groups, device='cpu', dtype=torch.int16) context_groups = context_groups.repeat_interleave(self.block_size, dim=-1) @@ -1938,7 +1937,7 @@ def _make_attn_bias(self, context_groups, token_groups): causal_mask = torch.ones(num_queries, num_queries, device='cpu', dtype=torch.bool) causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0) attn_mask[:, :, context_len:].logical_or_(causal_mask) - attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf) + attn_mask = ~attn_mask return attn_mask.unflatten(0, (1, -1)) @@ -6114,7 +6113,7 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, diagonal=1) mask = causal_mask.logical_or(len_mask) mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_bias = ~mask attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) return attn_metadata @@ -6172,16 +6171,13 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) # causal_mask = causal_mask.logical_and(len_mask) - mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.concat((past_mask, causal_mask), dim=-1) else: # CAUSAL MASK without removing padding (CAUSAL+sliding window) # removing padding cause accuracy issue for images input - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) mask = torch.tril(tensor, diagonal=shift) - mask = torch.triu(mask, diagonal=shift - window_size + 1) - attn_bias = torch.log(mask) + attn_bias = torch.triu(mask, diagonal=shift - window_size + 1) attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata @@ -6232,18 +6228,15 @@ def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetada causal_mask = causal_mask & same_chunk_mask causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) - mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.concat((past_mask, causal_mask), dim=-1) else: - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) mask = torch.tril(tensor, diagonal=shift) idx = torch.arange(seq_len, device=device) chunk_id = idx // chunk_size same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) - mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) - attn_bias = torch.log(mask) + attn_bias = same_chunk & mask attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) return attn_metadata @@ -6282,8 +6275,7 @@ def _set_block_mapping(self, block_groups = metadata.block_groups mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_bias = mask < block_usage.unsqueeze(-1) if not is_fake_hpu(): block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)