Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ca14421
init framework
hzh0425 Nov 20, 2025
e02374e
[Sparse]: support sparse io schedule
huangtingwei9988 Nov 21, 2025
8cdf703
[NSA Sparse]: separate KV cache and indexer_k allocation
hzh0425 Nov 22, 2025
12d29f4
[Sparse NSA]: Support truncate mem and page table when prompt len > …
hzh0425 Nov 23, 2025
a3e5cdc
[Sparse]: Support HICache Integrate; fix some bugs;
huangtingwei9988 Nov 24, 2025
7448e15
[Sparse NSA]: Init NSA Integrate
hzh0425 Nov 24, 2025
bacec58
[nsa sparse]: Optimize Code style
hzh0425 Nov 24, 2025
67fba78
Successfully ran DSA for the first time.
huangtingwei9988 Nov 26, 2025
b3882d0
support SparseKVCacheManager & fuse sparse triton kernel
huangtingwei9988 Nov 28, 2025
6e0d18a
fix bug
huangtingwei9988 Nov 28, 2025
822db3e
Merge pull request #1 from hzh0425/sparse-framework_1128
huangtingwei9988 Nov 28, 2025
1457e3c
tmp optimized
hzh0425 Nov 29, 2025
ecb1623
Merge pull request #2 from hzh0425/sparse-framework-11-29
huangtingwei9988 Dec 1, 2025
9a10e34
[Sparse NSA Kernel]: Optimize nsa diff kernel's performance
hzh0425 Nov 29, 2025
c3cabfc
tmp optimized 1201 huangtingwei9988
huangtingwei9988 Dec 1, 2025
8abfe85
ifx
huangtingwei9988 Dec 1, 2025
4cf64f1
resolve conflicts
huangtingwei9988 Dec 1, 2025
e15705e
Merge pull request #3 from hzh0425/sparse-framework-1201-huangtingwei…
huangtingwei9988 Dec 1, 2025
756feb1
[nsa sparse]: Using unified kernel to process all spares requests to …
LingYeAI Dec 1, 2025
2265d76
[Sparse]: Support cuda graph
hzh0425 Dec 3, 2025
2ba37c6
Move Module into mem_cache
hzh0425 Dec 5, 2025
12b0a80
[Sparse]: Optimize code style
hzh0425 Dec 5, 2025
dc1807a
fix memory leak
hzh0425 Dec 5, 2025
367dfa2
[Sparse]: Fix triton accuracy bug
hzh0425 Dec 8, 2025
f506107
[io kernel]:optimize transfer kernel temporarily
LingYeAI Dec 5, 2025
d9dbd95
Optimize Code Style
hzh0425 Dec 8, 2025
8417746
fix ack_id keyError & Distinguishing ack_id between prefill and decode
huangtingwei9988 Dec 10, 2025
55a28e1
Fix memory leak when outputlen = 1
hzh0425 Dec 9, 2025
ec6be78
[alloc]: fix token pool write
hzh0425 Dec 10, 2025
eb7ec6c
[Sparse]: Refactor algorithm layer
hzh0425 Dec 11, 2025
8912bf7
[Sparse]: Refactor page_wise_algorithm.py
hzh0425 Dec 13, 2025
2807ecc
[Sparse]: Optimize sparse kernel,fix cuda ima
hzh0425 Dec 12, 2025
4a9bd9d
[Sparse]: Refactor algorithm layer again.
hzh0425 Dec 13, 2025
5107404
tmp optimization for io_kernel block_quota
huangtingwei9988 Dec 14, 2025
ff76890
Tmp Support For PD
hzh0425 Dec 17, 2025
e56dfd3
Fix memory_pool_host nsa dim
hzh0425 Dec 17, 2025
bf209e3
Refactor mem alloc, only seperate alloc index_k on decode mode
hzh0425 Dec 22, 2025
7b5a414
Remove NSAReqToTokenPool
hzh0425 Dec 23, 2025
9b27093
[1/N][Sparse With Hicache]: Add Sparse Interface (#14741)
hzh0425 Dec 25, 2025
a8f5ffe
Refactoring sparse diff triton kernel
huangtingwei9988 Dec 26, 2025
55faee3
support lru evict for sparse kv cache
huangtingwei9988 Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 79 additions & 8 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.managers.utils import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
is_enable_hierarchical_nsa,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.common import release_kv_cache
from sglang.srt.mem_cache.common import (
release_kv_cache,
truncate_kv_cache_after_prefill,
)
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
Expand All @@ -60,6 +66,7 @@
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.mem_cache.sparsity import get_sparse_coordinator
from sglang.srt.tracing.trace import trace_event_batch, trace_slice_end
from sglang.srt.utils import get_int_env_var
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
Expand Down Expand Up @@ -134,6 +141,40 @@ def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))


class NSADecodeReqToTokenPool(DecodeReqToTokenPool):
"""NSA DecodeReqToTokenPool: separate mapping for KV cache and nsa indexer_k"""

def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
pre_alloc_size: int,
):
super().__init__(
size, max_context_len, device, enable_memory_saver, pre_alloc_size
)

memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_nsa_index_k = torch.zeros(
(size + pre_alloc_size, max_context_len),
dtype=torch.int32,
device=device,
)

