Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/sglang/jit_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ def get_jit_cuda_arch() -> ArchInfo:
def is_arch_support_pdl() -> bool:
if is_hip_runtime():
return False
return get_jit_cuda_arch().major >= 9
arch = get_jit_cuda_arch()
# PDL (griddepcontrol) instruction is supported on SM90+ (Hopper, Blackwell).
# SM120 (desktop Blackwell) supports PDL despite lacking TMEM/tcgen05 —
# PDL uses griddepcontrol for kernel scheduling, independent of TMEM.
return arch.major >= 9


def _find_package_root(package: str) -> Optional[pathlib.Path]:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/arg_groups/deepseek_v4_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def apply_deepseek_v4_defaults(server_args: "ServerArgs", model_arch: str) -> No
f"Setting swa_full_tokens_ratio to {server_args.swa_full_tokens_ratio} for {model_arch}."
)

# SM120: auto-select marlin MoE backend (dispatches to SM120 Triton kernel)
from sglang.srt.utils.common import is_sm120_supported

if is_sm120_supported() and server_args.moe_runner_backend == "auto":
server_args.moe_runner_backend = "marlin"
logger.info("Use marlin as MoE runner backend on SM120 for DeepSeekV4")

if server_args.disaggregation_mode != "null" and server_args.pp_size > 1:
# get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized
# flat KV ptrs by PP layer range.
Expand Down
14 changes: 10 additions & 4 deletions python/sglang/srt/layers/attention/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
from sglang.srt.layers.attention.dsv4.quant_k_cache import (
quant_to_nope_fp8_rope_bf16_pack_triton,
)
from sglang.srt.layers.attention.flash_mla_sm120_fallback import (
_is_sm120,
flash_mla_with_kvcache_entrypoint,
)
from sglang.srt.layers.dp_attention import (
get_attention_cp_rank,
get_attention_cp_size,
Expand Down Expand Up @@ -81,6 +85,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T:


def _create_flashmla_metadata():
if _is_sm120:
return None
import flash_mla

return flash_mla.get_mla_metadata()[0]
Expand Down Expand Up @@ -1031,9 +1037,7 @@ def forward(
extra_indices.shape[-1] % 64 == 0
), f"{extra_indices.shape=}'s last dimension is not aligned to 64"

import flash_mla

o = flash_mla.flash_mla_with_kvcache(
input_dict = dict(
q=q,
k_cache=swa_k_cache,
head_dim_v=self.head_dim_v,
Expand All @@ -1048,7 +1052,9 @@ def forward(
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_lengths,
)[0]
)

o = flash_mla_with_kvcache_entrypoint(**input_dict, backend="kernel")[0]

o = o.squeeze(1)
return o
Expand Down
100 changes: 71 additions & 29 deletions python/sglang/srt/layers/attention/dsv4/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config
from sglang.srt.environ import envs
from sglang.srt.layers.attention.dsv4.compressor import Compressor
from sglang.srt.layers.attention.dsv4.metadata import PagedIndexerMetadata
from sglang.srt.layers.attention.dsv4.metadata import (
PagedIndexerMetadata,
_is_sm120,
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.state_capturer.indexer_topk import get_global_indexer_capturer
from sglang.srt.utils import add_prefix, is_hip
Expand Down Expand Up @@ -49,44 +52,83 @@ def fp8_paged_mqa_logits_torch(
max_seq_len: int,
clean_logits: bool = True,
) -> torch.Tensor:
"""CUDA-graph-compatible FP8 paged MQA logits (vectorized, no .item()).

Vectorized across batches using batched gather + bmm instead of
per-batch Python loop with .item() calls.
"""
_ = deep_gemm_metadata
batch_size, _, num_heads, head_dim = q_fp8.shape
block_size = kvcache_fp8.shape[1]
device = q_fp8.device

assert head_dim == 128, "torch reference impl hardcodes DSV4 indexer head_dim=128"
assert block_size == 64, "torch reference impl hardcodes block_size=64 cache layout"
assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128"
assert (
block_size == 64
), "Vectorized torch impl hardcodes block_size=64 cache layout"
assert q_fp8.shape == (batch_size, 1, num_heads, head_dim)
assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4)
assert weight.shape == (batch_size, num_heads)
if seq_lens.dim() > 1:
seq_lens = seq_lens.squeeze(-1)
assert seq_lens.shape == (batch_size,)
assert page_table.shape[0] == batch_size
assert clean_logits == False

logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32)
for i in range(batch_size):
q = q_fp8[i, 0]
q = q.to(torch.float32)
q_scale = weight[i]
seq_len = int(seq_lens[i].item())
assert seq_len <= max_seq_len
num_pages = (seq_len + block_size - 1) // block_size
padded_seq_len = num_pages * block_size
pages = page_table[i, :num_pages]
kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4))
kvcache = kvcache_fp8[pages]
SCALE_OFFSET = block_size * head_dim
kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE)
kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32)
kvcache_value = kvcache_value.to(torch.float32)
kvcache_scale = kvcache_scale.contiguous()
kvcache_value = kvcache_value.view(padded_seq_len, head_dim)
kvcache_scale = kvcache_scale.view(padded_seq_len)
score = F.linear(kvcache_value, q)
score = F.relu(score)
score *= q_scale[None, :]
score = score.sum(dim=1)
score *= kvcache_scale
logits[i, :seq_len] = score[:seq_len]
# ── Vectorized: no .item(), no per-batch loop ──
max_pages = (max_seq_len + block_size - 1) // block_size
max_padded_seq = max_pages * block_size

