Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ class Envs:
SGLANG_OPT_USE_COMPRESSOR_V2 = EnvBool(True)
SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False)
SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False)
SGLANG_OPT_FLASHMLA_SPARSE_PREFILL = EnvBool(False)

# SWA radix cache
SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True)
Expand Down
136 changes: 134 additions & 2 deletions python/sglang/srt/layers/attention/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@
create_paged_compressor_data,
)

from sglang.srt.layers.attention.dsv4.dequant_k_cache import (
dequantize_k_cache_paged,
)
from sglang.srt.layers.attention.dsv4.indexer import C4IndexerBackendMixin
from sglang.srt.layers.attention.dsv4.metadata import (
_LARGE_INDEXER_QUERY_THRESHOLD,
PagedIndexerMetadata,
copy_metadata,
maybe_copy_inplace,
Expand All @@ -47,6 +51,9 @@
from sglang.srt.layers.attention.dsv4.quant_k_cache import (
quant_to_nope_fp8_rope_bf16_pack_triton,
)
from sglang.srt.layers.attention.dsv4.sparse_prefill_utils import (
SparsePrefillChunkCache,
)
from sglang.srt.layers.dp_attention import (
get_attention_cp_rank,
get_attention_cp_size,
Expand Down Expand Up @@ -109,6 +116,7 @@ class DSV4AttnMetadata:
c4_topk_lengths_clamp1: Optional[torch.Tensor] = None
c4_sparse_topk_lengths: torch.Tensor = field(init=False)
c4_sparse_page_indices: torch.Tensor = field(init=False)
c4_sparse_raw_indices: Optional[torch.Tensor] = field(init=False, default=None)

c128_out_loc: Optional[torch.Tensor] = None
c128_page_indices: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -240,7 +248,7 @@ def apply_cp_reindex(self) -> None:
f"!= pre_global_len={pre_global_len} (must remain global for compressor write path)"
)

def init_flashmla_related(self):
def init_flashmla_related(self, is_prefill: bool = False):
# c4_sparse_topk is set from model_config.index_topk per-model
# (small model: 512, large model: 1024).
assert self.c4_sparse_topk in (512, 1024), (
Expand All @@ -258,6 +266,8 @@ def init_flashmla_related(self):
device=self.c4_topk_lengths_clamp1.device,
)
self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices)
if is_prefill:
self.c4_sparse_raw_indices = torch.empty_like(self.c4_sparse_page_indices)
self.c1_flashmla_metadata = _create_flashmla_metadata()
self.c4_flashmla_metadata = _create_flashmla_metadata()
self.c128_flashmla_metadata = _create_flashmla_metadata()
Expand All @@ -271,6 +281,11 @@ class DSV4Metadata:
c4_compress_metadata: Optional[FusedCompressMetadata] = None
c128_compress_metadata: Optional[FusedCompressMetadata] = None

# Lazily populated on the first call to ``_forward_prefill_sparse`` and
# reused across every layer in the chunk. Reset to ``None`` on copy_ so
# cuda-graph replay rebuilds it for the next forward.
sparse_prefill_cache: Optional[SparsePrefillChunkCache] = None

@property
def core_metadata(self) -> DSV4AttnMetadata:
return self.core_attn_metadata
Expand All @@ -282,6 +297,7 @@ def copy_(self, other: DSV4Metadata):
maybe_copy_inplace(
self.c128_compress_metadata, src=other.c128_compress_metadata
)
self.sparse_prefill_cache = None


@dataclass
Expand Down Expand Up @@ -1031,6 +1047,20 @@ def forward(
extra_indices.shape[-1] % 64 == 0
), f"{extra_indices.shape=}'s last dimension is not aligned to 64"

if forward_batch.forward_mode.is_extend_without_speculative() and (
q.shape[0] > _LARGE_INDEXER_QUERY_THRESHOLD
or envs.SGLANG_OPT_FLASHMLA_SPARSE_PREFILL.get()
):
return self._forward_prefill_sparse(
q=q,
layer_id=layer_id,
compress_ratio=compress_ratio,
forward_batch=forward_batch,
token_to_kv_pool=token_to_kv_pool,
core_attn_metadata=core_attn_metadata,
attn_sink=attn_sink,
)

