Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,25 @@ struct CombinedTopKKernel {
.enable_pdl(true)(kernel, params);
} else {
// Some items may be large -- launch stage-1 + main
if (batch_size <= kNumClusters) {
// SM120 (consumer Blackwell, CC 12.0) has only ~99KB shared memory per block.
// kStage1SMEM (~144KB) exceeds this limit, so skip the cluster path on SM120
// and use only the Medium/Small (stage-2) paths which fit in ~84KB.
int device_cc_major = 0;
{
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
device_cc_major = prop.major;
}
const bool is_sm120 = (device_cc_major == 12);
if (!is_sm120 && batch_size <= kNumClusters) {
// can fuse into 1 stage
constexpr auto kernel = topk_fused_transform;
constexpr auto kSMEM = std::max(kStage1SMEM, kStage2SMEM);
setup_kernel_smem_once<kernel, kSMEM>();
LaunchKernel({batch_size, kClusterSize}, kBlockSize, device, kSMEM)
.enable_cluster({1, kClusterSize})
.enable_pdl(true)(kernel, params);
} else {
} else if (!is_sm120) {
// stage 1 + stage 2
constexpr auto kernel_stage_1 = topk_combine_preprocess;
setup_kernel_smem_once<kernel_stage_1, kStage1SMEM>();
Expand All @@ -485,6 +495,12 @@ struct CombinedTopKKernel {
setup_kernel_smem_once<kernel_stage_2, kStage2SMEM>();
LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) //
.enable_pdl(true)(kernel_stage_2, params);
} else {
// SM120 fallback: use only stage-2 (Small/Medium) path which fits in ~84KB
constexpr auto kernel = topk_short_transform;
setup_kernel_smem_once<kernel, kStage2SMEM>();
LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM)
.enable_pdl(true)(kernel, params);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def hash_topk(
(num_tokens, topk_fused), dtype=torch.float32, device=router_logits.device
)
module = _jit_hash_topk_module()
# tvm_ffi hash_topk kernel expects input_ids as int64
if input_ids.dtype != torch.int64:
input_ids = input_ids.to(torch.int64)
module.hash_topk(
router_logits,
input_ids,
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/jit_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,13 @@ def wrapper(*args, **kwargs):

@cache_once
def is_arch_support_pdl() -> bool:
"""PDL (Programmatic Dependent Launch) is available on SM90+ GPUs.

Available on all architectures with compute capability >= 9.0, including
Hopper (SM90), Blackwell datacenter (SM100), and Blackwell consumer (SM120).
"""
import torch

device = torch.cuda.current_device()
return torch.cuda.get_device_capability(device)[0] >= 9
major = torch.cuda.get_device_capability(device)[0]
return major >= 9
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,10 @@ class Envs:
SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False)
SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False)

# SM120-optimized Triton kernels (auto-enabled on SM120, set to 0 to disable)
SGLANG_SM120_TRITON_MOE = EnvBool(True)
SGLANG_SM120_TRITON_FLASHMLA = EnvBool(True)

# Symmetric Memory
SGLANG_SYMM_MEM_PREALLOC_GB_SIZE = EnvInt(-1)

Expand Down
146 changes: 85 additions & 61 deletions python/sglang/srt/layers/attention/compressed/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,44 +45,64 @@ def fp8_paged_mqa_logits_torch(
max_seq_len: int,
clean_logits: bool = True,
) -> torch.Tensor:
"""CUDA-graph compatible FP8 paged MQA logits.

Retains the original per-batch loop structure for correctness, but replaces
.item() calls with GPU-only ops. For decode (bs=1), the loop runs once.
"""
_ = deep_gemm_metadata
batch_size, _, num_heads, head_dim = q_fp8.shape
block_size = kvcache_fp8.shape[1]

assert head_dim == 128, "TODO"
assert block_size == 64, "TODO"
assert q_fp8.shape == (batch_size, 1, num_heads, head_dim)
assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4)
assert weight.shape == (batch_size, num_heads)
assert seq_lens.shape == (batch_size,)
assert page_table.shape[0] == batch_size
assert clean_logits == False

logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32)
block_size = kvcache_fp8.shape[1] # 64

if seq_lens.dim() == 2 and seq_lens.shape[1] == 1:
seq_lens = seq_lens.squeeze(-1)

logits = page_table.new_zeros((batch_size, max_seq_len), dtype=torch.float32)
SCALE_OFFSET = block_size * head_dim

# Use max_pages from page_table columns
max_pages_in_table = page_table.shape[1]
# Pre-compute arange for valid token masking (avoid re-creation per batch)
token_arange = torch.arange(max_seq_len, device=seq_lens.device)

kvcache_flat = kvcache_fp8.reshape(-1, block_size * (head_dim + 4))
num_total_pages = kvcache_flat.shape[0]