def write_index_token(self, indices, values):
"""Write indexer_k mapping"""
self.req_to_nsa_index_k[indices] = values

def clear(self):
super().clear()
self.req_to_nsa_index_k.zero_()


class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):

def __init__(
Expand Down Expand Up @@ -510,9 +551,14 @@ def pop_preallocated(self) -> List[DecodeRequest]:
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
if isinstance(self.req_to_token_pool, NSADecodeReqToTokenPool):
kv_indices_full = self.req_to_token_pool.req_to_nsa_index_k[
decode_req.req.req_pool_idx, :seq_len
]
else:
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
else:
Expand Down Expand Up @@ -624,10 +670,10 @@ def _pre_alloc(self, req: Req) -> torch.Tensor:
req.kv_allocated_len = fill_len
req.kv_committed_len = fill_len
if self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len)
alloc_result = self.token_to_kv_pool_allocator.alloc(fill_len)
else:
device = self.token_to_kv_pool_allocator.device
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
alloc_result = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens=torch.tensor([0], dtype=torch.int64, device=device),
prefix_lens_cpu=torch.tensor([0], dtype=torch.int64),
seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device),
Expand All @@ -637,11 +683,25 @@ def _pre_alloc(self, req: Req) -> torch.Tensor:
)

assert (
kv_loc is not None
alloc_result is not None
), "KV cache is full! There is a bug in memory estimation."

if is_enable_hierarchical_nsa(self.token_to_kv_pool_allocator):
kv_loc, index_k_loc = alloc_result
else:
kv_loc = alloc_result
index_k_loc = None

# Write KV indices to req_to_token
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)

# Write index_k indices for NSA
if index_k_loc is not None:
self.req_to_token_pool.write_index_token(
(req.req_pool_idx, slice(0, len(index_k_loc))),
index_k_loc.to(torch.int32),
)

# populate metadata
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.origin_input_ids)
Expand Down Expand Up @@ -959,4 +1019,15 @@ def process_decode_queue(self: Scheduler):
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived

# NSA: Register, Offload and Truncate after KV transfer completes
sparse_coordinator = get_sparse_coordinator()
if sparse_coordinator is not None:
for req in alloc_reqs:
sparse_coordinator.on_request_begin(req)
sparse_coordinator.on_request_prefill_end(req)
truncate_kv_cache_after_prefill(
req, self.req_to_token_pool, self.tree_cache
)

self.waiting_queue.extend(alloc_reqs)
33 changes: 33 additions & 0 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.sparsity import DeepSeekNSAAlgorithm, get_sparse_coordinator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
Expand Down Expand Up @@ -362,6 +363,12 @@ def __init__(
1 if model_runner.server_args.enable_deterministic_inference else 0
)

# Sparse attention coordinator
self.sparse_coordinator = get_sparse_coordinator()
if self.sparse_coordinator is not None:
if isinstance(self.sparse_coordinator.algorithm, DeepSeekNSAAlgorithm):
self.sparse_coordinator = None

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
Expand Down Expand Up @@ -961,6 +968,13 @@ def forward_extend(
else:
o = result

if self.sparse_coordinator is not None:
self.sparse_coordinator.attention_end(
output=o,
layer=layer,
forward_batch=forward_batch,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def forward_decode(
Expand Down Expand Up @@ -999,6 +1013,18 @@ def forward_decode(

# Use precomputed metadata across all layers
metadata = self.forward_metadata

# Apply sparse attention: modify metadata based on query
if self.sparse_coordinator is not None:
self.sparse_coordinator.attention_begin(
query=q,
key=k,
value=v,
layer=layer,
forward_batch=forward_batch,
attn_metadata=metadata,
)

local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attn = (
self.attention_chunk_size is not None
Expand Down Expand Up @@ -1231,6 +1257,13 @@ def forward_decode(
else:
o = result

if self.sparse_coordinator is not None:
self.sparse_coordinator.attention_end(
output=o,
layer=layer,
forward_batch=forward_batch,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
Expand Down
27 changes: 21 additions & 6 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,10 @@ def _forward_cuda_k_only(
key = self._get_k_bf16(x, positions, enable_dual_stream)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)

if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
index_loc = self._get_index_cache_loc(forward_batch)
forward_batch.token_to_kv_pool.set_index_k_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
loc=index_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
Expand Down Expand Up @@ -621,11 +620,10 @@ def forward_cuda(
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
index_loc = self._get_index_cache_loc(forward_batch)
forward_batch.token_to_kv_pool.set_index_k_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
loc=index_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
Expand Down Expand Up @@ -667,6 +665,23 @@ def forward_cuda(
)
return topk_result

def _get_index_cache_loc(self, forward_batch: ForwardBatch) -> torch.Tensor:
pool = forward_batch.req_to_token_pool

if (
forward_batch.forward_mode.is_decode()
and hasattr(pool, "req_to_nsa_index_k")
):
index_loc = pool.req_to_nsa_index_k[
forward_batch.req_pool_indices, forward_batch.seq_lens - 1
].to(torch.int64)
else:
index_loc = forward_batch.out_cache_loc

if not index_loc.is_contiguous():
index_loc = index_loc.contiguous()
return index_loc

def forward_npu(
self,
x: torch.Tensor,
Expand Down
Loading