Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 76 additions & 36 deletions python/sglang/jit_kernel/csrc/hisparse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,24 @@ __device__ __forceinline__ int warp_inclusive_scan(int* s_data, int lane_id, int
return accumulator;
}

// Shared memory size calculation for dynamic allocation.
// Layout: int32_t region (4-byte aligned) followed by int16_t region (2-byte aligned).
template <int NUM_TOP_K, int HOT_BUFFER_SIZE>
struct SmemLayout {
static constexpr int HASH_SIZE = NUM_TOP_K * 2;
static constexpr int NUM_BUFFER_CHUNKS = (HOT_BUFFER_SIZE + WARP_SIZE - 1) / WARP_SIZE;
// int32_t region: top_k_tokens + chunk_offset + evict_chunk_offset + hash_keys + total_hits + newest_hit
static constexpr int TOTAL_INT32 = NUM_TOP_K + (NUM_BUFFER_CHUNKS + 1) + (NUM_BUFFER_CHUNKS + 1) + HASH_SIZE + 2;
// int16_t region: lru_slots_out + hash_vals
static constexpr int TOTAL_INT16 = HOT_BUFFER_SIZE + HASH_SIZE;
static constexpr size_t BYTES = TOTAL_INT32 * sizeof(int32_t) + TOTAL_INT16 * sizeof(int16_t);
Comment on lines +56 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The introduction of the SmemLayout struct for dynamic shared memory calculation is a good practice. It centralizes the shared memory layout definition, making it easier to manage and reason about the memory usage, especially with different data types and alignment requirements. This approach enhances flexibility and maintainability for CUDA kernels.

};

