Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):
metadata = self.graph_metadata[bs]
max_len = seq_lens_cpu[:bs].max().item()
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):

num_kv_splits = None
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
actual_forward_mode: Optional[ForwardMode] = None,
):
"""Init the metadata for a forward pass for replaying a cuda graph."""
raise NotImplementedError()
Expand Down
154 changes: 76 additions & 78 deletions python/sglang/srt/layers/attention/compressed/indexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -37,8 +37,6 @@
FP8_DTYPE = torch.float8_e4m3fn
FP8_MAX = torch.finfo(FP8_DTYPE).max

_arange_cache: Dict[str, torch.Tensor] = {}


def fp8_paged_mqa_logits_torch(
q_fp8: torch.Tensor,
Expand All @@ -50,8 +48,6 @@ def fp8_paged_mqa_logits_torch(
max_seq_len: int,
clean_logits: bool = True,
) -> torch.Tensor:
"""Vectorized implementation that avoids .item() and Python loops,
making it compatible with CUDA graph capture."""
_ = deep_gemm_metadata
batch_size, _, num_heads, head_dim = q_fp8.shape
block_size = kvcache_fp8.shape[1]
Expand All @@ -65,48 +61,33 @@ def fp8_paged_mqa_logits_torch(
assert page_table.shape[0] == batch_size
assert clean_logits == False

max_num_pages = page_table.shape[1]
SCALE_OFFSET = block_size * head_dim
total_dim = block_size * (head_dim + 4)

kvcache_flat = kvcache_fp8.view(-1, total_dim)

pages_clamped = page_table.clamp(min=0)
kvcache_gathered = kvcache_flat[pages_clamped] # (B, max_num_pages, total_dim)

kv_values_raw = kvcache_gathered[..., :SCALE_OFFSET].contiguous()
kv_values_fp8 = kv_values_raw.view(dtype=FP8_DTYPE)
kv_values = kv_values_fp8.to(torch.float32)
kv_values = kv_values.reshape(batch_size, max_num_pages * block_size, head_dim)

kv_scales_raw = kvcache_gathered[..., SCALE_OFFSET:].contiguous()
kv_scales = kv_scales_raw.view(dtype=torch.float32)
kv_scales = kv_scales.reshape(batch_size, max_num_pages * block_size)

q_float = q_fp8[:, 0].to(torch.float32) # (B, num_heads, head_dim)
# (B, padded_seq_len, head_dim) @ (B, head_dim, num_heads) -> (B, padded_seq_len, num_heads)
scores = torch.bmm(kv_values, q_float.transpose(1, 2))
scores = F.relu(scores)
scores = scores * weight.unsqueeze(1) # (B, padded_seq_len, num_heads)
scores = scores.sum(dim=2) # (B, padded_seq_len)
scores = scores * kv_scales # (B, padded_seq_len)

padded_seq_len = max_num_pages * block_size
cache = _arange_cache
arange_key = f"arange_{padded_seq_len}_{scores.device}"
if arange_key not in cache:
cache[arange_key] = torch.arange(padded_seq_len, device=scores.device)
positions = cache[arange_key].unsqueeze(0)
valid_mask = positions < seq_lens.unsqueeze(1)
scores = scores.masked_fill(~valid_mask, 0.0)

# Pad to max_seq_len if needed (padded_seq_len may be < max_seq_len)
if padded_seq_len < max_seq_len:
scores = F.pad(scores, (0, max_seq_len - padded_seq_len), value=0.0)
else:
scores = scores[:, :max_seq_len]

return scores
logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32)
for i in range(batch_size):
q = q_fp8[i, 0] # (num_heads, head_dim)
q = q.to(torch.float32)
q_scale = weight[i] # (num_heads)
seq_len = int(seq_lens[i].item())
assert seq_len <= max_seq_len
num_pages = (seq_len + block_size - 1) // block_size
padded_seq_len = num_pages * block_size
pages = page_table[i, :num_pages] # (num_pages,)
kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4))
kvcache = kvcache_fp8[pages] # (num_pages, block_size * (head_dim + 4))
SCALE_OFFSET = block_size * head_dim
kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE)
kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32)
kvcache_value = kvcache_value.to(torch.float32)
kvcache_scale = kvcache_scale.contiguous()
kvcache_value = kvcache_value.view(padded_seq_len, head_dim)
kvcache_scale = kvcache_scale.view(padded_seq_len)
score = F.linear(kvcache_value, q)
score = F.relu(score)
score *= q_scale[None, :]
score = score.sum(dim=1) # (padded_seq_len,)
score *= kvcache_scale
logits[i, :seq_len] = score[:seq_len]

