diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py index 344deed660e8..f634bce0ccab 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -5,10 +5,14 @@ import triton.language as tl from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_bool_env_var, is_hip _is_hip = is_hip() _is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter.ops.cache import cp_gather_indexer_k_quant_cache if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool @@ -163,8 +167,52 @@ def triton( class GetKAndS: @classmethod def execute(cls, *args, **kwargs): + if _use_aiter: + return cls.aiter(*args, **kwargs) return cls.triton(*args, **kwargs) + @classmethod + def aiter( + cls, + pool: "NSATokenToKVPool", + buf: torch.Tensor, + page_indices: torch.Tensor, + seq_len_tensor: torch.Tensor, + seq_len_sum: int, + max_seq_len: int, + ): + from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype + + page_size = pool.page_size + index_head_dim = pool.index_head_dim + quant_block_size = pool.quant_block_size + scale_elems = index_head_dim // quant_block_size + + kv_cache = buf.view(-1, page_size, index_head_dim + scale_elems * 4).view( + fp8_dtype + ) + dst_k = torch.empty( + (seq_len_sum, index_head_dim), dtype=torch.uint8, device=buf.device + ) + dst_scale = torch.empty( + (seq_len_sum, scale_elems * 4), dtype=torch.uint8, device=buf.device + ) + + cu_seq_lens = torch.zeros( + seq_len_tensor.shape[0] + 1, dtype=torch.int32, device=buf.device + ) + torch.cumsum(seq_len_tensor.to(torch.int32), dim=0, out=cu_seq_lens[1:]) + + cp_gather_indexer_k_quant_cache( + kv_cache, + dst_k.view(fp8_dtype), + dst_scale, + page_indices.to(torch.int32), + cu_seq_lens, + preshuffle=True, + ) + return dst_k, dst_scale + @classmethod def triton( cls, @@ -364,15 +412,14 @@ def _set_k_and_s_triton( raise ValueError( f"index_k_scale must be 1D or 2D, got shape {index_k_scale.shape}" ) - if _is_hip: - assert buf_numel_per_page == 1 * (128 + 4) - else: - assert buf_numel_per_page == 64 * (128 + 4) + assert buf_numel_per_page == page_size * (128 + 4) assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__ assert index_head_dim == 128 assert scale_dim == 1 if _is_hip: - assert page_size == 1 + assert ( + page_size % 16 == 0 + ), f"HIP preshuffle requires page_size to be a multiple of 16, got {page_size}" else: assert page_size == 64 diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 6e92533b744e..9969e4a48ed9 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -430,12 +430,13 @@ def _get_topk_paged( page_size = forward_batch.token_to_kv_pool.page_size # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm if _is_hip: - assert page_size == 1, "only support page size 1" - block_tables = metadata.get_page_table_1() + assert ( + page_size % 16 == 0 + ), f"HIP preshuffle requires page_size to be a multiple of 16, got {page_size}" else: assert page_size == 64, "only support page size 64" - # NOTE(dark): this support extend/decode/decode+graph - block_tables = metadata.get_page_table_64() + # NOTE(dark): this support extend/decode/decode+graph + block_tables = metadata.get_page_table_64() max_seq_len = block_tables.shape[1] * page_size kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer( @@ -462,17 +463,12 @@ def _get_topk_paged( assert len(q_fp8.shape) == 3 q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now assert len(kv_cache_fp8.shape) == 2 - block_kv = 1 if _is_hip else 64 + block_kv = page_size num_heads_kv = 1 head_dim_with_sf = 132 - if _is_hip: - kv_cache_fp8 = kv_cache_fp8.view( - -1, block_kv, num_heads_kv, head_dim_with_sf - ) - else: - kv_cache_fp8 = kv_cache_fp8.view( - kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf - ) + kv_cache_fp8 = kv_cache_fp8.view( + kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf + ) assert len(weights.shape) == 3 weights = weights.squeeze(2) @@ -483,9 +479,8 @@ def _get_topk_paged( from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits batch_size, next_n, heads, _ = q_fp8.shape - logits = torch.full( + logits = torch.empty( (batch_size * next_n, max_seq_len), - float("-inf"), device=q_fp8.device, dtype=torch.float32, ) @@ -497,7 +492,7 @@ def _get_topk_paged( seqlens_32, block_tables, max_seq_len, - Preshuffle=False, + Preshuffle=True, KVBlockSize=block_kv, ) else: @@ -561,7 +556,9 @@ def _get_topk_ragged( page_size = forward_batch.token_to_kv_pool.page_size if _is_hip: - assert page_size == 1, "only support page size 1" + assert ( + page_size % 16 == 0 + ), f"HIP preshuffle requires page_size to be a multiple of 16, got {page_size}" else: assert page_size == 64, "only support page size 64" @@ -572,10 +569,7 @@ def _get_topk_ragged( ) weights = weights.squeeze(-1) - if _is_hip: - block_tables = metadata.get_page_table_1() - else: - block_tables = metadata.get_page_table_64() + block_tables = metadata.get_page_table_64() assert ( forward_batch.seq_lens_cpu is not None @@ -1031,19 +1025,24 @@ def _store_index_k_cache( ) return - # Fast path: AITER fused quant + cache store (HIP, page_size=1) + # Fast path: AITER fused quant + cache store (HIP, preshuffle) if _use_aiter: + page_size = forward_batch.token_to_kv_pool.page_size buf = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer( layer_id=layer_id ) - # Reshape from (num_pages, 132) uint8 to (num_pages, 1, 132) fp8 - # to match kernel's (num_blocks, block_size, head_dim + scale_bytes) layout - kv_cache = buf.unsqueeze(1).view(fp8_dtype) + # Reshape from (num_pages, page_size*(128+4)) uint8 to (num_pages, page_size, 132) fp8 + kv_cache = buf.view(-1, page_size, 132).view(fp8_dtype) out_loc = forward_batch.out_cache_loc if not out_loc.is_contiguous(): out_loc = out_loc.contiguous() indexer_k_quant_and_cache( - key, kv_cache, out_loc, self.block_size, self.scale_fmt + key, + kv_cache, + out_loc, + self.block_size, + self.scale_fmt, + preshuffle=True, ) return diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index a0e0765f2da7..8ec89cf8c20f 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1908,7 +1908,9 @@ def __init__( assert index_head_dim == 128 if _is_hip: - assert self.page_size == 1 + assert ( + self.page_size % 16 == 0 + ), f"HIP preshuffle requires page_size to be a multiple of 16, got {self.page_size}" else: assert self.page_size == 64 with ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 71afe2192d7a..9e316efcce01 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1741,15 +1741,8 @@ def _handle_model_specific_adjustments(self): f"attn_tp_size={self.tp_size}, attention weights will be sharded across {self.tp_size} ranks." ) - if is_hip(): - self.page_size = 1 - logger.warning( - "Setting page size to 1 for DeepSeek DSA on ROCm." - ) - else: - # For CUDA GPU - self.page_size = 64 - logger.warning("Setting page size to 64 for DeepSeek DSA.") + self.page_size = 64 + logger.warning("Setting page size to 64 for DeepSeek DSA.") import torch