# Flatten KV cache for indexing: [total_pages, block_size * (head_dim + 4)]
kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4))
SCALE_OFFSET = block_size * head_dim

# Gather pages for all batches: [batch, max_pages]
page_ids = page_table[:, :max_pages]
# Gather KV data: [batch, max_pages, block_size * (head_dim + 4)]
kvcache_gathered = kvcache_flat[page_ids]

# Split value and scale
kv_value_raw = kvcache_gathered[
..., :SCALE_OFFSET
] # [batch, max_pages, block_size * head_dim]
kv_scale_raw = kvcache_gathered[
..., SCALE_OFFSET:
] # [batch, max_pages, block_size * 4]

# Dequant value: view as FP8, convert to float32
kv_value = kv_value_raw.contiguous().view(dtype=FP8_DTYPE).to(torch.float32)
kv_value = kv_value.view(batch_size, max_padded_seq, head_dim)

# Dequant scale
kv_scale = kv_scale_raw.contiguous().view(dtype=torch.float32)
kv_scale = kv_scale.view(batch_size, max_padded_seq)

# Q: [batch, num_heads, head_dim]
q = q_fp8[:, 0].to(torch.float32)

# Batched matmul: [batch, max_padded_seq, head_dim] @ [batch, head_dim, num_heads]
score = torch.bmm(kv_value, q.transpose(1, 2)) # [batch, max_padded_seq, num_heads]

# ReLU + scale by weight + sum across heads
score = F.relu(score)
score = score * weight.unsqueeze(1) # [batch, max_padded_seq, num_heads]
score = score.sum(dim=2) # [batch, max_padded_seq]

# Apply KV scale
score = score * kv_scale # [batch, max_padded_seq]

# Create validity mask and write output — graph-safe (no torch.tensor() calls)
out_width = min(max_padded_seq, max_seq_len)
logits = score.new_full((batch_size, max_seq_len), float("-inf"))
logits[:, :out_width] = score[:, :out_width]

# Mask invalid positions to -inf
positions = torch.arange(max_seq_len, device=device)
invalid_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(
1
) # [batch, max_seq_len]
logits.masked_fill_(invalid_mask, float("-inf"))

return logits

Expand Down Expand Up @@ -377,7 +419,7 @@ def forward_c4_indexer(
from sglang.srt.layers.attention.dsv4.tilelang_kernel import (
tilelang_fp8_paged_mqa_logits as fn,
)
elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get():
elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or _is_sm120:
fn = fp8_paged_mqa_logits_torch
else:
from deep_gemm import fp8_paged_mqa_logits as fn
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/layers/attention/dsv4/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from sglang.srt.environ import envs
from sglang.srt.utils import is_hip
from sglang.srt.utils.common import is_sm120_supported

_is_cuda = torch.cuda.is_available() and not is_hip()
_is_sm120 = _is_cuda and is_sm120_supported()

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -103,7 +107,9 @@ class PagedIndexerMetadata:
topk_metadata: torch.Tensor = field(init=False, repr=False)

def __post_init__(self):
if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get():
if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or _is_sm120:
# SM120: DeepGEMM get_paged_mqa_logits_metadata asserts
# "Unsupported architecture" on SM120. Use None (torch fallback path).
self.deep_gemm_metadata = None
else:
import deep_gemm
Expand Down
Loading
Loading