-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[ROCm][DeepSeek-V3.2][Perf] Enable gluon preshuffle indexer (block_size=64 + SHUFFLE layout) #41008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d979599
da48589
e49be8d
48f4c2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,12 +7,55 @@ | |||||
| import torch | ||||||
|
|
||||||
| from vllm.forward_context import get_forward_context | ||||||
| from vllm.logger import init_logger | ||||||
| from vllm.platforms import current_platform | ||||||
| from vllm.triton_utils import tl, triton | ||||||
| from vllm.utils.torch_utils import LayerNameType | ||||||
| from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata | ||||||
| from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton | ||||||
|
|
||||||
| logger = init_logger(__name__) | ||||||
|
|
||||||
| _AITER_MQA_SMALL_HEADS_WARNED = False | ||||||
|
|
||||||
| _cached_paged_logits: torch.Tensor | None = None | ||||||
|
|
||||||
| # Over-allocate each logits row by this many float32 columns to absorb | ||||||
| # out-of-bounds stores from the AITER preshuffle kernel's unmasked | ||||||
| # buffer_store (up to ~190 elements past context_length). The downstream | ||||||
| # top_k_per_row_decode op takes stride(0)/stride(1) explicitly, so the | ||||||
| # widened stride is transparent. Credit: maeehart (vllm-project/vllm#40643). | ||||||
| _PAGED_LOGITS_COL_PADDING = 256 | ||||||
|
|
||||||
|
|
||||||
| def _get_paged_logits_buffer( | ||||||
| rows: int, cols: int, device: torch.device | ||||||
| ) -> torch.Tensor: | ||||||
| """Return a (rows, cols) float32 view pre-filled with -inf. | ||||||
|
|
||||||
| Within a decode step every layer sees the same (batch*next_n, | ||||||
| actual_max_seq_len) shape, so the expensive torch.full call only | ||||||
| happens once per step (or when the shape changes). The backing | ||||||
| storage is wider by _PAGED_LOGITS_COL_PADDING columns to guard | ||||||
| against preshuffle kernel OOB writes. | ||||||
| """ | ||||||
| global _cached_paged_logits | ||||||
| padded_cols = cols + _PAGED_LOGITS_COL_PADDING | ||||||
| if ( | ||||||
| _cached_paged_logits is not None | ||||||
| and _cached_paged_logits.shape[0] >= rows | ||||||
| and _cached_paged_logits.shape[1] >= padded_cols | ||||||
| and _cached_paged_logits.device == device | ||||||
| ): | ||||||
| buf = _cached_paged_logits[:rows, :cols] | ||||||
| buf.fill_(float("-inf")) | ||||||
| return buf | ||||||
| _cached_paged_logits = torch.full( | ||||||
| (rows, padded_cols), float("-inf"), device=device, dtype=torch.float32 | ||||||
| ) | ||||||
| return _cached_paged_logits[:rows, :cols] | ||||||
|
|
||||||
|
|
||||||
| if current_platform.is_cuda_alike(): | ||||||
| from vllm import _custom_ops as ops | ||||||
|
|
||||||
|
|
@@ -111,7 +154,7 @@ def indexer_k_quant_and_cache_triton( | |||||
| block_size, | ||||||
| num_tokens, | ||||||
| head_dim, | ||||||
| "NHD", | ||||||
| "SHUFFLE", | ||||||
| block_tile_size, | ||||||
| head_tile_size, | ||||||
| IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, | ||||||
|
|
@@ -212,7 +255,7 @@ def cp_gather_indexer_k_quant_cache_triton( | |||||
| block_table_stride, | ||||||
| k_cache_value.stride(0), | ||||||
| k_cache_scale.stride(0), | ||||||
| "NHD", | ||||||
| "SHUFFLE", | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| head_dim, | ||||||
| block_tile_size, | ||||||
| head_tile_size, | ||||||
|
|
@@ -300,6 +343,7 @@ def rocm_fp8_paged_mqa_logits( | |||||
| block_tables: torch.Tensor, | ||||||
| schedule_metadata: torch.Tensor, | ||||||
| max_model_len: int, | ||||||
| block_size: int = 1, | ||||||
| ) -> torch.Tensor: | ||||||
| """Compute FP8 MQA logits using paged KV-cache. | ||||||
|
|
||||||
|
|
@@ -317,41 +361,80 @@ def rocm_fp8_paged_mqa_logits( | |||||
| schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; | ||||||
| used to distribute work across SMs. | ||||||
| max_model_len: Maximum sequence length used to size the logits output. | ||||||
| block_size: KV cache block size (default 1). | ||||||
|
|
||||||
| Returns: | ||||||
| Logits tensor of shape [B * next_n, max_model_len], dtype | ||||||
| `torch.float32`. | ||||||
| """ | ||||||
| global _AITER_MQA_SMALL_HEADS_WARNED | ||||||
| from vllm._aiter_ops import rocm_aiter_ops | ||||||
|
|
||||||
| batch_size, next_n, heads, _ = q_fp8.shape | ||||||
|
|
||||||
| aiter_paged_mqa_logits_module = None | ||||||
| if rocm_aiter_ops.is_enabled(): | ||||||
| if rocm_aiter_ops.is_enabled() and heads >= 16: | ||||||
| aiter_paged_mqa_logits_module = paged_mqa_logits_module() | ||||||
| elif rocm_aiter_ops.is_enabled() and not _AITER_MQA_SMALL_HEADS_WARNED: | ||||||
| logger.warning( | ||||||
| "AITER paged MQA logits kernel does not support %d heads " | ||||||
| "(requires >= 16). Falling back to PyTorch reference.", | ||||||
| heads, | ||||||
| ) | ||||||
| _AITER_MQA_SMALL_HEADS_WARNED = True | ||||||
|
|
||||||
| if aiter_paged_mqa_logits_module is not None: | ||||||
| deepgemm_fp8_paged_mqa_logits_stage1 = ( | ||||||
| aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 | ||||||
| _deepgemm_fp8_paged_mqa_logits = getattr( | ||||||
| aiter_paged_mqa_logits_module, | ||||||
| "deepgemm_fp8_paged_mqa_logits", | ||||||
| None, | ||||||
| ) | ||||||
| batch_size, next_n, heads, _ = q_fp8.shape | ||||||
| out_qk = torch.full( | ||||||
| (heads, batch_size * next_n, max_model_len), | ||||||
| float("-inf"), | ||||||
| device="cuda", | ||||||
| dtype=torch.float32, | ||||||
| use_new_api = ( | ||||||
| _deepgemm_fp8_paged_mqa_logits is not None and block_size > 1 | ||||||
| ) | ||||||
| # TODO: 1. Replace _stage1 and out_qk.sum with another fused variant; | ||||||
| # 2. Remove ChunkQ when AITER PR #2891 merged | ||||||
| deepgemm_fp8_paged_mqa_logits_stage1( | ||||||
| q_fp8, | ||||||
| kv_cache_fp8, | ||||||
| weights, | ||||||
| out_qk, | ||||||
| context_lens, | ||||||
| block_tables, | ||||||
| max_model_len, | ||||||
| ChunkQ=heads, | ||||||
| ) | ||||||
| return out_qk.sum(dim=0) | ||||||
| if use_new_api: | ||||||
| out_logits = _get_paged_logits_buffer( | ||||||
| batch_size * next_n, max_model_len, q_fp8.device | ||||||
| ) | ||||||
| _deepgemm_fp8_paged_mqa_logits( | ||||||
| q_fp8, | ||||||
| kv_cache_fp8, | ||||||
| weights, | ||||||
| out_logits, | ||||||
| context_lens, | ||||||
| block_tables, | ||||||
| max_model_len, | ||||||
| ChunkK=256, | ||||||
| Preshuffle=block_size == 64, | ||||||
| KVBlockSize=block_size, | ||||||
| WavePerEU=2, | ||||||
| ) | ||||||
| return out_logits | ||||||
| else: | ||||||
| _stage1 = ( | ||||||
| aiter_paged_mqa_logits_module | ||||||
| .deepgemm_fp8_paged_mqa_logits_stage1 | ||||||
| ) | ||||||
| out_qk = torch.full( | ||||||
| (heads, batch_size * next_n, max_model_len), | ||||||
| float("-inf"), | ||||||
| device="cuda", | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| dtype=torch.float32, | ||||||
| ) | ||||||
| # TODO: 1. Replace _stage1 and out_qk.sum with another fused | ||||||
| # variant; | ||||||
| # 2. Remove ChunkQ when AITER PR #2891 merged | ||||||
| _stage1( | ||||||
| q_fp8, | ||||||
| kv_cache_fp8, | ||||||
| weights, | ||||||
| out_qk, | ||||||
| context_lens, | ||||||
| block_tables, | ||||||
| max_model_len, | ||||||
| ChunkQ=heads, | ||||||
| ) | ||||||
| return out_qk.sum(dim=0) | ||||||
| else: | ||||||
| return fp8_paged_mqa_logits_torch( | ||||||
| q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len | ||||||
|
|
@@ -540,7 +623,7 @@ def rocm_aiter_sparse_attn_indexer( | |||||
| num_tokens = slot_mapping.shape[0] | ||||||
| k = k[:num_tokens] | ||||||
|
|
||||||
| ops.indexer_k_quant_and_cache( | ||||||
| indexer_k_quant_and_cache_triton( | ||||||
| k, | ||||||
| kv_cache, | ||||||
| slot_mapping, | ||||||
|
|
@@ -598,6 +681,7 @@ def rocm_aiter_sparse_attn_indexer( | |||||
| if has_decode: | ||||||
| decode_metadata = layer_attn_metadata.decode | ||||||
| assert decode_metadata is not None | ||||||
| kv_block_size = kv_cache.shape[1] | ||||||
| # kv_cache size requirement [num_block, block_size, n_head, head_dim], | ||||||
| # we only have [num_block, block_size, head_dim], | ||||||
| kv_cache = kv_cache.unsqueeze(-2) | ||||||
|
|
@@ -620,14 +704,17 @@ def rocm_aiter_sparse_attn_indexer( | |||||
| assert batch_size == decode_metadata.seq_lens.shape[0] | ||||||
| num_padded_tokens = batch_size * next_n | ||||||
|
|
||||||
| actual_max_seq_len = layer_attn_metadata.max_seq_len | ||||||
|
|
||||||
| logits = rocm_fp8_paged_mqa_logits( | ||||||
| padded_q_fp8_decode_tokens, | ||||||
| kv_cache, | ||||||
| weights[:num_padded_tokens], | ||||||
| decode_metadata.seq_lens, | ||||||
| decode_metadata.block_table, | ||||||
| decode_metadata.schedule_metadata, | ||||||
| max_model_len=max_model_len, | ||||||
| max_model_len=actual_max_seq_len, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||
| block_size=kv_block_size, | ||||||
| ) | ||||||
|
|
||||||
| num_rows = logits.shape[0] | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantization kernel is now hardcoded to use the
"SHUFFLE"layout. However, the attention logic inrocm_fp8_paged_mqa_logitsonly enablesPreshuffle=Truewhenblock_size == 64(line 408) and falls back to the stage1 kernel forblock_size == 1(line 413). The stage1 kernel and the gluon kernel withPreshuffle=Falseexpect the standard"NHD"layout. This inconsistency will lead to incorrect results for anyblock_sizeother than 64. The layout should be conditional on the block size.