for i in range(batch_size):
q = q_fp8[i, 0]
q = q.to(torch.float32)
q_scale = weight[i]
seq_len = int(seq_lens[i].item())
assert seq_len <= max_seq_len
num_pages = (seq_len + block_size - 1) // block_size
padded_seq_len = num_pages * block_size
pages = page_table[i, :num_pages]
kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4))
kvcache = kvcache_fp8[pages]
SCALE_OFFSET = block_size * head_dim
kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE)
kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32)
kvcache_value = kvcache_value.to(torch.float32)
kvcache_scale = kvcache_scale.contiguous()
kvcache_value = kvcache_value.view(padded_seq_len, head_dim)
kvcache_scale = kvcache_scale.view(padded_seq_len)
score = F.linear(kvcache_value, q)
score = F.relu(score)
score *= q_scale[None, :]
score = score.sum(dim=1)
score *= kvcache_scale
logits[i, :seq_len] = score[:seq_len]
q = q_fp8[i, 0].to(torch.float32) # (num_heads, head_dim)
q_scale = weight[i] # (num_heads,)

# Gather pages for this batch item (GPU-only, no .item())
pages = page_table[i].clamp(0, num_total_pages - 1) # (max_pages,)
kvcache = kvcache_flat[pages] # (max_pages, block_size * (head_dim + 4))

# Split value and scale
kvcache_value = kvcache[..., :SCALE_OFFSET].reshape(-1, head_dim)
kvcache_value = kvcache_value.view(FP8_DTYPE).to(torch.float32)
kvcache_scale = (
kvcache[..., SCALE_OFFSET:].contiguous().view(torch.float32).reshape(-1)
)

padded_len = kvcache_value.shape[0]

# Score: F.linear(Q_flat, K) -> (padded_len, num_heads)
score = F.linear(kvcache_value, q) # (padded_len, num_heads)
score = torch.relu(score)
score = score * q_scale.unsqueeze(0) # broadcast per-head weight
score = score.sum(dim=1) # (padded_len,)
score = score * kvcache_scale # apply KV scale

# Mask invalid tokens using GPU comparison (no .item())
# seq_len is on GPU, arange is on GPU — all GPU ops
valid_len = seq_lens[i] # GPU scalar tensor, no sync
valid_mask = token_arange[:padded_len] < valid_len
# Truncate to max_seq_len if needed
store_len = min(padded_len, max_seq_len)
score_valid = score[:store_len]
mask_valid = valid_mask[:store_len]
logits[i, :store_len] = torch.where(
mask_valid, score_valid, torch.zeros_like(score_valid)
)

return logits

Expand Down Expand Up @@ -136,27 +156,27 @@ def topk_transform_512_pytorch_vectorized(
pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k
valid_topk = valid_topk & ~pad_mask

# CUDA graph compatible: compute sequential path unconditionally,
# select with torch.where (no .any() GPU->CPU sync)
needs_sequential = seq_lens <= TOPK
if needs_sequential.any():
sequential_indices = (
torch.arange(TOPK, device=device, dtype=torch.int32)
.unsqueeze(0)
.expand(batch_size, -1)
)
sequential_valid = sequential_indices < seq_lens.unsqueeze(1)

raw_indices = torch.where(
needs_sequential.unsqueeze(1).expand(-1, TOPK),
torch.where(
sequential_valid,
sequential_indices,
torch.tensor(-1, device=device, dtype=torch.int32),
),
raw_indices,
)
valid_topk = torch.where(
needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk
)
sequential_indices = (
torch.arange(TOPK, device=device, dtype=torch.int32)
.unsqueeze(0)
.expand(batch_size, -1)
)
sequential_valid = sequential_indices < seq_lens.unsqueeze(1)

needs_seq_expand = needs_sequential.unsqueeze(1).expand(-1, TOPK)
raw_indices = torch.where(
needs_seq_expand,
torch.where(
sequential_valid,
sequential_indices,
torch.tensor(-1, device=device, dtype=torch.int32),
),
raw_indices,
)
valid_topk = torch.where(needs_seq_expand, sequential_valid, valid_topk)

page_idx = raw_indices >> page_bits
offset_in_page = raw_indices & page_mask
Expand Down Expand Up @@ -379,12 +399,16 @@ def forward_c4_indexer(
elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get():
fn = fp8_paged_mqa_logits_torch
else:
if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1:
from sglang.srt.layers.deep_gemm_wrapper.paged_mqa_logits import (
fp8_paged_mqa_logits_chunked as fn,
)
else:
from deep_gemm import fp8_paged_mqa_logits as fn
try:
if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1:
from sglang.srt.layers.deep_gemm_wrapper.paged_mqa_logits import (
fp8_paged_mqa_logits_chunked as fn,
)
else:
from deep_gemm import fp8_paged_mqa_logits as fn
except (ImportError, RuntimeError, FileNotFoundError):
# DeepGEMM not available or SM120 unsupported, use PyTorch fallback
fn = fp8_paged_mqa_logits_torch

_c4sl = indexer_metadata.c4_seq_lens
if _c4sl.dim() == 1:
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/attention/compressed/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from sglang.srt.environ import envs
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import is_hip

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,7 +123,7 @@ class PagedIndexerMetadata(IndexerMetadata):
topk_metadata: torch.Tensor = field(init=False, repr=False)

def __post_init__(self):
if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get():
if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or not ENABLE_JIT_DEEPGEMM:
self.deep_gemm_metadata = None
else:
import deep_gemm
Expand Down
Loading
Loading