import flash_mla

o = flash_mla.flash_mla_with_kvcache(
Expand All @@ -1055,6 +1085,107 @@ def forward(

raise NotImplementedError("ragged attention")

def _forward_prefill_sparse(
self,
q: torch.Tensor,
layer_id: int,
compress_ratio: Literal[0, 4, 128],
forward_batch: ForwardBatch,
token_to_kv_pool: DeepSeekV4TokenToKVPool,
core_attn_metadata: DSV4AttnMetadata,
attn_sink: torch.Tensor,
) -> torch.Tensor:
"""Unified prefill via flash_mla_sparse_fwd. Replaces the
flash_mla_with_kvcache call on the extend path. Per request,
positionally gathers the SWA window (always) and the compressed
cache (c4/c128) into a flat bf16 workspace, then lets
flash_mla_sparse_fwd consume the workspace via per-query rebased
indices. Chunk-invariant scaffolding lives in
``self.forward_metadata.sparse_prefill_cache``.
"""
from flash_mla import flash_mla_sparse_fwd

# q is (b, 1, h_q, d_qk); flash_mla_sparse_fwd takes (s_q, h_q, d_qk).
q_flat = q.squeeze(1)

cache = self.forward_metadata.sparse_prefill_cache
if cache is None:
# ``swa_window_size`` on the pool is its storage page size, not
# the model's SWA window — pass both explicitly.
cache = SparsePrefillChunkCache.build(
seq_lens=forward_batch.seq_lens.to(torch.int32),
extend_seq_lens=forward_batch.extend_seq_lens.to(torch.int32),
req_pool_indices=forward_batch.req_pool_indices.to(torch.int32),
req_to_token=self.req_to_token,
full_to_swa=token_to_kv_pool.full_to_swa_index_mapping,
swa_window_size=SWA_WINDOW,
swa_page_size=token_to_kv_pool.swa_window_size,
num_qo_tokens=q_flat.shape[0],
)
self.forward_metadata.sparse_prefill_cache = cache

# Resolve the workspace + indices for this ratio, then dequant
# SWA + compressed regions directly into the workspace (no torch.cat).
compressed_slice = None
extra_k_cache = None
extra_page_size = None
flat_token_ids = None
if compress_ratio == 0:
workspace = cache.c0_workspace
combined_indices = cache.c0_combined_indices
combined_lens = cache.c0_combined_lens
swa_slice = workspace
else:
extra_page_size = token_to_kv_pool.get_extra_key_page_size(layer_id)
extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id)
if compress_ratio == 128:
assert core_attn_metadata.c128_page_indices is not None
cache.ensure_c128(core_attn_metadata.c128_page_indices)
flat_token_ids = cache.c128_flat_token_ids
workspace = cache.c128_workspace
combined_indices = cache.c128_combined_indices
combined_lens = cache.c128_combined_lens
else:
assert core_attn_metadata.c4_sparse_raw_indices is not None, (
"sparse-prefill c4 path requires c4_sparse_raw_indices "
"(allocated in init_flashmla_related when is_prefill=True)"
)
cache.ensure_c4(core_attn_metadata.page_table, extra_page_size)
flat_token_ids = cache.c4_flat_token_ids
workspace = cache.c4_workspace
combined_indices, combined_lens = cache.combine_c4_layer(
c4_sparse_raw_indices=core_attn_metadata.c4_sparse_raw_indices,
)
n_compressed = flat_token_ids.shape[0]
compressed_slice = workspace[:n_compressed]
swa_slice = workspace[n_compressed:]

if compressed_slice is not None:
dequantize_k_cache_paged(
extra_k_cache,
flat_token_ids,
page_size=extra_page_size,
out=compressed_slice,
)
Comment on lines +1163 to +1169
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this dequant all the c4 cache, or only the selected c4 cache?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like all the c4 cache. Since selected c4 cache changes between different layers, but the sparse prefill cache here is only computed once at the first layer (so the value of flat_token_ids doesn't change). Need @zcnrex to confirm this

dequantize_k_cache_paged(
token_to_kv_pool.get_swa_key_buffer_radix(layer_id),
cache.swa_token_ids,
page_size=cache.swa_page_size,
out=swa_slice,
)
kv = workspace

o, _, _ = flash_mla_sparse_fwd(
q=q_flat,
kv=kv,
indices=combined_indices.unsqueeze(1),
sm_scale=self.softmax_scale,
d_v=self.head_dim_v,
attn_sink=attn_sink,
topk_length=combined_lens,
)
return o

def expand_prefill_casually(
self,
num_tokens: int,
Expand Down Expand Up @@ -1150,10 +1281,11 @@ def make_core_attn_metadata(

if need_compress:
core_attn_metadata.init_compression_metadata()
core_attn_metadata.init_flashmla_related()
core_attn_metadata.init_flashmla_related(is_prefill=is_prefill)
else:
core_attn_metadata.c4_sparse_topk_lengths = None
core_attn_metadata.c4_sparse_page_indices = None
core_attn_metadata.c4_sparse_raw_indices = None
core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata()
core_attn_metadata.c4_flashmla_metadata = None
core_attn_metadata.c128_flashmla_metadata = None
Expand Down
136 changes: 136 additions & 0 deletions python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Optional
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a torch ref impl here, and add an output comparison test between ref & applied kerenl (maybe under __main__)


import torch
import triton
import triton.language as tl

from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz

fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn

# v4 KV cache layout (see dsv4.index_buf_accessor._set_k_and_s_triton_kernel):
# per-token: 448 fp8 nope + 64 bf16 rope (= 576 contiguous bytes) +
# 7 ue8m0 scales padded to 8 bytes.
# per-page: [token 0..P-1 nope+rope (P*576 bytes)] [token 0..P-1 scale (P*8 bytes)]
# padded up to a multiple of 576.
DIM_NOPE = 448
DIM_ROPE = 64
TILE_SIZE = 64 # one nope scale tile = 64 fp8 values
NUM_SCALE_TILES = DIM_NOPE // TILE_SIZE # 7
NOPE_ROPE_BYTES = DIM_NOPE + DIM_ROPE * 2 # 576
PADDED_SCALE_PER_TOKEN = NUM_SCALE_TILES + 1 # 8


def dequantize_k_cache_paged(
quant_k_cache: torch.Tensor,
page_table_1_flattened: torch.Tensor,
page_size: int,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Dequantize the DeepSeek v4 paged KV cache for a list of token IDs.

Args:
quant_k_cache: (num_pages, bytes_per_page_padded) uint8.
page_table_1_flattened: (num_tokens,) int — token IDs into the cache.
page_size: number of tokens per page.
out: optional (num_tokens, 1, DIM_NOPE + DIM_ROPE) bf16 destination.
May be a slice of a larger workspace; the kernel uses out.stride(0)
so contiguous-along-dim-0 slices work.

Returns:
(num_tokens, 1, DIM_NOPE + DIM_ROPE) bfloat16.
"""
assert quant_k_cache.is_contiguous()
assert page_table_1_flattened.dtype in (torch.int32, torch.int64)

# The buffer's dtype is whatever the pool exposes (often bf16); the
# underlying storage is uint8. Reinterpret to byte-space first.
quant_k_cache_u8 = quant_k_cache.view(torch.uint8)
num_tokens = page_table_1_flattened.shape[0]
bytes_per_page = quant_k_cache_u8.shape[-1]
s_offset_bytes = page_size * NOPE_ROPE_BYTES

# Three typed views over the same underlying bytes.
buf_fp8 = quant_k_cache_u8.view(fp8_dtype).reshape(-1)
buf_bf16 = quant_k_cache_u8.view(torch.bfloat16).reshape(-1)
buf_uint8 = quant_k_cache_u8.reshape(-1)

if out is None:
out = torch.empty(
(num_tokens, 1, DIM_NOPE + DIM_ROPE),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
else:
assert out.shape == (num_tokens, 1, DIM_NOPE + DIM_ROPE)
assert out.dtype == torch.bfloat16

_dequantize_k_cache_paged_kernel[(num_tokens,)](
out,
buf_fp8,
buf_bf16,
buf_uint8,
page_table_1_flattened,
out.stride(0),
BYTES_PER_PAGE=bytes_per_page,
PAGE_SIZE=page_size,
DIM_NOPE=DIM_NOPE,
DIM_ROPE=DIM_ROPE,
TILE_SIZE=TILE_SIZE,
NUM_SCALE_TILES=NUM_SCALE_TILES,
NOPE_ROPE_BYTES=NOPE_ROPE_BYTES,
PADDED_SCALE_PER_TOKEN=PADDED_SCALE_PER_TOKEN,
S_OFFSET_BYTES=s_offset_bytes,
)
return out


@triton.jit
def _dequantize_k_cache_paged_kernel(
output_ptr,
buf_fp8_ptr,
buf_bf16_ptr,
buf_uint8_ptr,
page_table_ptr,
output_stride_0,
BYTES_PER_PAGE: tl.constexpr,
PAGE_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
TILE_SIZE: tl.constexpr,
NUM_SCALE_TILES: tl.constexpr,
NOPE_ROPE_BYTES: tl.constexpr,
PADDED_SCALE_PER_TOKEN: tl.constexpr,
S_OFFSET_BYTES: tl.constexpr,
):
# One program per token: load page_table[token_id] once and emit all
# NUM_SCALE_TILES nope tiles + rope tail via tl.static_range.
token_id = tl.program_id(0)
loc = tl.load(page_table_ptr + token_id).to(tl.int64)
page_idx = loc // PAGE_SIZE
in_page = loc % PAGE_SIZE
page_byte_base = page_idx * BYTES_PER_PAGE
token_data_base = page_byte_base + in_page * NOPE_ROPE_BYTES
token_scale_base = (
page_byte_base + S_OFFSET_BYTES + in_page * PADDED_SCALE_PER_TOKEN
)
out_row_base = token_id * output_stride_0

nope_offs = tl.arange(0, TILE_SIZE)
for tile_id in tl.static_range(NUM_SCALE_TILES):
fp8_off = token_data_base + tile_id * TILE_SIZE + nope_offs
fp8_vals = tl.load(buf_fp8_ptr + fp8_off).to(tl.float32)

scale_u8 = tl.load(buf_uint8_ptr + token_scale_base + tile_id).to(tl.int32)
scale_pow2 = tl.exp2((scale_u8 - 127).to(tl.float32))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will tl.int32 and tl.float32 cause overflow here?


out_off = out_row_base + tile_id * TILE_SIZE + nope_offs
tl.store(
output_ptr + out_off,
(fp8_vals * scale_pow2).to(output_ptr.dtype.element_ty),
)

rope_offs = tl.arange(0, DIM_ROPE)
bf16_off = (token_data_base + DIM_NOPE) // 2 + rope_offs
rope_data = tl.load(buf_bf16_ptr + bf16_off)
tl.store(output_ptr + out_row_base + DIM_NOPE + rope_offs, rope_data)
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/attention/dsv4/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ def forward_c4_indexer(
raw_indices = hisparse_coordinator.raw_indices_buffer[
: core_metadata.c4_sparse_page_indices.size(0)
]
elif core_metadata.c4_sparse_raw_indices is not None:
raw_indices = core_metadata.c4_sparse_raw_indices

if envs.SGLANG_TOPK_TRANSFORM_512_TORCH.get():
topk_transform_512_pytorch_vectorized(
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/layers/attention/dsv4/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
c4_sparse: means "compressed by 4" but only attend to top-512 tokens.
all related length will be clipped to 512.
"""
_LARGE_INDEXER_QUERY_THRESHOLD = 11673
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hardcode to this value. Can we avoid hardcoding



def copy_metadata(
Expand Down Expand Up @@ -108,7 +109,11 @@ def __post_init__(self):
else:
import deep_gemm

if envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get():
use_jit_indexer = (
envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get()
or self.c4_seq_lens.numel() > _LARGE_INDEXER_QUERY_THRESHOLD
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it numel here? I think the numel of c4_seq_lens should be a small value like batch size?

)
if use_jit_indexer:
from sglang.jit_kernel.deepseek_v4 import get_paged_mqa_logits_metadata
else:
from deep_gemm import get_paged_mqa_logits_metadata
Expand Down
Loading
Loading