return logits


# def fp8_paged_mqa_logits_torch(
Expand Down Expand Up @@ -224,6 +205,7 @@ def fp8_paged_mqa_logits_torch(
# return logits


# Vectorized version (faster but uses more memory) - for AMD/HIP
def topk_transform_512_pytorch_vectorized(
scores: torch.Tensor,
seq_lens: torch.Tensor,
Expand All @@ -234,8 +216,7 @@ def topk_transform_512_pytorch_vectorized(
) -> None:
"""
Vectorized PyTorch fallback for topk_transform_512.
All helper tensors (arange, zeros) are cached to avoid device-tensor
creation during HIP/CUDA graph capture.
Faster than the loop version but may use more memory.
"""

TOPK = 512
Expand All @@ -246,54 +227,67 @@ def topk_transform_512_pytorch_vectorized(
page_bits = (page_size - 1).bit_length() if page_size > 1 else 0
page_mask = page_size - 1

# ---- cached helper tensors (allocated once, reused on replay) ----
cache = _arange_cache
key_seq = f"arange_{max_seq_len}_{device}"
key_topk = f"arange_{TOPK}_{device}"
key_bs = f"arange_{batch_size}_{device}"
if key_seq not in cache:
cache[key_seq] = torch.arange(max_seq_len, device=device)
if key_topk not in cache:
cache[key_topk] = torch.arange(TOPK, device=device, dtype=torch.int32)
if key_bs not in cache:
cache[key_bs] = torch.arange(batch_size, device=device)

positions = cache[key_seq].unsqueeze(0).expand(batch_size, -1)
# Create mask for valid positions based on seq_lens
positions = (
torch.arange(max_seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
)
valid_mask = positions < seq_lens.unsqueeze(1)

# Mask out invalid positions with -inf
masked_scores = scores.clone()
masked_scores.masked_fill_(~valid_mask, float("-inf"))
masked_scores[~valid_mask] = float("-inf")

# Get top-k indices
actual_k = min(TOPK, max_seq_len)
_, raw_indices = torch.topk(
masked_scores, k=actual_k, dim=1, largest=True, sorted=False
)
raw_indices = raw_indices.to(torch.int32)

# Pad raw_indices to TOPK size if needed
if actual_k < TOPK:
raw_indices = F.pad(raw_indices, (0, TOPK - actual_k), value=0)
padding = torch.zeros(
(batch_size, TOPK - actual_k), dtype=torch.int32, device=device
)
raw_indices = torch.cat([raw_indices, padding], dim=1)

batch_indices = cache[key_bs].unsqueeze(1).expand(-1, TOPK)
# Check which indices are valid
batch_indices = (
torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, TOPK)
)
gathered_scores = scores[
batch_indices.flatten(), raw_indices.clamp(min=0).flatten()
].view(batch_size, TOPK)

valid_topk = gathered_scores != float("-inf")
if actual_k < TOPK:
pad_mask = cache[key_topk].unsqueeze(0) >= actual_k
pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k
valid_topk = valid_topk & ~pad_mask

# For short sequences, use sequential indices
needs_sequential = seq_lens <= TOPK
sequential_indices = cache[key_topk].unsqueeze(0).expand(batch_size, -1)
sequential_valid = sequential_indices < seq_lens.unsqueeze(1)

seq_indices_or_neg1 = sequential_indices.clone()
seq_indices_or_neg1.masked_fill_(~sequential_valid, -1)

needs_seq_mask = needs_sequential.unsqueeze(1).expand(-1, TOPK)
raw_indices = torch.where(needs_seq_mask, seq_indices_or_neg1, raw_indices)
valid_topk = torch.where(needs_seq_mask, sequential_valid, valid_topk)
if needs_sequential.any():
sequential_indices = (
torch.arange(TOPK, device=device, dtype=torch.int32)
.unsqueeze(0)
.expand(batch_size, -1)
)
sequential_valid = sequential_indices < seq_lens.unsqueeze(1)

raw_indices = torch.where(
needs_sequential.unsqueeze(1).expand(-1, TOPK),
torch.where(
sequential_valid,
sequential_indices,
torch.tensor(-1, device=device, dtype=torch.int32),
),
raw_indices,
)
valid_topk = torch.where(
needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk
)

# Transform to page indices
page_idx = raw_indices >> page_bits
offset_in_page = raw_indices & page_mask

Expand All @@ -302,13 +296,17 @@ def topk_transform_512_pytorch_vectorized(

page_indices = (physical_pages << page_bits) | offset_in_page
page_indices = page_indices.to(torch.int32)
page_indices.masked_fill_(~valid_topk, -1)

page_indices = torch.where(
valid_topk, page_indices, torch.tensor(-1, device=device, dtype=torch.int32)
)

out_page_indices.copy_(page_indices)

if out_raw_indices is not None:
raw_indices = raw_indices.clone()
raw_indices.masked_fill_(~valid_topk, -1)
raw_indices = torch.where(
valid_topk, raw_indices, torch.tensor(-1, device=device, dtype=torch.int32)
)
out_raw_indices.copy_(raw_indices)


Expand Down
23 changes: 11 additions & 12 deletions python/sglang/srt/layers/attention/compressed/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,18 @@ def max_seq_len(self) -> int:

def copy_(self, other: "PagedIndexerMetadata"):
if is_hip():
copy_metadata(
src=other,
dst=self,
check_eq_fields=["page_size", "deep_gemm_metadata"],
copy_fields=["page_table", "c4_seq_lens"],
)
# HIP/ROCm: don't copy deep_gemm_metadata (it's None)
copy_fields = ["page_table", "c4_seq_lens"]
else:
copy_metadata(
src=other,
dst=self,
check_eq_fields=["page_size"],
copy_fields=["page_table", "c4_seq_lens", "deep_gemm_metadata"],
)
# CUDA: original behavior
copy_fields = ["page_table", "c4_seq_lens", "deep_gemm_metadata"]

copy_metadata(
src=other,
dst=self,
check_eq_fields=["page_size"],
copy_fields=copy_fields,
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):

if forward_mode.is_decode_or_idle():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs):
# backend == "torch"
import os

from sglang.srt.layers.attention.nsa.tilelang_kernel import (
dpsk_v4_bf16_sparse_attention_fwd,
)

backend = os.environ.get("SGLANG_HACK_FLASHMLA_BACKEND", "kernel")
else:
import flash_mla
Expand All @@ -36,9 +32,6 @@ def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs):
if backend == "torch":
return flash_mla_with_kvcache_torch(**kwargs)

if backend == "tilelang":
return dpsk_v4_bf16_sparse_attention_fwd(**kwargs)

if backend == "kernel":
return flash_mla.flash_mla_with_kvcache(**kwargs)

Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/layers/attention/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,8 @@ def init_forward_metadata_capture_cuda_graph(
max_seq_len=self.max_seq_len_for_capture,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
# Dummy value (must be int64 to match real out_cache_loc dtype)
out_cache_loc=torch.zeros(
seq_lens.shape, dtype=torch.int64, device=seq_lens.device
),
# Dummy value
out_cache_loc=torch.zeros_like(seq_lens),
)

self.decode_cuda_graph_metadata_of_bs[bs] = metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,6 @@ def init_forward_metadata_replay_cuda_graph(
spec_info: Optional[None],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None,
**kwargs,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert forward_mode.is_decode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,6 @@ def init_forward_metadata_replay_cuda_graph(
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
**kwargs,
):
"""Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs]
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/attention/flashmla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/layers/attention/hybrid_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
**kwargs,
):
backend = self._select_backend(forward_mode)
backend.init_forward_metadata_replay_cuda_graph(
Expand All @@ -107,7 +106,6 @@ def init_forward_metadata_replay_cuda_graph(
forward_mode,
spec_info,
seq_lens_cpu,
**kwargs,
)

def get_cuda_graph_seq_len_fill_value(self):
Expand Down
Loading
Loading