diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 238d8d176e30..ba86de9b8978 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -643,6 +643,7 @@ class Envs: SGLANG_OPT_USE_COMPRESSOR_V2 = EnvBool(True) SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) + SGLANG_OPT_FLASHMLA_SPARSE_PREFILL = EnvBool(False) # SWA radix cache SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index f9f396428557..8a9c06b0a101 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -35,8 +35,12 @@ create_paged_compressor_data, ) +from sglang.srt.layers.attention.dsv4.dequant_k_cache import ( + dequantize_k_cache_paged, +) from sglang.srt.layers.attention.dsv4.indexer import C4IndexerBackendMixin from sglang.srt.layers.attention.dsv4.metadata import ( + _LARGE_INDEXER_QUERY_THRESHOLD, PagedIndexerMetadata, copy_metadata, maybe_copy_inplace, @@ -47,6 +51,9 @@ from sglang.srt.layers.attention.dsv4.quant_k_cache import ( quant_to_nope_fp8_rope_bf16_pack_triton, ) +from sglang.srt.layers.attention.dsv4.sparse_prefill_utils import ( + SparsePrefillChunkCache, +) from sglang.srt.layers.dp_attention import ( get_attention_cp_rank, get_attention_cp_size, @@ -109,6 +116,7 @@ class DSV4AttnMetadata: c4_topk_lengths_clamp1: Optional[torch.Tensor] = None c4_sparse_topk_lengths: torch.Tensor = field(init=False) c4_sparse_page_indices: torch.Tensor = field(init=False) + c4_sparse_raw_indices: Optional[torch.Tensor] = field(init=False, default=None) c128_out_loc: Optional[torch.Tensor] = None c128_page_indices: Optional[torch.Tensor] = None @@ -240,7 +248,7 @@ def apply_cp_reindex(self) -> None: f"!= pre_global_len={pre_global_len} (must remain global for compressor write path)" ) - def init_flashmla_related(self): + def init_flashmla_related(self, is_prefill: bool = False): # c4_sparse_topk is set from model_config.index_topk per-model # (small model: 512, large model: 1024). assert self.c4_sparse_topk in (512, 1024), ( @@ -258,6 +266,8 @@ def init_flashmla_related(self): device=self.c4_topk_lengths_clamp1.device, ) self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices) + if is_prefill: + self.c4_sparse_raw_indices = torch.empty_like(self.c4_sparse_page_indices) self.c1_flashmla_metadata = _create_flashmla_metadata() self.c4_flashmla_metadata = _create_flashmla_metadata() self.c128_flashmla_metadata = _create_flashmla_metadata() @@ -271,6 +281,11 @@ class DSV4Metadata: c4_compress_metadata: Optional[FusedCompressMetadata] = None c128_compress_metadata: Optional[FusedCompressMetadata] = None + # Lazily populated on the first call to ``_forward_prefill_sparse`` and + # reused across every layer in the chunk. Reset to ``None`` on copy_ so + # cuda-graph replay rebuilds it for the next forward. + sparse_prefill_cache: Optional[SparsePrefillChunkCache] = None + @property def core_metadata(self) -> DSV4AttnMetadata: return self.core_attn_metadata @@ -282,6 +297,7 @@ def copy_(self, other: DSV4Metadata): maybe_copy_inplace( self.c128_compress_metadata, src=other.c128_compress_metadata ) + self.sparse_prefill_cache = None @dataclass @@ -1031,6 +1047,20 @@ def forward( extra_indices.shape[-1] % 64 == 0 ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" + if forward_batch.forward_mode.is_extend_without_speculative() and ( + q.shape[0] > _LARGE_INDEXER_QUERY_THRESHOLD + or envs.SGLANG_OPT_FLASHMLA_SPARSE_PREFILL.get() + ): + return self._forward_prefill_sparse( + q=q, + layer_id=layer_id, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + core_attn_metadata=core_attn_metadata, + attn_sink=attn_sink, + ) + import flash_mla o = flash_mla.flash_mla_with_kvcache( @@ -1055,6 +1085,107 @@ def forward( raise NotImplementedError("ragged attention") + def _forward_prefill_sparse( + self, + q: torch.Tensor, + layer_id: int, + compress_ratio: Literal[0, 4, 128], + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + core_attn_metadata: DSV4AttnMetadata, + attn_sink: torch.Tensor, + ) -> torch.Tensor: + """Unified prefill via flash_mla_sparse_fwd. Replaces the + flash_mla_with_kvcache call on the extend path. Per request, + positionally gathers the SWA window (always) and the compressed + cache (c4/c128) into a flat bf16 workspace, then lets + flash_mla_sparse_fwd consume the workspace via per-query rebased + indices. Chunk-invariant scaffolding lives in + ``self.forward_metadata.sparse_prefill_cache``. + """ + from flash_mla import flash_mla_sparse_fwd + + # q is (b, 1, h_q, d_qk); flash_mla_sparse_fwd takes (s_q, h_q, d_qk). + q_flat = q.squeeze(1) + + cache = self.forward_metadata.sparse_prefill_cache + if cache is None: + # ``swa_window_size`` on the pool is its storage page size, not + # the model's SWA window — pass both explicitly. + cache = SparsePrefillChunkCache.build( + seq_lens=forward_batch.seq_lens.to(torch.int32), + extend_seq_lens=forward_batch.extend_seq_lens.to(torch.int32), + req_pool_indices=forward_batch.req_pool_indices.to(torch.int32), + req_to_token=self.req_to_token, + full_to_swa=token_to_kv_pool.full_to_swa_index_mapping, + swa_window_size=SWA_WINDOW, + swa_page_size=token_to_kv_pool.swa_window_size, + num_qo_tokens=q_flat.shape[0], + ) + self.forward_metadata.sparse_prefill_cache = cache + + # Resolve the workspace + indices for this ratio, then dequant + # SWA + compressed regions directly into the workspace (no torch.cat). + compressed_slice = None + extra_k_cache = None + extra_page_size = None + flat_token_ids = None + if compress_ratio == 0: + workspace = cache.c0_workspace + combined_indices = cache.c0_combined_indices + combined_lens = cache.c0_combined_lens + swa_slice = workspace + else: + extra_page_size = token_to_kv_pool.get_extra_key_page_size(layer_id) + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + if compress_ratio == 128: + assert core_attn_metadata.c128_page_indices is not None + cache.ensure_c128(core_attn_metadata.c128_page_indices) + flat_token_ids = cache.c128_flat_token_ids + workspace = cache.c128_workspace + combined_indices = cache.c128_combined_indices + combined_lens = cache.c128_combined_lens + else: + assert core_attn_metadata.c4_sparse_raw_indices is not None, ( + "sparse-prefill c4 path requires c4_sparse_raw_indices " + "(allocated in init_flashmla_related when is_prefill=True)" + ) + cache.ensure_c4(core_attn_metadata.page_table, extra_page_size) + flat_token_ids = cache.c4_flat_token_ids + workspace = cache.c4_workspace + combined_indices, combined_lens = cache.combine_c4_layer( + c4_sparse_raw_indices=core_attn_metadata.c4_sparse_raw_indices, + ) + n_compressed = flat_token_ids.shape[0] + compressed_slice = workspace[:n_compressed] + swa_slice = workspace[n_compressed:] + + if compressed_slice is not None: + dequantize_k_cache_paged( + extra_k_cache, + flat_token_ids, + page_size=extra_page_size, + out=compressed_slice, + ) + dequantize_k_cache_paged( + token_to_kv_pool.get_swa_key_buffer_radix(layer_id), + cache.swa_token_ids, + page_size=cache.swa_page_size, + out=swa_slice, + ) + kv = workspace + + o, _, _ = flash_mla_sparse_fwd( + q=q_flat, + kv=kv, + indices=combined_indices.unsqueeze(1), + sm_scale=self.softmax_scale, + d_v=self.head_dim_v, + attn_sink=attn_sink, + topk_length=combined_lens, + ) + return o + def expand_prefill_casually( self, num_tokens: int, @@ -1150,10 +1281,11 @@ def make_core_attn_metadata( if need_compress: core_attn_metadata.init_compression_metadata() - core_attn_metadata.init_flashmla_related() + core_attn_metadata.init_flashmla_related(is_prefill=is_prefill) else: core_attn_metadata.c4_sparse_topk_lengths = None core_attn_metadata.c4_sparse_page_indices = None + core_attn_metadata.c4_sparse_raw_indices = None core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata() core_attn_metadata.c4_flashmla_metadata = None core_attn_metadata.c128_flashmla_metadata = None diff --git a/python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py b/python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py new file mode 100644 index 000000000000..254fd9ed9698 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py @@ -0,0 +1,136 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + +# v4 KV cache layout (see dsv4.index_buf_accessor._set_k_and_s_triton_kernel): +# per-token: 448 fp8 nope + 64 bf16 rope (= 576 contiguous bytes) + +# 7 ue8m0 scales padded to 8 bytes. +# per-page: [token 0..P-1 nope+rope (P*576 bytes)] [token 0..P-1 scale (P*8 bytes)] +# padded up to a multiple of 576. +DIM_NOPE = 448 +DIM_ROPE = 64 +TILE_SIZE = 64 # one nope scale tile = 64 fp8 values +NUM_SCALE_TILES = DIM_NOPE // TILE_SIZE # 7 +NOPE_ROPE_BYTES = DIM_NOPE + DIM_ROPE * 2 # 576 +PADDED_SCALE_PER_TOKEN = NUM_SCALE_TILES + 1 # 8 + + +def dequantize_k_cache_paged( + quant_k_cache: torch.Tensor, + page_table_1_flattened: torch.Tensor, + page_size: int, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Dequantize the DeepSeek v4 paged KV cache for a list of token IDs. + + Args: + quant_k_cache: (num_pages, bytes_per_page_padded) uint8. + page_table_1_flattened: (num_tokens,) int — token IDs into the cache. + page_size: number of tokens per page. + out: optional (num_tokens, 1, DIM_NOPE + DIM_ROPE) bf16 destination. + May be a slice of a larger workspace; the kernel uses out.stride(0) + so contiguous-along-dim-0 slices work. + + Returns: + (num_tokens, 1, DIM_NOPE + DIM_ROPE) bfloat16. + """ + assert quant_k_cache.is_contiguous() + assert page_table_1_flattened.dtype in (torch.int32, torch.int64) + + # The buffer's dtype is whatever the pool exposes (often bf16); the + # underlying storage is uint8. Reinterpret to byte-space first. + quant_k_cache_u8 = quant_k_cache.view(torch.uint8) + num_tokens = page_table_1_flattened.shape[0] + bytes_per_page = quant_k_cache_u8.shape[-1] + s_offset_bytes = page_size * NOPE_ROPE_BYTES + + # Three typed views over the same underlying bytes. + buf_fp8 = quant_k_cache_u8.view(fp8_dtype).reshape(-1) + buf_bf16 = quant_k_cache_u8.view(torch.bfloat16).reshape(-1) + buf_uint8 = quant_k_cache_u8.reshape(-1) + + if out is None: + out = torch.empty( + (num_tokens, 1, DIM_NOPE + DIM_ROPE), + dtype=torch.bfloat16, + device=quant_k_cache.device, + ) + else: + assert out.shape == (num_tokens, 1, DIM_NOPE + DIM_ROPE) + assert out.dtype == torch.bfloat16 + + _dequantize_k_cache_paged_kernel[(num_tokens,)]( + out, + buf_fp8, + buf_bf16, + buf_uint8, + page_table_1_flattened, + out.stride(0), + BYTES_PER_PAGE=bytes_per_page, + PAGE_SIZE=page_size, + DIM_NOPE=DIM_NOPE, + DIM_ROPE=DIM_ROPE, + TILE_SIZE=TILE_SIZE, + NUM_SCALE_TILES=NUM_SCALE_TILES, + NOPE_ROPE_BYTES=NOPE_ROPE_BYTES, + PADDED_SCALE_PER_TOKEN=PADDED_SCALE_PER_TOKEN, + S_OFFSET_BYTES=s_offset_bytes, + ) + return out + + +@triton.jit +def _dequantize_k_cache_paged_kernel( + output_ptr, + buf_fp8_ptr, + buf_bf16_ptr, + buf_uint8_ptr, + page_table_ptr, + output_stride_0, + BYTES_PER_PAGE: tl.constexpr, + PAGE_SIZE: tl.constexpr, + DIM_NOPE: tl.constexpr, + DIM_ROPE: tl.constexpr, + TILE_SIZE: tl.constexpr, + NUM_SCALE_TILES: tl.constexpr, + NOPE_ROPE_BYTES: tl.constexpr, + PADDED_SCALE_PER_TOKEN: tl.constexpr, + S_OFFSET_BYTES: tl.constexpr, +): + # One program per token: load page_table[token_id] once and emit all + # NUM_SCALE_TILES nope tiles + rope tail via tl.static_range. + token_id = tl.program_id(0) + loc = tl.load(page_table_ptr + token_id).to(tl.int64) + page_idx = loc // PAGE_SIZE + in_page = loc % PAGE_SIZE + page_byte_base = page_idx * BYTES_PER_PAGE + token_data_base = page_byte_base + in_page * NOPE_ROPE_BYTES + token_scale_base = ( + page_byte_base + S_OFFSET_BYTES + in_page * PADDED_SCALE_PER_TOKEN + ) + out_row_base = token_id * output_stride_0 + + nope_offs = tl.arange(0, TILE_SIZE) + for tile_id in tl.static_range(NUM_SCALE_TILES): + fp8_off = token_data_base + tile_id * TILE_SIZE + nope_offs + fp8_vals = tl.load(buf_fp8_ptr + fp8_off).to(tl.float32) + + scale_u8 = tl.load(buf_uint8_ptr + token_scale_base + tile_id).to(tl.int32) + scale_pow2 = tl.exp2((scale_u8 - 127).to(tl.float32)) + + out_off = out_row_base + tile_id * TILE_SIZE + nope_offs + tl.store( + output_ptr + out_off, + (fp8_vals * scale_pow2).to(output_ptr.dtype.element_ty), + ) + + rope_offs = tl.arange(0, DIM_ROPE) + bf16_off = (token_data_base + DIM_NOPE) // 2 + rope_offs + rope_data = tl.load(buf_bf16_ptr + bf16_off) + tl.store(output_ptr + out_row_base + DIM_NOPE + rope_offs, rope_data) diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index facf5ef3da40..4fa00d8c599c 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -410,6 +410,8 @@ def forward_c4_indexer( raw_indices = hisparse_coordinator.raw_indices_buffer[ : core_metadata.c4_sparse_page_indices.size(0) ] + elif core_metadata.c4_sparse_raw_indices is not None: + raw_indices = core_metadata.c4_sparse_raw_indices if envs.SGLANG_TOPK_TRANSFORM_512_TORCH.get(): topk_transform_512_pytorch_vectorized( diff --git a/python/sglang/srt/layers/attention/dsv4/metadata.py b/python/sglang/srt/layers/attention/dsv4/metadata.py index 7995dbd959cb..7c1fe82c7441 100644 --- a/python/sglang/srt/layers/attention/dsv4/metadata.py +++ b/python/sglang/srt/layers/attention/dsv4/metadata.py @@ -48,6 +48,7 @@ c4_sparse: means "compressed by 4" but only attend to top-512 tokens. all related length will be clipped to 512. """ +_LARGE_INDEXER_QUERY_THRESHOLD = 11673 def copy_metadata( @@ -108,7 +109,11 @@ def __post_init__(self): else: import deep_gemm - if envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get(): + use_jit_indexer = ( + envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get() + or self.c4_seq_lens.numel() > _LARGE_INDEXER_QUERY_THRESHOLD + ) + if use_jit_indexer: from sglang.jit_kernel.deepseek_v4 import get_paged_mqa_logits_metadata else: from deep_gemm import get_paged_mqa_logits_metadata diff --git a/python/sglang/srt/layers/attention/dsv4/sparse_prefill_utils.py b/python/sglang/srt/layers/attention/dsv4/sparse_prefill_utils.py new file mode 100644 index 000000000000..114459ff3064 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/sparse_prefill_utils.py @@ -0,0 +1,584 @@ +"""Per-query sparse-index combiner for the FlashMLA sparse prefill path. + +Adapts vllm's ``combine_topk_swa_indices`` (vllm/v1/attention/ops/ +deepseek_v4_ops/cache_utils.py) to sglang's flat-workspace layout. For each +query token in a prefill chunk, emits one row of combined indices into the +chunk's bf16 KV workspace: + + [ topk indices into compressed cache (rebased) ] + [ swa positional indices (rebased) ] + [ -1 padding up to a multiple of 128 ] + +The workspace is a single flat ``(total_workspace_tokens, 512)`` tensor +formed by concatenating, per request, that request's compressed-region +gather followed by all requests' SWA-region gathers. Two per-request +offset tensors describe the layout: + + * ``compressed_base[r]`` — flat index where request r's compressed + region begins. Topk indices ``topk_indices[token, j]`` are local to + request r's compressed region (in ``[0, compressed_gather_len[r])``) + and get rebased to flat space by adding ``compressed_base[r]``. + * ``swa_base[r]`` — flat index where request r's SWA region begins. + Per-query SWA indices are computed positionally as + ``swa_base[r] + (pos - swa_len + 1 - gather_start) + j``. + +This is the natural layout for ``flash_mla_sparse_fwd``'s ``kv: (s_kv, 1, +d_qk)`` argument, where ``s_kv`` is the total flat workspace length. + +For SWA-only layers callers pass ``topk=0``, ``compressed_base = 0`` (the +compressed branch becomes a no-op) and any ``compress_ratio >= 1``. +""" + +from dataclasses import dataclass, field +from typing import Optional + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.dsv4.dequant_k_cache import DIM_NOPE, DIM_ROPE +from sglang.srt.utils import ceil_align + +# FlashMLA sparse prefill asserts ``params.topk % B_TOPK == 0``. B_TOPK is 64 +# for the h_q=64 kernel and 128 for h_q=128; pad to 128 to satisfy both. +SPARSE_PREFILL_TOPK_ALIGNMENT = 128 +# Bf16 workspace per-token width, matching ``dequantize_k_cache_paged``'s +# output: 448 fp8 nope (dequanted) + 64 bf16 rope = 512. +WORKSPACE_DIM = DIM_NOPE + DIM_ROPE + + +def combined_topk_width(topk: int, window_size: int) -> int: + """Width of the padded combined_indices last dim that + ``combine_topk_swa_indices`` would produce for these args.""" + return ceil_align(topk + window_size, SPARSE_PREFILL_TOPK_ALIGNMENT) + + +def combine_topk_swa_indices( + topk_indices: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + compressed_base: torch.Tensor, + swa_base: torch.Tensor, + window_size: int, + compress_ratio: int, + topk: int, + out_indices: Optional[torch.Tensor] = None, + out_lens: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Combine topk + SWA indices into a single ``flash_mla_sparse_fwd`` row. + + Args: + topk_indices: (num_tokens, K) int32. Per-query indices into the + compressed-cache region, **already in request-local space** — + i.e. in ``[0, compressed_gather_len[r])`` for the request that + owns each token. Pad entries can be any value; they are ignored + beyond ``topk_len``. + query_start_loc: (num_reqs+1,) int32. Cumulative query lengths; may + be in global (cross-chunk) space — kernel rebases by subtracting + ``query_start_loc[0]``. + seq_lens: (num_reqs,) int32. Each request's full sequence length. + gather_lens: (num_reqs,) int32. Trailing tokens dequanted into the + SWA region for that request. + compressed_base: (num_reqs,) int32. Flat workspace offset where + request r's compressed region begins. Pass all-zeros (or any + value) for SWA-only layers since topk=0 disables this branch. + swa_base: (num_reqs,) int32. Flat workspace offset where request + r's SWA region begins. + window_size: SWA window size. + compress_ratio: must be ``>= 1`` even when topk==0. + topk: configured topk; pass 0 for SWA-only layers. + out_indices: optional preallocated ``(num_tokens, combined_topk)`` + int32 buffer. If provided, the kernel writes the per-query prefix + ``[0, topk_len + swa_len)``; positions beyond are not touched. + Caller must pre-fill with ``-1`` sentinels (and the chunk-invariant + valid-prefix length must hold across reuses). + out_lens: optional preallocated ``(num_tokens,)`` int32 buffer; the + kernel fully overwrites it, so any dtype-correct buffer works. + + Returns: + combined_indices: (num_tokens, padded_topk_swa) int32, padded to a + multiple of 128 with -1 sentinels. + combined_lens: (num_tokens,) int32, valid prefix length per token. + """ + assert topk_indices.dtype == torch.int32 + assert query_start_loc.dtype == torch.int32 + assert seq_lens.dtype == torch.int32 + assert gather_lens.dtype == torch.int32 + assert compressed_base.dtype == torch.int32 + assert swa_base.dtype == torch.int32 + assert compress_ratio >= 1, "COMPRESS_RATIO must be >= 1 (use TOP_K=0 for SWA-only)" + + num_tokens = topk_indices.shape[0] + num_reqs = seq_lens.shape[0] + combined_topk = combined_topk_width(topk, window_size) + if out_indices is None: + combined_indices = torch.full( + (num_tokens, combined_topk), + -1, + dtype=torch.int32, + device=topk_indices.device, + ) + else: + assert out_indices.shape == (num_tokens, combined_topk) + assert out_indices.dtype == torch.int32 + combined_indices = out_indices + if out_lens is None: + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert out_lens.shape == (num_tokens,) + assert out_lens.dtype == torch.int32 + combined_lens = out_lens + + NUM_WORKERS = 128 + _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( + combined_indices, + combined_indices.stride(0), + combined_lens, + topk_indices, + topk_indices.stride(0), + query_start_loc, + seq_lens, + gather_lens, + compressed_base, + swa_base, + TOP_K=topk, + COMPRESS_RATIO=compress_ratio, + WINDOW_SIZE=window_size, + PADDED_TOP_K=triton.next_power_of_2(topk_indices.shape[-1]), + ) + return combined_indices, combined_lens + + +def build_swa_token_ids( + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + req_to_token: torch.Tensor, + full_to_swa: torch.Tensor, + swa_window: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build a flat list of physical SWA-cache token IDs covering each + request's positional union of every query's SWA window. + + Per request, the union spans seq positions + ``[max(0, seq_len - extend - W + 1), seq_len)``, of length + ``min(seq_len, extend + W - 1)``. Each position is translated through + ``req_to_token`` (full kv-cache id) and then ``full_to_swa`` (SWA + cache id) to land in the SWA-cache token-id space that + ``dequantize_k_cache_paged`` consumes. + + Args: + seq_lens: (num_reqs,) int32, per-request total sequence length. + extend_seq_lens: (num_reqs,) int32, per-request query length. + req_pool_indices: (num_reqs,) int32, per-request row in + ``req_to_token``. + req_to_token: (num_reqs_max, max_seq_len) int32. Full kv-cache id + per (request, seq position). + full_to_swa: (full_pool_size + extra,) int64. Maps full kv id to + SWA-cache id. + swa_window: int. SWA window size. + + Returns: + swa_token_ids: (total_swa,) int32, flat physical SWA-cache token IDs. + swa_first_pos: (num_reqs,) int32, first seq position covered per req. + swa_gather_lens: (num_reqs,) int32, gather length per request. + swa_offsets: (num_reqs+1,) int32, exclusive cumsum of swa_gather_lens. + """ + assert seq_lens.dtype == torch.int32 + assert extend_seq_lens.dtype == torch.int32 + assert req_pool_indices.dtype == torch.int32 + assert req_to_token.dtype == torch.int32 + assert full_to_swa.dtype == torch.int64 + + num_reqs = seq_lens.shape[0] + device = seq_lens.device + + swa_gather_lens = torch.minimum(seq_lens, extend_seq_lens + (swa_window - 1)).to( + torch.int32 + ) + swa_first_pos = (seq_lens - swa_gather_lens).to(torch.int32) + swa_offsets = torch.zeros(num_reqs + 1, dtype=torch.int32, device=device) + swa_offsets[1:] = torch.cumsum(swa_gather_lens, dim=0).to(torch.int32) + total_swa = int(swa_offsets[-1].item()) # one CPU sync per chunk + + swa_token_ids = torch.empty(total_swa, dtype=torch.int32, device=device) + if total_swa == 0: + return swa_token_ids, swa_first_pos, swa_gather_lens, swa_offsets + + NUM_WORKERS = 128 + _build_swa_token_ids_kernel[(num_reqs, NUM_WORKERS)]( + swa_token_ids, + swa_first_pos, + swa_gather_lens, + swa_offsets, + req_pool_indices, + req_to_token, + req_to_token.stride(0), + full_to_swa, + ) + return swa_token_ids, swa_first_pos, swa_gather_lens, swa_offsets + + +@triton.jit +def _build_swa_token_ids_kernel( + out_ptr, + swa_first_pos_ptr, + swa_gather_lens_ptr, + swa_offsets_ptr, + req_pool_indices_ptr, + req_to_token_ptr, + req_to_token_stride, + full_to_swa_ptr, +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + num_workers = tl.num_programs(1) + + first_pos = tl.load(swa_first_pos_ptr + batch_idx) + gather_len = tl.load(swa_gather_lens_ptr + batch_idx) + out_off = tl.load(swa_offsets_ptr + batch_idx) + req_pool_idx = tl.load(req_pool_indices_ptr + batch_idx).to(tl.int64) + + for i in range(worker_id, gather_len, num_workers): + pos = first_pos + i + full_id = tl.load( + req_to_token_ptr + req_pool_idx * req_to_token_stride + pos + ).to(tl.int64) + swa_id = tl.load(full_to_swa_ptr + full_id).to(tl.int32) + tl.store(out_ptr + out_off + i, swa_id) + + +@triton.jit +def _combine_topk_swa_indices_kernel( + combined_indices_ptr, + combined_indices_stride, + combined_lens_ptr, + topk_indices_ptr, + topk_indices_stride, + query_start_loc_ptr, + seq_lens_ptr, + gather_lens_ptr, + compressed_base_ptr, + swa_base_ptr, + TOP_K: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + WINDOW_SIZE: tl.constexpr, + PADDED_TOP_K: tl.constexpr, +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + num_workers = tl.num_programs(1) + + # query_start_loc may be a global tensor; rebase to chunk-local offsets + # by subtracting the chunk's starting value. + base = tl.load(query_start_loc_ptr) + query_start = tl.load(query_start_loc_ptr + batch_idx) - base + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - base + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + batch_idx) + gather_len = tl.load(gather_lens_ptr + batch_idx) + compressed_base = tl.load(compressed_base_ptr + batch_idx) + swa_base = tl.load(swa_base_ptr + batch_idx) + start_pos = seq_len - query_len + # SWA portion of the gathered buffer starts from position + # (seq_len - gather_len), not 0. The +pos-gather_start formula maps a + # query's window back into the workspace's SWA region. + gather_start = seq_len - gather_len + + for token_idx in range(query_start + worker_id, query_end, num_workers): + token_idx_in_query = token_idx - query_start + pos = start_pos + token_idx_in_query + # Both the C4 indexer and the C128 metadata builder emit + # min((pos+1)//compress_ratio, topk_tokens) valid entries. Caller + # passes TOP_K=0 for SWA-only layers to zero this out. + topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) + swa_len = tl.minimum(pos + 1, WINDOW_SIZE) + + offset = tl.arange(0, PADDED_TOP_K) + mask = offset < topk_len + topk_vals = tl.load( + topk_indices_ptr + token_idx * topk_indices_stride + offset, + mask=mask, + ) + tl.store( + combined_indices_ptr + token_idx * combined_indices_stride + offset, + topk_vals + compressed_base, + mask=mask, + ) + + offset = tl.arange(0, WINDOW_SIZE) + # Workspace SWA index: swa_base[r] + (gather_offset_in_buffer). + # For positions [pos - swa_len + 1, pos], the buffer offsets are + # [pos - swa_len + 1 - gather_start, pos - gather_start]. + tl.store( + combined_indices_ptr + + token_idx * combined_indices_stride + + topk_len + + offset, + swa_base + offset + pos - swa_len + 1 - gather_start, + mask=offset < swa_len, + ) + + tl.store(combined_lens_ptr + token_idx, topk_len + swa_len) + + +@dataclass +class SparsePrefillChunkCache: + """Chunk-invariant scaffolding for ``_forward_prefill_sparse``. + + The fields here depend only on the prefill chunk (forward_batch, + req_to_token, full_to_swa_index_mapping, and the c4/c128 page tables) + and not on the per-layer k_cache. Reused across every layer in the + chunk to avoid rebuilding tiny tensors 61 times per forward pass. + """ + + # Geometry computed once per chunk. + num_reqs: int + num_qo_tokens: int + # Model's SWA window — the per-query attention range. Used by + # combine_topk_swa_indices' WINDOW_SIZE and by build_swa_token_ids's + # gather_lens. Must match SWA_WINDOW from the backend (e.g. 128), NOT + # the SWA pool's storage page size (often 256). + swa_window_size: int + # SWA cache pool's storage page size — used as the dequant kernel's + # ``page_size`` so that ``slot // page_size`` recovers the right page. + swa_page_size: int + seq_lens: torch.Tensor # (num_reqs,) int32 + query_start_loc: torch.Tensor # (num_reqs+1,) int32 + + # SWA-side (every layer needs these, all chunk-invariant). + swa_token_ids: torch.Tensor # (total_swa,) int32 + swa_first_pos: torch.Tensor # (num_reqs,) int32 + swa_gather_lens: torch.Tensor # (num_reqs,) int32 + swa_offsets: torch.Tensor # (num_reqs+1,) int32 + + # c0 pre-computed combine output (entire input set is chunk-invariant). + c0_combined_indices: torch.Tensor = field(default=None) + c0_combined_lens: torch.Tensor = field(default=None) + # Preallocated workspace reused across layers — avoids per-layer + # ``torch.cat`` and bf16 allocations. Shape (total_swa, 1, 512) bf16 for + # c0, (total_compressed + total_swa, 1, 512) for c4/c128. Dequant kernels + # write directly via ``out=workspace[slice]``. + c0_workspace: torch.Tensor = field(default=None) + + # c128: positional layout of the c128 cache + pre-computed combine. + c128_flat_token_ids: Optional[torch.Tensor] = None # (num_reqs * c128_max,) int32 + c128_combined_indices: Optional[torch.Tensor] = None + c128_combined_lens: Optional[torch.Tensor] = None + c128_workspace: Optional[torch.Tensor] = None + + # c4: positional layout of the c4 cache (combine output is per-layer). + c4_flat_token_ids: Optional[torch.Tensor] = None # (num_reqs * c4_max,) int32 + c4_page_size: Optional[int] = None + c4_compressed_base: Optional[torch.Tensor] = None # (num_reqs,) int32 + c4_swa_base: Optional[torch.Tensor] = None # (num_reqs,) int32 + c4_workspace: Optional[torch.Tensor] = None + # Tail stays at the -1 sentinel because the valid prefix length is + # chunk-invariant per request — subsequent layers only overwrite that prefix. + c4_combined_indices: Optional[torch.Tensor] = None + c4_combined_lens: Optional[torch.Tensor] = None + + @classmethod + def build( + cls, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + req_to_token: torch.Tensor, + full_to_swa: torch.Tensor, + swa_window_size: int, + swa_page_size: int, + num_qo_tokens: int, + ) -> "SparsePrefillChunkCache": + device = seq_lens.device + num_reqs = seq_lens.shape[0] + + query_start_loc = torch.zeros(num_reqs + 1, dtype=torch.int32, device=device) + query_start_loc[1:] = torch.cumsum(extend_seq_lens, dim=0).to(torch.int32) + + swa_token_ids, swa_first_pos, swa_gather_lens, swa_offsets = ( + build_swa_token_ids( + seq_lens=seq_lens, + extend_seq_lens=extend_seq_lens, + req_pool_indices=req_pool_indices, + req_to_token=req_to_token, + full_to_swa=full_to_swa, + swa_window=swa_window_size, + ) + ) + + cache = cls( + num_reqs=num_reqs, + num_qo_tokens=num_qo_tokens, + swa_window_size=swa_window_size, + swa_page_size=swa_page_size, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + swa_token_ids=swa_token_ids, + swa_first_pos=swa_first_pos, + swa_gather_lens=swa_gather_lens, + swa_offsets=swa_offsets, + ) + + # Pre-compute the c0 combine output: TOPK=0, compressed_base=0, + # swa_base = swa_offsets[:-1]. All inputs are chunk-invariant. + zero_topk = torch.zeros((num_qo_tokens, 1), dtype=torch.int32, device=device) + zero_compressed_base = torch.zeros(num_reqs, dtype=torch.int32, device=device) + c0_swa_base = swa_offsets[:-1].to(torch.int32) + cache.c0_combined_indices, cache.c0_combined_lens = combine_topk_swa_indices( + topk_indices=zero_topk, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + gather_lens=swa_gather_lens, + compressed_base=zero_compressed_base, + swa_base=c0_swa_base, + window_size=swa_window_size, + compress_ratio=1, + topk=0, + ) + cache.c0_workspace = torch.empty( + (swa_token_ids.shape[0], 1, WORKSPACE_DIM), + dtype=torch.bfloat16, + device=device, + ) + return cache + + def ensure_c128(self, c128_page_indices: torch.Tensor) -> None: + """Populate c128-side fields from per-query c128 page indices. + + ``c128_page_indices[q, j]`` carries slot ids derived from + ``page_table[q]`` (request-keyed; same across queries of a request) + but masked per-token by ``j < seq_lens_casual[q] // 128`` — entries + beyond that are -1. We need a row whose mask covers every j the + combine kernel might reference, i.e. up to ``seq_lens[r] // 128``; + that's the *last* query's mask. Pulling the first query in a fresh + prefill (``seq_lens_casual = 1``) yields an all-`-1` row that + clamp_min(0) collapses to slot 0, sending dequant to a polluted + slot and producing garbage c128 entries. + """ + if self.c128_flat_token_ids is not None: + return + device = self.seq_lens.device + c128_max = c128_page_indices.shape[-1] + last_q_per_req = (self.query_start_loc[1:] - 1).long() + per_req_c128 = c128_page_indices[last_q_per_req] + # Clamp -1 -> 0 so dequant doesn't OOB; combine masks the invalid + # tail via topk_len. + flat_c128_ids = per_req_c128.reshape(-1).clamp_min(0).to(torch.int32) + compressed_base = ( + torch.arange(self.num_reqs, dtype=torch.int32, device=device) * c128_max + ).to(torch.int32) + total_compressed = self.num_reqs * c128_max + # Pre-compute the c128 combine output. topk_indices[q, j] = j is the + # arange-broadcast pattern; we materialize it once here so the + # combine kernel can read it like any other topk tensor. + topk_indices = ( + torch.arange(c128_max, dtype=torch.int32, device=device)[None, :] + .expand(self.num_qo_tokens, -1) + .contiguous() + ) + swa_base = (total_compressed + self.swa_offsets[:-1]).to(torch.int32) + combined_indices, combined_lens = combine_topk_swa_indices( + topk_indices=topk_indices, + query_start_loc=self.query_start_loc, + seq_lens=self.seq_lens, + gather_lens=self.swa_gather_lens, + compressed_base=compressed_base, + swa_base=swa_base, + window_size=self.swa_window_size, + compress_ratio=128, + topk=c128_max, + ) + + self.c128_flat_token_ids = flat_c128_ids + self.c128_combined_indices = combined_indices + self.c128_combined_lens = combined_lens + self.c128_workspace = torch.empty( + (total_compressed + self.swa_token_ids.shape[0], 1, WORKSPACE_DIM), + dtype=torch.bfloat16, + device=device, + ) + + def ensure_c4( + self, + page_table: torch.Tensor, + c4_page_size: int, + ) -> None: + """Populate c4-side fields from the per-query page table. + + ``page_table`` is (num_qo_tokens, max_blocks); rows within a request + are duplicates. The combine output is per-layer (depends on the + layer's remapped topk_indices), so we only cache the gather-side + scaffolding plus compressed/swa bases. + """ + if self.c4_flat_token_ids is not None: + return + device = self.seq_lens.device + max_blocks = page_table.shape[-1] + c4_max = max_blocks * c4_page_size + first_q_per_req = self.query_start_loc[:-1].long() + per_req_page_table = page_table[first_q_per_req] + + k_arange = torch.arange(c4_max, dtype=torch.int32, device=device) + block_idx = (k_arange // c4_page_size).long() + in_page = (k_arange % c4_page_size).to(torch.int32) + c4_token_ids_2d = ( + per_req_page_table.index_select(1, block_idx) * c4_page_size + in_page + ).to(torch.int32) + flat_c4_ids = c4_token_ids_2d.reshape(-1).clamp_min(0) + total_compressed = self.num_reqs * c4_max + compressed_base = ( + torch.arange(self.num_reqs, dtype=torch.int32, device=device) * c4_max + ).to(torch.int32) + swa_base = (total_compressed + self.swa_offsets[:-1]).to(torch.int32) + + self.c4_flat_token_ids = flat_c4_ids + self.c4_page_size = c4_page_size + self.c4_compressed_base = compressed_base + self.c4_swa_base = swa_base + self.c4_workspace = torch.empty( + (total_compressed + self.swa_token_ids.shape[0], 1, WORKSPACE_DIM), + dtype=torch.bfloat16, + device=device, + ) + + def combine_c4_layer( + self, + c4_sparse_raw_indices: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Per-layer combine for c4. ``c4_sparse_raw_indices`` is the topk + kernel's positional output (``block_in_seq * c_page_size + in_page``) + — already in the request-local workspace coordinate that + ``combine_topk_swa_indices`` expects, so no remap is needed. + + Reuses preallocated ``c4_combined_indices`` / ``c4_combined_lens`` + buffers across layers — the kernel only overwrites the valid prefix. + """ + topk = c4_sparse_raw_indices.shape[-1] + if self.c4_combined_indices is None: + device = self.seq_lens.device + self.c4_combined_indices = torch.full( + (self.num_qo_tokens, combined_topk_width(topk, self.swa_window_size)), + -1, + dtype=torch.int32, + device=device, + ) + self.c4_combined_lens = torch.empty( + self.num_qo_tokens, dtype=torch.int32, device=device + ) + return combine_topk_swa_indices( + topk_indices=c4_sparse_raw_indices, + query_start_loc=self.query_start_loc, + seq_lens=self.seq_lens, + gather_lens=self.swa_gather_lens, + compressed_base=self.c4_compressed_base, + swa_base=self.c4_swa_base, + window_size=self.swa_window_size, + compress_ratio=4, + topk=topk, + out_indices=self.c4_combined_indices, + out_lens=self.c4_combined_lens, + ) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index ead96a4a6c21..89021672ebfc 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -1171,7 +1171,7 @@ def forward( metadata = forward_batch.attn_backend.forward_metadata core_meta = metadata.core_attn_metadata core_meta.apply_cp_reindex() - core_meta.init_flashmla_related() + core_meta.init_flashmla_related(is_prefill=True) if metadata.indexer_metadata is not None: metadata.indexer_metadata = ( forward_batch.attn_backend.init_forward_metadata_indexer( diff --git a/python/sglang/srt/models/deepseek_v4_nextn.py b/python/sglang/srt/models/deepseek_v4_nextn.py index bd116c29a8e2..148074ec3927 100644 --- a/python/sglang/srt/models/deepseek_v4_nextn.py +++ b/python/sglang/srt/models/deepseek_v4_nextn.py @@ -248,7 +248,7 @@ def forward( metadata = forward_batch.attn_backend.forward_metadata core_meta = metadata.core_attn_metadata core_meta.apply_cp_reindex() - core_meta.init_flashmla_related() + core_meta.init_flashmla_related(is_prefill=True) if metadata.indexer_metadata is not None: metadata.indexer_metadata = ( forward_batch.attn_backend.init_forward_metadata_indexer(