-
-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[XPU] fp8_mqa_logits and fp8_paged_mqa_logits torch fallbacks for XPU #39156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ff71e9c
ad6f2ef
5b70ae7
7aa3017
f7ffcfd
815d781
08ec43a
8b480ab
7515ad5
741ff70
f508b76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -429,6 +429,92 @@ def cp_gather_indexer_k_quant_cache( | |
| ) | ||
| dst_scale[:] = kv_cache_flat[scale_indices] | ||
|
|
||
| @staticmethod | ||
| def fp8_mqa_logits_torch( | ||
| q: torch.Tensor, | ||
| kv: tuple[torch.Tensor, torch.Tensor], | ||
| weights: torch.Tensor, | ||
| cu_seqlen_ks: torch.Tensor, | ||
| cu_seqlen_ke: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| kv_fp8, scale = kv | ||
| seq_len_kv = kv_fp8.shape[0] | ||
| k = kv_fp8.to(torch.bfloat16) | ||
| q = q.to(torch.bfloat16) | ||
|
|
||
| mask_lo = ( | ||
| torch.arange(0, seq_len_kv, device=q.device)[None, :] | ||
| >= cu_seqlen_ks[:, None] | ||
| ) | ||
| mask_hi = ( | ||
| torch.arange(0, seq_len_kv, device=q.device)[None, :] | ||
| < cu_seqlen_ke[:, None] | ||
| ) | ||
| mask = mask_lo & mask_hi | ||
|
|
||
| score = torch.einsum("mhd,nd->hmn", q, k).float() * scale | ||
| logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) | ||
| logits = logits.masked_fill(~mask, float("-inf")) | ||
| return logits | ||
|
|
||
| @staticmethod | ||
| def fp8_paged_mqa_logits_torch( | ||
| q: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| context_lens: torch.Tensor, | ||
| block_tables: torch.Tensor, | ||
| max_model_len: int, | ||
| ) -> torch.Tensor: | ||
| from vllm.utils.math_utils import cdiv | ||
|
|
||
| fp8_dtype = current_platform.fp8_dtype() | ||
| batch_size, next_n, _, dim = q.size() | ||
| kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] | ||
| scale = scale.contiguous().view(torch.float) | ||
| q = q.float() | ||
| kv_cache = kv_cache.view(fp8_dtype).float() * scale | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dequantizing the entire |
||
| _, block_size, _, _ = kv_cache.size() | ||
| logits = torch.full( | ||
| [batch_size * next_n, max_model_len], | ||
| float("-inf"), | ||
| device=q.device, | ||
| dtype=torch.float32, | ||
| ) | ||
| for i in range(batch_size): | ||
| context_len = context_lens[i].item() | ||
| q_offsets = torch.arange(context_len - next_n, context_len, device=q.device) | ||
| weight_slice = ( | ||
| weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() | ||
| ) | ||
| for block_idx in range(cdiv(context_len, block_size)): | ||
| block_id = block_tables[i][block_idx] | ||
| qx, kx = q[i], kv_cache[block_id] | ||
| k_offsets = torch.arange( | ||
| block_idx * block_size, | ||
| (block_idx + 1) * block_size, | ||
| device=q.device, | ||
| ) | ||
| mask = (k_offsets[None, :] < context_len) & ( | ||
| k_offsets[None, :] <= q_offsets[:, None] | ||
| ) | ||
| s = torch.where( | ||
| mask[None, :, :], | ||
| (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( | ||
| logits.dtype | ||
| ), | ||
| float("-inf"), | ||
| ) | ||
| s = torch.relu(s) * weight_slice[..., None] | ||
| s = s.sum(dim=0) | ||
| logits[ | ||
| i * next_n : (i + 1) * next_n, | ||
| block_idx * block_size : (block_idx + 1) * block_size, | ||
| ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], | ||
| s, | ||
| float("-inf")) | ||
| return logits | ||
|
|
||
| @staticmethod | ||
| def top_k_per_row_prefill( | ||
| logits: torch.Tensor, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of
torch.einsumto compute a[H, M, N]tensor poses a critical risk of out-of-memory (OOM) errors for large sequence lengths. For instance, with 128 heads and a 1GB logits budget, this intermediate tensor would require 128GB of memory. It is highly recommended to iterate over the heads and accumulate the results to keep memory usage bounded by[M, N].