// Each block processes one request
// req_pool_indices are int64_t (pool indices can be large), seq_lens are int32_t
// req_pool_indices are int64_t (pool indices can be large), seq_lens can be int32_t or int64_t
// Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token]
// newest_slot is at HOT_BUFFER_SIZE (first position of extra page)
template <int BLOCK_SIZE, int NUM_TOP_K, int HOT_BUFFER_SIZE, bool IsMLA>
template <int BLOCK_SIZE, int NUM_TOP_K, int HOT_BUFFER_SIZE, bool IsMLA, typename SeqLensT>
__global__ void load_cache_to_device_buffer_kernel(
const int32_t* __restrict__ top_k_tokens,
int32_t* __restrict__ device_buffer_tokens,
Expand All @@ -69,7 +82,7 @@ __global__ void load_cache_to_device_buffer_kernel(
void* __restrict__ device_buffer_v,
int32_t* __restrict__ top_k_device_locs,
const int64_t* __restrict__ req_pool_indices,
const int32_t* __restrict__ seq_lens,
const SeqLensT* __restrict__ seq_lens,
int16_t* __restrict__ lru_slots,
const int32_t* __restrict__ num_real_reqs,
int64_t buffer_stride_0,
Expand Down Expand Up @@ -118,21 +131,29 @@ __global__ void load_cache_to_device_buffer_kernel(
return;
}

// Dynamic shared memory layout: int32_t arrays first, then int16_t arrays.
extern __shared__ char smem_raw[];
using Layout = SmemLayout<NUM_TOP_K, HOT_BUFFER_SIZE>;
constexpr int HASH_SIZE = Layout::HASH_SIZE;

int32_t* smem_i32 = reinterpret_cast<int32_t*>(smem_raw);
// Top-k token positions; reused as miss-token scratch in the copy phase
__shared__ int32_t s_top_k_tokens[NUM_TOP_K];
int32_t* s_top_k_tokens = smem_i32;
// Prefix-sum offsets for hit counting and miss counting
__shared__ int32_t s_chunk_offset[NUM_BUFFER_CHUNKS + 1];
int32_t* s_chunk_offset = s_top_k_tokens + NUM_TOP_K;
// Prefix-sum offsets for evictable counting
__shared__ int32_t s_evict_chunk_offset[NUM_BUFFER_CHUNKS + 1];
int32_t* s_evict_chunk_offset = s_chunk_offset + (NUM_BUFFER_CHUNKS + 1);
// Open-addressing hash table: top-k token_id → top-k index (keys)
int32_t* s_hash_keys = s_evict_chunk_offset + (NUM_BUFFER_CHUNKS + 1);
// Scalar counters
int32_t& s_total_hits = s_hash_keys[HASH_SIZE];
int32_t& s_newest_hit = s_hash_keys[HASH_SIZE + 1];

int16_t* smem_i16 = reinterpret_cast<int16_t*>(smem_i32 + Layout::TOTAL_INT32);
// Compacted slot ordering: [hits fwd→ ... ←evictables bwd]
__shared__ int16_t s_lru_slots_out[HOT_BUFFER_SIZE];
// Open-addressing hash table: top-k token_id → top-k index
constexpr int HASH_SIZE = NUM_TOP_K * 2;
__shared__ int32_t s_hash_keys[HASH_SIZE];
__shared__ int16_t s_hash_vals[HASH_SIZE];

__shared__ int32_t s_total_hits;
__shared__ int32_t s_newest_hit;
int16_t* s_lru_slots_out = smem_i16;
// Open-addressing hash table: top-k token_id → top-k index (values)
int16_t* s_hash_vals = s_lru_slots_out + HOT_BUFFER_SIZE;
Comment on lines +134 to +156
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Switching from static shared memory arrays to dynamic allocation using extern __shared__ char smem_raw[]; and pointer arithmetic is a robust way to handle variable shared memory requirements. This allows for more efficient use of shared memory resources, especially when NUM_TOP_K or HOT_BUFFER_SIZE can vary significantly. The explicit casting and offsetting for int32_t and int16_t regions are correctly implemented.


// Initialize shared memory: counters, hash table, prefix-sum offsets.
if (tid == 0) {
Expand Down Expand Up @@ -363,28 +384,47 @@ void load_cache_to_device_buffer(
const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0];
const auto device = LaunchKernel::resolve_device(top_k_tokens.device());

LaunchKernel(bs, BLOCK_SIZE, device)(
load_cache_to_device_buffer_kernel<BLOCK_SIZE, NUM_TOP_K, HOT_BUFFER_SIZE, IsMLA>,
static_cast<const int32_t*>(top_k_tokens.data_ptr()),
static_cast<int32_t*>(device_buffer_tokens.data_ptr()),
static_cast<const int64_t*>(host_cache_locs.data_ptr()),
static_cast<const int32_t*>(device_buffer_locs.data_ptr()),
host_cache_k.data_ptr(),
(IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(),
device_buffer_k.data_ptr(),
(IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(),
static_cast<int32_t*>(top_k_device_locs.data_ptr()),
static_cast<const int64_t*>(req_pool_indices.data_ptr()),
static_cast<const int32_t*>(seq_lens.data_ptr()),
static_cast<int16_t*>(lru_slots.data_ptr()),
static_cast<const int32_t*>(num_real_reqs.data_ptr()),
buffer_stride_0,
host_stride,
lru_slot_stride_0,
top_k_tokens_stride,
top_k_device_locs_stride,
page_size,
item_size_bytes);
// Generic lambda: both int32 and int64 kernel variants are compiled;
// the correct one is selected at runtime based on seq_lens dtype.
auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr) {
constexpr size_t smem_bytes = SmemLayout<NUM_TOP_K, HOT_BUFFER_SIZE>::BYTES;
if constexpr (smem_bytes > 48u * 1024u) {
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)(
kernel_fn,
static_cast<const int32_t*>(top_k_tokens.data_ptr()),
static_cast<int32_t*>(device_buffer_tokens.data_ptr()),
static_cast<const int64_t*>(host_cache_locs.data_ptr()),
static_cast<const int32_t*>(device_buffer_locs.data_ptr()),
host_cache_k.data_ptr(),
(IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(),
device_buffer_k.data_ptr(),
(IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(),
static_cast<int32_t*>(top_k_device_locs.data_ptr()),
static_cast<const int64_t*>(req_pool_indices.data_ptr()),
seq_lens_ptr,
static_cast<int16_t*>(lru_slots.data_ptr()),
static_cast<const int32_t*>(num_real_reqs.data_ptr()),
buffer_stride_0,
host_stride,
lru_slot_stride_0,
top_k_tokens_stride,
top_k_device_locs_stride,
page_size,
item_size_bytes);
};

const auto dtype = seq_lens.dtype();
if (dtype.code == kDLInt && dtype.bits == 64) {
launch(
load_cache_to_device_buffer_kernel<BLOCK_SIZE, NUM_TOP_K, HOT_BUFFER_SIZE, IsMLA, int64_t>,
static_cast<const int64_t*>(seq_lens.data_ptr()));
} else {
launch(
load_cache_to_device_buffer_kernel<BLOCK_SIZE, NUM_TOP_K, HOT_BUFFER_SIZE, IsMLA, int32_t>,
static_cast<const int32_t*>(seq_lens.data_ptr()));
}
}

} // namespace
16 changes: 8 additions & 8 deletions python/sglang/srt/managers/hisparse_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def set_decode_producer_stream(self, stream) -> None:
self.decode_producer_stream = stream

def admit_request_into_staging(self, req: Req) -> None:
req.staging = True
req.hisparse_staging = True
logical_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.fill_ids)
]
Expand Down Expand Up @@ -224,7 +224,7 @@ def collect_ready_reqs(self) -> List[Req]:
_, _, req = self.ack_staging_queue.pop(0)
# prepare device buffer and update req
self.alloc_device_buffer(req)
req.staging = False
req.hisparse_staging = False
self._skip_first_backup[req.req_pool_idx] = True
finish_count -= 1
ready_reqs.append(req)
Expand Down Expand Up @@ -494,7 +494,7 @@ def abort_staging_request(self, req: Req) -> None:
"""Remove a request from the staging queue and free its host resources.

Must be called when aborting a request that has been admitted into staging
but has not yet completed (i.e. req.staging is True).
but has not yet completed (i.e. req.hisparse_staging is True).
"""
# Remove from staging queue
self.ack_staging_queue = [
Expand All @@ -510,10 +510,10 @@ def abort_staging_request(self, req: Req) -> None:
self.mem_pool_host.free(host_indices)
self.req_to_host_pool[req.req_pool_idx, :] = -1
self._skip_first_backup[req.req_pool_idx] = False
req.staging = False
req.hisparse_staging = False

def retract_req(self, req: Req) -> None:
if req.staging:
if req.hisparse_staging:
self.abort_staging_request(req)
else:
self.request_finished(req)
Expand Down Expand Up @@ -556,14 +556,14 @@ def swap_in_selected_pages(
layer_id: int,
) -> torch.Tensor:
"""Swap selected top-k tokens into device memory and return their indices."""
# The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32.
# The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32 or int64.
if req_pool_indices.dtype != torch.int64:
raise ValueError(
f"req_pool_indices dtype {req_pool_indices.dtype} is not int64 as expected"
)
if seq_lens.dtype != torch.int32:
if seq_lens.dtype not in (torch.int32, torch.int64):
raise ValueError(
f"seq_lens dtype {seq_lens.dtype} is not int32 as expected"
f"seq_lens dtype {seq_lens.dtype} is not int32 or int64 as expected"
)
if top_k_result.dtype != torch.int32:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def __init__(
self.init_diffusion_llm(dllm_config)

# For hisparse
self.staging = False
self.hisparse_staging = False

@property
def seqlen(self) -> int:
Expand Down
53 changes: 29 additions & 24 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
chunked_req_to_exclude.add(self.chunked_req)
self.stash_chunked_request(self.chunked_req)

# HiSparse has its own prefill-to-decode transition; skip last_batch merge.
if self.enable_hisparse:
ready_reqs = self.hisparse_coordinator.collect_ready_reqs()
if len(ready_reqs) > 0:
Expand All @@ -2155,31 +2156,35 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
else:
self.running_batch.merge_batch(new_batch)
self.running_batch.hisparse_coordinator = self.hisparse_coordinator
else:
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.last_batch.chunked_req is not None:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
# We need to discard it.
chunked_req_to_exclude.add(self.last_batch.chunked_req)

if self.dllm_config is not None and self.last_batch.reqs:
chunked_req_to_exclude.update(self.last_batch.reqs)

# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False

# Merge the new batch into the running batch.
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)
if (
not self.enable_hisparse
and self.last_batch
and self.last_batch.forward_mode.is_extend()
Comment on lines +2160 to +2163
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wrapping the last_batch processing logic within an if not self.enable_hisparse condition correctly isolates the behavior. This ensures that the standard last_batch merge only occurs when hisparse is not active, preventing conflicts with hisparse's specialized transition handling.

):
if self.last_batch.chunked_req is not None:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
# We need to discard it.
chunked_req_to_exclude.add(self.last_batch.chunked_req)

if self.dllm_config is not None and self.last_batch.reqs:
chunked_req_to_exclude.update(self.last_batch.reqs)

# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False

# Merge the new batch into the running batch.
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)

# For prefill-only batch, filter out finished requests since they
# won't go through the decode step. This keeps running_batch accurate
Expand Down
23 changes: 23 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,18 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str:
user_set_prefill = self.nsa_prefill_backend is not None
user_set_decode = self.nsa_decode_backend is not None

# HiSparse requires flashmla_sparse for both prefill and decode
if self.enable_hisparse:
if not user_set_prefill:
self.nsa_prefill_backend = "flashmla_sparse"
if not user_set_decode:
self.nsa_decode_backend = "flashmla_sparse"
logger.warning(
f"HiSparse enabled: using flashmla_sparse NSA backends "
f"(prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend})."
)
return

if not user_set_prefill and not user_set_decode and is_hip():
self.nsa_prefill_backend = "tilelang"
self.nsa_decode_backend = "tilelang"
Expand Down Expand Up @@ -6088,6 +6100,17 @@ def check_server_args(self):
assert (
self.disable_radix_cache
), "Hierarchical sparse attention currently requires --disable-radix-cache."
for attr, label in [
("nsa_prefill_backend", "prefill"),
("nsa_decode_backend", "decode"),
]:
backend = getattr(self, attr)
if backend is not None and backend != "flashmla_sparse":
raise ValueError(
f"HiSparse requires flashmla_sparse NSA {label} backend, "
f"but got --nsa-{label}-backend={backend}. "
f"Please use --nsa-{label}-backend=flashmla_sparse or omit it."
)

assert (
self.schedule_conservativeness >= 0
Expand Down
Loading