From 719c9b7152675e6ab23910ad7d2587b7a9ea0c12 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Mon, 23 Mar 2026 07:57:39 +0000 Subject: [PATCH 1/2] style refinement for hisparse --- python/sglang/jit_kernel/csrc/hisparse.cuh | 54 ++++++++++++++----- .../srt/managers/hisparse_coordinator.py | 10 ++-- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/managers/scheduler.py | 53 +++++++++--------- 4 files changed, 76 insertions(+), 43 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/hisparse.cuh b/python/sglang/jit_kernel/csrc/hisparse.cuh index 3cf12178f243..5aa2fb75a341 100644 --- a/python/sglang/jit_kernel/csrc/hisparse.cuh +++ b/python/sglang/jit_kernel/csrc/hisparse.cuh @@ -53,6 +53,19 @@ __device__ __forceinline__ int warp_inclusive_scan(int* s_data, int lane_id, int return accumulator; } +// Shared memory size calculation for dynamic allocation. +// Layout: int32_t region (4-byte aligned) followed by int16_t region (2-byte aligned). +template +struct SmemLayout { + static constexpr int HASH_SIZE = NUM_TOP_K * 2; + static constexpr int NUM_BUFFER_CHUNKS = (HOT_BUFFER_SIZE + WARP_SIZE - 1) / WARP_SIZE; + // int32_t region: top_k_tokens + chunk_offset + evict_chunk_offset + hash_keys + total_hits + newest_hit + static constexpr int TOTAL_INT32 = NUM_TOP_K + (NUM_BUFFER_CHUNKS + 1) + (NUM_BUFFER_CHUNKS + 1) + HASH_SIZE + 2; + // int16_t region: lru_slots_out + hash_vals + static constexpr int TOTAL_INT16 = HOT_BUFFER_SIZE + HASH_SIZE; + static constexpr size_t BYTES = TOTAL_INT32 * sizeof(int32_t) + TOTAL_INT16 * sizeof(int16_t); +}; + // Each block processes one request // req_pool_indices are int64_t (pool indices can be large), seq_lens are int32_t // Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] @@ -118,21 +131,29 @@ __global__ void load_cache_to_device_buffer_kernel( return; } + // Dynamic shared memory layout: int32_t arrays first, then int16_t arrays. + extern __shared__ char smem_raw[]; + using Layout = SmemLayout; + constexpr int HASH_SIZE = Layout::HASH_SIZE; + + int32_t* smem_i32 = reinterpret_cast(smem_raw); // Top-k token positions; reused as miss-token scratch in the copy phase - __shared__ int32_t s_top_k_tokens[NUM_TOP_K]; + int32_t* s_top_k_tokens = smem_i32; // Prefix-sum offsets for hit counting and miss counting - __shared__ int32_t s_chunk_offset[NUM_BUFFER_CHUNKS + 1]; + int32_t* s_chunk_offset = s_top_k_tokens + NUM_TOP_K; // Prefix-sum offsets for evictable counting - __shared__ int32_t s_evict_chunk_offset[NUM_BUFFER_CHUNKS + 1]; + int32_t* s_evict_chunk_offset = s_chunk_offset + (NUM_BUFFER_CHUNKS + 1); + // Open-addressing hash table: top-k token_id → top-k index (keys) + int32_t* s_hash_keys = s_evict_chunk_offset + (NUM_BUFFER_CHUNKS + 1); + // Scalar counters + int32_t& s_total_hits = s_hash_keys[HASH_SIZE]; + int32_t& s_newest_hit = s_hash_keys[HASH_SIZE + 1]; + + int16_t* smem_i16 = reinterpret_cast(smem_i32 + Layout::TOTAL_INT32); // Compacted slot ordering: [hits fwd→ ... ←evictables bwd] - __shared__ int16_t s_lru_slots_out[HOT_BUFFER_SIZE]; - // Open-addressing hash table: top-k token_id → top-k index - constexpr int HASH_SIZE = NUM_TOP_K * 2; - __shared__ int32_t s_hash_keys[HASH_SIZE]; - __shared__ int16_t s_hash_vals[HASH_SIZE]; - - __shared__ int32_t s_total_hits; - __shared__ int32_t s_newest_hit; + int16_t* s_lru_slots_out = smem_i16; + // Open-addressing hash table: top-k token_id → top-k index (values) + int16_t* s_hash_vals = s_lru_slots_out + HOT_BUFFER_SIZE; // Initialize shared memory: counters, hash table, prefix-sum offsets. if (tid == 0) { @@ -363,8 +384,15 @@ void load_cache_to_device_buffer( const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); - LaunchKernel(bs, BLOCK_SIZE, device)( - load_cache_to_device_buffer_kernel, + constexpr size_t smem_bytes = SmemLayout::BYTES; + auto kernel_fn = load_cache_to_device_buffer_kernel; + // Opt in to dynamic shared memory beyond the default 48 KB limit. + if constexpr (smem_bytes > 48u * 1024u) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)( + kernel_fn, static_cast(top_k_tokens.data_ptr()), static_cast(device_buffer_tokens.data_ptr()), static_cast(host_cache_locs.data_ptr()), diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index 92ef22f404cf..889585dca65e 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -123,7 +123,7 @@ def set_decode_producer_stream(self, stream) -> None: self.decode_producer_stream = stream def admit_request_into_staging(self, req: Req) -> None: - req.staging = True + req.hisparse_staging = True logical_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(req.fill_ids) ] @@ -224,7 +224,7 @@ def collect_ready_reqs(self) -> List[Req]: _, _, req = self.ack_staging_queue.pop(0) # prepare device buffer and update req self.alloc_device_buffer(req) - req.staging = False + req.hisparse_staging = False self._skip_first_backup[req.req_pool_idx] = True finish_count -= 1 ready_reqs.append(req) @@ -494,7 +494,7 @@ def abort_staging_request(self, req: Req) -> None: """Remove a request from the staging queue and free its host resources. Must be called when aborting a request that has been admitted into staging - but has not yet completed (i.e. req.staging is True). + but has not yet completed (i.e. req.hisparse_staging is True). """ # Remove from staging queue self.ack_staging_queue = [ @@ -510,10 +510,10 @@ def abort_staging_request(self, req: Req) -> None: self.mem_pool_host.free(host_indices) self.req_to_host_pool[req.req_pool_idx, :] = -1 self._skip_first_backup[req.req_pool_idx] = False - req.staging = False + req.hisparse_staging = False def retract_req(self, req: Req) -> None: - if req.staging: + if req.hisparse_staging: self.abort_staging_request(req) else: self.request_finished(req) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 670cc1c8595d..1c6ab57152c6 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -801,7 +801,7 @@ def __init__( self.init_diffusion_llm(dllm_config) # For hisparse - self.staging = False + self.hisparse_staging = False @property def seqlen(self) -> int: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f31303e8b18b..4eea3940b6ab 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2101,6 +2101,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: chunked_req_to_exclude.add(self.chunked_req) self.stash_chunked_request(self.chunked_req) + # HiSparse has its own prefill-to-decode transition; skip last_batch merge. if self.enable_hisparse: ready_reqs = self.hisparse_coordinator.collect_ready_reqs() if len(ready_reqs) > 0: @@ -2110,31 +2111,35 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: else: self.running_batch.merge_batch(new_batch) self.running_batch.hisparse_coordinator = self.hisparse_coordinator - else: - if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.last_batch.chunked_req is not None: - # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. - # We need to discard it. - chunked_req_to_exclude.add(self.last_batch.chunked_req) - - if self.dllm_config is not None and self.last_batch.reqs: - chunked_req_to_exclude.update(self.last_batch.reqs) - - # Filter batch - last_bs = self.last_batch.batch_size() - self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) - ) - if self.last_batch.batch_size() < last_bs: - self.running_batch.batch_is_full = False - # Merge the new batch into the running batch. - if not self.last_batch.is_empty(): - if self.running_batch.is_empty(): - self.running_batch = self.last_batch - else: - # Merge running_batch with prefill batch - self.running_batch.merge_batch(self.last_batch) + if ( + not self.enable_hisparse + and self.last_batch + and self.last_batch.forward_mode.is_extend() + ): + if self.last_batch.chunked_req is not None: + # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. + # We need to discard it. + chunked_req_to_exclude.add(self.last_batch.chunked_req) + + if self.dllm_config is not None and self.last_batch.reqs: + chunked_req_to_exclude.update(self.last_batch.reqs) + + # Filter batch + last_bs = self.last_batch.batch_size() + self.last_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) + if self.last_batch.batch_size() < last_bs: + self.running_batch.batch_is_full = False + + # Merge the new batch into the running batch. + if not self.last_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = self.last_batch + else: + # Merge running_batch with prefill batch + self.running_batch.merge_batch(self.last_batch) # For prefill-only batch, filter out finished requests since they # won't go through the decode step. This keeps running_batch accurate From 8ed29bfe47512cf54e4f6ecef7548d426ab889ee Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Tue, 24 Mar 2026 06:17:17 +0000 Subject: [PATCH 2/2] data type fix and flashmla_sparse guard --- python/sglang/jit_kernel/csrc/hisparse.cuh | 74 +++++++++++-------- .../srt/managers/hisparse_coordinator.py | 6 +- python/sglang/srt/server_args.py | 23 ++++++ 3 files changed, 69 insertions(+), 34 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/hisparse.cuh b/python/sglang/jit_kernel/csrc/hisparse.cuh index 5aa2fb75a341..2919b59ba6a4 100644 --- a/python/sglang/jit_kernel/csrc/hisparse.cuh +++ b/python/sglang/jit_kernel/csrc/hisparse.cuh @@ -67,10 +67,10 @@ struct SmemLayout { }; // Each block processes one request -// req_pool_indices are int64_t (pool indices can be large), seq_lens are int32_t +// req_pool_indices are int64_t (pool indices can be large), seq_lens can be int32_t or int64_t // Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] // newest_slot is at HOT_BUFFER_SIZE (first position of extra page) -template +template __global__ void load_cache_to_device_buffer_kernel( const int32_t* __restrict__ top_k_tokens, int32_t* __restrict__ device_buffer_tokens, @@ -82,7 +82,7 @@ __global__ void load_cache_to_device_buffer_kernel( void* __restrict__ device_buffer_v, int32_t* __restrict__ top_k_device_locs, const int64_t* __restrict__ req_pool_indices, - const int32_t* __restrict__ seq_lens, + const SeqLensT* __restrict__ seq_lens, int16_t* __restrict__ lru_slots, const int32_t* __restrict__ num_real_reqs, int64_t buffer_stride_0, @@ -384,35 +384,47 @@ void load_cache_to_device_buffer( const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); - constexpr size_t smem_bytes = SmemLayout::BYTES; - auto kernel_fn = load_cache_to_device_buffer_kernel; - // Opt in to dynamic shared memory beyond the default 48 KB limit. - if constexpr (smem_bytes > 48u * 1024u) { - cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + // Generic lambda: both int32 and int64 kernel variants are compiled; + // the correct one is selected at runtime based on seq_lens dtype. + auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr) { + constexpr size_t smem_bytes = SmemLayout::BYTES; + if constexpr (smem_bytes > 48u * 1024u) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)( + kernel_fn, + static_cast(top_k_tokens.data_ptr()), + static_cast(device_buffer_tokens.data_ptr()), + static_cast(host_cache_locs.data_ptr()), + static_cast(device_buffer_locs.data_ptr()), + host_cache_k.data_ptr(), + (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), + device_buffer_k.data_ptr(), + (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), + static_cast(top_k_device_locs.data_ptr()), + static_cast(req_pool_indices.data_ptr()), + seq_lens_ptr, + static_cast(lru_slots.data_ptr()), + static_cast(num_real_reqs.data_ptr()), + buffer_stride_0, + host_stride, + lru_slot_stride_0, + top_k_tokens_stride, + top_k_device_locs_stride, + page_size, + item_size_bytes); + }; + + const auto dtype = seq_lens.dtype(); + if (dtype.code == kDLInt && dtype.bits == 64) { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr())); + } else { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr())); } - - LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)( - kernel_fn, - static_cast(top_k_tokens.data_ptr()), - static_cast(device_buffer_tokens.data_ptr()), - static_cast(host_cache_locs.data_ptr()), - static_cast(device_buffer_locs.data_ptr()), - host_cache_k.data_ptr(), - (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), - device_buffer_k.data_ptr(), - (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), - static_cast(top_k_device_locs.data_ptr()), - static_cast(req_pool_indices.data_ptr()), - static_cast(seq_lens.data_ptr()), - static_cast(lru_slots.data_ptr()), - static_cast(num_real_reqs.data_ptr()), - buffer_stride_0, - host_stride, - lru_slot_stride_0, - top_k_tokens_stride, - top_k_device_locs_stride, - page_size, - item_size_bytes); } } // namespace diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index 889585dca65e..eda985d4f80c 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -556,14 +556,14 @@ def swap_in_selected_pages( layer_id: int, ) -> torch.Tensor: """Swap selected top-k tokens into device memory and return their indices.""" - # The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32. + # The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32 or int64. if req_pool_indices.dtype != torch.int64: raise ValueError( f"req_pool_indices dtype {req_pool_indices.dtype} is not int64 as expected" ) - if seq_lens.dtype != torch.int32: + if seq_lens.dtype not in (torch.int32, torch.int64): raise ValueError( - f"seq_lens dtype {seq_lens.dtype} is not int32 as expected" + f"seq_lens dtype {seq_lens.dtype} is not int32 or int64 as expected" ) if top_k_result.dtype != torch.int32: raise ValueError( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a0f0704f3469..73c86ef993c8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1417,6 +1417,18 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: user_set_prefill = self.nsa_prefill_backend is not None user_set_decode = self.nsa_decode_backend is not None + # HiSparse requires flashmla_sparse for both prefill and decode + if self.enable_hisparse: + if not user_set_prefill: + self.nsa_prefill_backend = "flashmla_sparse" + if not user_set_decode: + self.nsa_decode_backend = "flashmla_sparse" + logger.warning( + f"HiSparse enabled: using flashmla_sparse NSA backends " + f"(prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend})." + ) + return + if not user_set_prefill and not user_set_decode and is_hip(): self.nsa_prefill_backend = "tilelang" self.nsa_decode_backend = "tilelang" @@ -6062,6 +6074,17 @@ def check_server_args(self): assert ( self.disable_radix_cache ), "Hierarchical sparse attention currently requires --disable-radix-cache." + for attr, label in [ + ("nsa_prefill_backend", "prefill"), + ("nsa_decode_backend", "decode"), + ]: + backend = getattr(self, attr) + if backend is not None and backend != "flashmla_sparse": + raise ValueError( + f"HiSparse requires flashmla_sparse NSA {label} backend, " + f"but got --nsa-{label}-backend={backend}. " + f"Please use --nsa-{label}-backend=flashmla_sparse or omit it." + ) assert ( self.schedule_conservativeness >= 0