diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index abe712978a92..7bc7cb2b1889 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -512,7 +512,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): metadata = self.graph_metadata[bs] max_len = seq_lens_cpu[:bs].max().item() diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index eb95dec0ee27..e52ae698428b 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -1661,7 +1661,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): num_kv_splits = None diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 309cf78095a3..ccfca17e6043 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -50,8 +50,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - out_cache_loc: Optional[torch.Tensor] = None, - actual_forward_mode: Optional[ForwardMode] = None, ): """Init the metadata for a forward pass for replaying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/compressed/indexer.py b/python/sglang/srt/layers/attention/compressed/indexer.py index 2300f8dd605b..63c419c7df87 100644 --- a/python/sglang/srt/layers/attention/compressed/indexer.py +++ b/python/sglang/srt/layers/attention/compressed/indexer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple import torch import torch.nn.functional as F @@ -37,8 +37,6 @@ FP8_DTYPE = torch.float8_e4m3fn FP8_MAX = torch.finfo(FP8_DTYPE).max -_arange_cache: Dict[str, torch.Tensor] = {} - def fp8_paged_mqa_logits_torch( q_fp8: torch.Tensor, @@ -50,8 +48,6 @@ def fp8_paged_mqa_logits_torch( max_seq_len: int, clean_logits: bool = True, ) -> torch.Tensor: - """Vectorized implementation that avoids .item() and Python loops, - making it compatible with CUDA graph capture.""" _ = deep_gemm_metadata batch_size, _, num_heads, head_dim = q_fp8.shape block_size = kvcache_fp8.shape[1] @@ -65,48 +61,33 @@ def fp8_paged_mqa_logits_torch( assert page_table.shape[0] == batch_size assert clean_logits == False - max_num_pages = page_table.shape[1] - SCALE_OFFSET = block_size * head_dim - total_dim = block_size * (head_dim + 4) - - kvcache_flat = kvcache_fp8.view(-1, total_dim) - - pages_clamped = page_table.clamp(min=0) - kvcache_gathered = kvcache_flat[pages_clamped] # (B, max_num_pages, total_dim) - - kv_values_raw = kvcache_gathered[..., :SCALE_OFFSET].contiguous() - kv_values_fp8 = kv_values_raw.view(dtype=FP8_DTYPE) - kv_values = kv_values_fp8.to(torch.float32) - kv_values = kv_values.reshape(batch_size, max_num_pages * block_size, head_dim) - - kv_scales_raw = kvcache_gathered[..., SCALE_OFFSET:].contiguous() - kv_scales = kv_scales_raw.view(dtype=torch.float32) - kv_scales = kv_scales.reshape(batch_size, max_num_pages * block_size) - - q_float = q_fp8[:, 0].to(torch.float32) # (B, num_heads, head_dim) - # (B, padded_seq_len, head_dim) @ (B, head_dim, num_heads) -> (B, padded_seq_len, num_heads) - scores = torch.bmm(kv_values, q_float.transpose(1, 2)) - scores = F.relu(scores) - scores = scores * weight.unsqueeze(1) # (B, padded_seq_len, num_heads) - scores = scores.sum(dim=2) # (B, padded_seq_len) - scores = scores * kv_scales # (B, padded_seq_len) - - padded_seq_len = max_num_pages * block_size - cache = _arange_cache - arange_key = f"arange_{padded_seq_len}_{scores.device}" - if arange_key not in cache: - cache[arange_key] = torch.arange(padded_seq_len, device=scores.device) - positions = cache[arange_key].unsqueeze(0) - valid_mask = positions < seq_lens.unsqueeze(1) - scores = scores.masked_fill(~valid_mask, 0.0) - - # Pad to max_seq_len if needed (padded_seq_len may be < max_seq_len) - if padded_seq_len < max_seq_len: - scores = F.pad(scores, (0, max_seq_len - padded_seq_len), value=0.0) - else: - scores = scores[:, :max_seq_len] - - return scores + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + for i in range(batch_size): + q = q_fp8[i, 0] # (num_heads, head_dim) + q = q.to(torch.float32) + q_scale = weight[i] # (num_heads) + seq_len = int(seq_lens[i].item()) + assert seq_len <= max_seq_len + num_pages = (seq_len + block_size - 1) // block_size + padded_seq_len = num_pages * block_size + pages = page_table[i, :num_pages] # (num_pages,) + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[pages] # (num_pages, block_size * (head_dim + 4)) + SCALE_OFFSET = block_size * head_dim + kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE) + kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32) + kvcache_value = kvcache_value.to(torch.float32) + kvcache_scale = kvcache_scale.contiguous() + kvcache_value = kvcache_value.view(padded_seq_len, head_dim) + kvcache_scale = kvcache_scale.view(padded_seq_len) + score = F.linear(kvcache_value, q) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) # (padded_seq_len,) + score *= kvcache_scale + logits[i, :seq_len] = score[:seq_len] + + return logits # def fp8_paged_mqa_logits_torch( @@ -224,6 +205,7 @@ def fp8_paged_mqa_logits_torch( # return logits +# Vectorized version (faster but uses more memory) - for AMD/HIP def topk_transform_512_pytorch_vectorized( scores: torch.Tensor, seq_lens: torch.Tensor, @@ -234,8 +216,7 @@ def topk_transform_512_pytorch_vectorized( ) -> None: """ Vectorized PyTorch fallback for topk_transform_512. - All helper tensors (arange, zeros) are cached to avoid device-tensor - creation during HIP/CUDA graph capture. + Faster than the loop version but may use more memory. """ TOPK = 512 @@ -246,54 +227,67 @@ def topk_transform_512_pytorch_vectorized( page_bits = (page_size - 1).bit_length() if page_size > 1 else 0 page_mask = page_size - 1 - # ---- cached helper tensors (allocated once, reused on replay) ---- - cache = _arange_cache - key_seq = f"arange_{max_seq_len}_{device}" - key_topk = f"arange_{TOPK}_{device}" - key_bs = f"arange_{batch_size}_{device}" - if key_seq not in cache: - cache[key_seq] = torch.arange(max_seq_len, device=device) - if key_topk not in cache: - cache[key_topk] = torch.arange(TOPK, device=device, dtype=torch.int32) - if key_bs not in cache: - cache[key_bs] = torch.arange(batch_size, device=device) - - positions = cache[key_seq].unsqueeze(0).expand(batch_size, -1) + # Create mask for valid positions based on seq_lens + positions = ( + torch.arange(max_seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + ) valid_mask = positions < seq_lens.unsqueeze(1) + # Mask out invalid positions with -inf masked_scores = scores.clone() - masked_scores.masked_fill_(~valid_mask, float("-inf")) + masked_scores[~valid_mask] = float("-inf") + # Get top-k indices actual_k = min(TOPK, max_seq_len) _, raw_indices = torch.topk( masked_scores, k=actual_k, dim=1, largest=True, sorted=False ) raw_indices = raw_indices.to(torch.int32) + # Pad raw_indices to TOPK size if needed if actual_k < TOPK: - raw_indices = F.pad(raw_indices, (0, TOPK - actual_k), value=0) + padding = torch.zeros( + (batch_size, TOPK - actual_k), dtype=torch.int32, device=device + ) + raw_indices = torch.cat([raw_indices, padding], dim=1) - batch_indices = cache[key_bs].unsqueeze(1).expand(-1, TOPK) + # Check which indices are valid + batch_indices = ( + torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, TOPK) + ) gathered_scores = scores[ batch_indices.flatten(), raw_indices.clamp(min=0).flatten() ].view(batch_size, TOPK) valid_topk = gathered_scores != float("-inf") if actual_k < TOPK: - pad_mask = cache[key_topk].unsqueeze(0) >= actual_k + pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k valid_topk = valid_topk & ~pad_mask + # For short sequences, use sequential indices needs_sequential = seq_lens <= TOPK - sequential_indices = cache[key_topk].unsqueeze(0).expand(batch_size, -1) - sequential_valid = sequential_indices < seq_lens.unsqueeze(1) - - seq_indices_or_neg1 = sequential_indices.clone() - seq_indices_or_neg1.masked_fill_(~sequential_valid, -1) - - needs_seq_mask = needs_sequential.unsqueeze(1).expand(-1, TOPK) - raw_indices = torch.where(needs_seq_mask, seq_indices_or_neg1, raw_indices) - valid_topk = torch.where(needs_seq_mask, sequential_valid, valid_topk) + if needs_sequential.any(): + sequential_indices = ( + torch.arange(TOPK, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(batch_size, -1) + ) + sequential_valid = sequential_indices < seq_lens.unsqueeze(1) + + raw_indices = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), + torch.where( + sequential_valid, + sequential_indices, + torch.tensor(-1, device=device, dtype=torch.int32), + ), + raw_indices, + ) + valid_topk = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk + ) + # Transform to page indices page_idx = raw_indices >> page_bits offset_in_page = raw_indices & page_mask @@ -302,13 +296,17 @@ def topk_transform_512_pytorch_vectorized( page_indices = (physical_pages << page_bits) | offset_in_page page_indices = page_indices.to(torch.int32) - page_indices.masked_fill_(~valid_topk, -1) + + page_indices = torch.where( + valid_topk, page_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) out_page_indices.copy_(page_indices) if out_raw_indices is not None: - raw_indices = raw_indices.clone() - raw_indices.masked_fill_(~valid_topk, -1) + raw_indices = torch.where( + valid_topk, raw_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) out_raw_indices.copy_(raw_indices) diff --git a/python/sglang/srt/layers/attention/compressed/metadata.py b/python/sglang/srt/layers/attention/compressed/metadata.py index 0868689bf42d..e307e3b4316c 100644 --- a/python/sglang/srt/layers/attention/compressed/metadata.py +++ b/python/sglang/srt/layers/attention/compressed/metadata.py @@ -169,19 +169,18 @@ def max_seq_len(self) -> int: def copy_(self, other: "PagedIndexerMetadata"): if is_hip(): - copy_metadata( - src=other, - dst=self, - check_eq_fields=["page_size", "deep_gemm_metadata"], - copy_fields=["page_table", "c4_seq_lens"], - ) + # HIP/ROCm: don't copy deep_gemm_metadata (it's None) + copy_fields = ["page_table", "c4_seq_lens"] else: - copy_metadata( - src=other, - dst=self, - check_eq_fields=["page_size"], - copy_fields=["page_table", "c4_seq_lens", "deep_gemm_metadata"], - ) + # CUDA: original behavior + copy_fields = ["page_table", "c4_seq_lens", "deep_gemm_metadata"] + + copy_metadata( + src=other, + dst=self, + check_eq_fields=["page_size"], + copy_fields=copy_fields, + ) @dataclass diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index 2b0b1c800578..e81e761bcefd 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -192,7 +192,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): if forward_mode.is_decode_or_idle(): diff --git a/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py b/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py index 6a37b166a743..75f6c520877f 100644 --- a/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py +++ b/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py @@ -13,10 +13,6 @@ def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): # backend == "torch" import os - from sglang.srt.layers.attention.nsa.tilelang_kernel import ( - dpsk_v4_bf16_sparse_attention_fwd, - ) - backend = os.environ.get("SGLANG_HACK_FLASHMLA_BACKEND", "kernel") else: import flash_mla @@ -36,9 +32,6 @@ def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): if backend == "torch": return flash_mla_with_kvcache_torch(**kwargs) - if backend == "tilelang": - return dpsk_v4_bf16_sparse_attention_fwd(**kwargs) - if backend == "kernel": return flash_mla.flash_mla_with_kvcache(**kwargs) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index c8e99405db5c..3df43d464ac6 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -149,10 +149,8 @@ def init_forward_metadata_capture_cuda_graph( max_seq_len=self.max_seq_len_for_capture, req_pool_indices=req_pool_indices, seq_lens=seq_lens, - # Dummy value (must be int64 to match real out_cache_loc dtype) - out_cache_loc=torch.zeros( - seq_lens.shape, dtype=torch.int64, device=seq_lens.device - ), + # Dummy value + out_cache_loc=torch.zeros_like(seq_lens), ) self.decode_cuda_graph_metadata_of_bs[bs] = metadata diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py index fb451d230492..a84015a803f8 100644 --- a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -591,7 +591,6 @@ def init_forward_metadata_replay_cuda_graph( spec_info: Optional[None], seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: torch.Tensor = None, - **kwargs, ): """Initialize forward metadata for replaying CUDA graph.""" assert forward_mode.is_decode() diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 8910a6e4d4c0..193ce338c221 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1882,7 +1882,6 @@ def init_forward_metadata_replay_cuda_graph( spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: Optional[torch.Tensor] = None, - **kwargs, ): """Initialize forward metadata for replaying CUDA graph.""" seq_lens = seq_lens[:bs] diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index fd3d1138aabc..064e8adc76b1 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -717,7 +717,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 4b9f57639b21..601a80cea52b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -460,7 +460,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): if forward_mode.is_decode_or_idle(): assert seq_lens_cpu is not None diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index a33086c9c258..0693d072dfbd 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -292,7 +292,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): if forward_mode.is_decode_or_idle(): assert seq_lens_cpu is not None diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index 6722505df8b4..57e10daa6057 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -95,7 +95,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): backend = self._select_backend(forward_mode) backend.init_forward_metadata_replay_cuda_graph( @@ -107,7 +106,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode, spec_info, seq_lens_cpu, - **kwargs, ) def get_cuda_graph_seq_len_fill_value(self): diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 242b6f363b6c..45a9a4c9985c 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -388,7 +388,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): self.forward_metadata = self._replay_metadata( bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu @@ -682,7 +681,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): metadata = self._replay_metadata( bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu @@ -804,7 +802,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): for attn_backend in self.attn_backend_list: attn_backend.init_forward_metadata_replay_cuda_graph( @@ -816,7 +813,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode, spec_info, seq_lens_cpu, - **kwargs, ) def get_cuda_graph_seq_len_fill_value(self): diff --git a/python/sglang/srt/layers/attention/linear/lightning_backend.py b/python/sglang/srt/layers/attention/linear/lightning_backend.py index d2b322f434a5..b34fefbfd230 100644 --- a/python/sglang/srt/layers/attention/linear/lightning_backend.py +++ b/python/sglang/srt/layers/attention/linear/lightning_backend.py @@ -100,7 +100,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): metadata = self._replay_metadata( bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py index 9f5d21a26f9a..bfc62d7f0b19 100644 --- a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import tilelang import tilelang.language as T @@ -27,7 +27,6 @@ BF16 = "bfloat16" FP8 = "float8_e4m3fnuz" if _is_fp8_fnuz else "float8_e4m3" FP32 = "float32" -INT32 = "int32" def fast_log2_ceil(x): @@ -1376,396 +1375,3 @@ def tilelang_sparse_fwd( ) out = kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore return out - - -def _next_power_of_2(x: int) -> int: - p = 1 - while p < x: - p *= 2 - return p - - -def _padded_H(head_kv: int) -> int: - if hasattr(tilelang, "math") and hasattr(tilelang.math, "next_power_of_2"): - return max(tilelang.math.next_power_of_2(head_kv), 16) - return max(_next_power_of_2(head_kv), 16) - - -def _cdiv(topk: int, block_I: int) -> int: - if hasattr(tilelang, "math") and hasattr(tilelang.math, "cdiv"): - return tilelang.math.cdiv(topk, block_I) - return (topk + block_I - 1) // block_I - - -@tilelang.jit( - out_idx=[-2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, -) -def dpsk_v4_bf16_sparse_attention_kernel( - num_heads: int, - topk: int, - *, - dim: int = 448, - tail_dim: int = 64, - sm_scale: float = 0.0, - block_I: int = 64, - num_stages: int = 1, - threads: int = 256, - use_attn_sink: bool = False, - use_swizzle: bool = False, - use_q_shared: bool = False, -) -> Any: - """DeepSeek V4 MLA sparse attention kernel (dim=448 nope + tail_dim=64 rope). - - Same structure as `sparse_attention_fwd_kernel_v1`, but the head is split - into a main `dim` and a `tail_dim` (NoPE + RoPE), and output includes - the full head (dim+tail_dim) plus LSE. - """ - ln2: float = 0.69314718 - log2e: float = 1.44269504 - if sm_scale <= 0.0: - sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * log2e - else: - sm_scale = sm_scale * log2e - assert dim == 448 and tail_dim == 64 - assert topk % block_I == 0 - kv_group = 1 - - batch = T.symbolic("batch") - seq_len = T.symbolic("seq_len") - seq_len_kv = T.symbolic("seq_len_kv") - - head_kv = num_heads // kv_group - D = dim - D_tail = tail_dim - BI = block_I - H = head_kv - padded_H = _padded_H(head_kv) - if padded_H != H: - assert kv_group == 1 - REPLICATE_H = (head_kv + 63) // 64 if head_kv > 64 else 1 - if head_kv > 64: - assert head_kv % 64 == 0 - H_per_block = 64 if REPLICATE_H > 1 else padded_H - NI = _cdiv(topk, BI) - - q_shape = [batch, seq_len, num_heads, D + D_tail] - kv_shape = [batch, seq_len_kv, kv_group, D + D_tail] - o_shape = [batch, seq_len, num_heads, D + D_tail] - indices_shape = [batch, seq_len, kv_group, topk] - lse_shape = [batch, seq_len, num_heads] - attn_sink_shape = [H] - dtype = BF16 - accum_dtype = "float" - indices_dtype = INT32 - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Attn_sink: T.Tensor(attn_sink_shape, FP32), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - LSE: T.Tensor(lse_shape, accum_dtype), # type: ignore - ) -> None: - with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): - if use_q_shared: - Q_shared = T.alloc_shared([H_per_block, D], dtype) - Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) - else: - Q_shared = T.alloc_fragment([H_per_block, D], dtype) - Q_tail_shared = T.alloc_fragment([H_per_block, D_tail], dtype) - KV_shared = T.alloc_shared([BI, D], dtype) - K_tail_shared = T.alloc_shared([BI, D_tail], dtype) - S_shared = T.alloc_shared([H_per_block, BI], dtype) - mask = T.alloc_fragment([BI], "bool") - - acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) - acc_o_tail = T.alloc_fragment([H_per_block, D_tail], accum_dtype) - acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) - sumexp = T.alloc_fragment([H_per_block], accum_dtype) - sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) - alpha = T.alloc_fragment([H_per_block], accum_dtype) - m_i = T.alloc_fragment([H_per_block], accum_dtype) - m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) - attn_sink_frag = T.alloc_fragment([H_per_block], FP32) - o_scale_frag = T.alloc_fragment([H_per_block], accum_dtype) - - T.fill(acc_o, 0) - T.fill(acc_o_tail, 0) - T.fill(sumexp, 0) - T.fill(m_i, -(2**30)) - - b_i, g_i = by, bz - s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) - H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) - H1 = H0 + H_per_block - - if use_swizzle and use_q_shared: - T.use_swizzle(10) - - T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) - T.copy(Q[b_i, s_i, H0:H1, D : D + D_tail], Q_tail_shared) - - for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): - mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0 - - for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[ - 0, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i - ] - for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[ - 0, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i - ] - - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else( - mask[bi_i], 0, -T.infinity(acc_s.dtype) - ) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - ) - T.gemm( - Q_tail_shared, - K_tail_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - ) - T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) - for h_i in T.Parallel(H_per_block): - m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) - for h_i in T.Parallel(H_per_block): - alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.exp2( - acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale - ) - T.reduce_sum(acc_s, sumexp_i, dim=1) - for h_i in T.Parallel(H_per_block): - sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] - for h_i, d_i in T.Parallel(H_per_block, D): - acc_o[h_i, d_i] *= alpha[h_i] - for h_i, d_i in T.Parallel(H_per_block, D_tail): - acc_o_tail[h_i, d_i] *= alpha[h_i] - - T.copy(acc_s, S_shared) - T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - S_shared, K_tail_shared, acc_o_tail, policy=T.GemmWarpPolicy.FullRow - ) - - # sumexp==0: output=0, LSE=+inf - for h_i, d_i in T.Parallel(H_per_block, D): - acc_o[h_i, d_i] = T.if_then_else( - sumexp[h_i] == 0, 0.0, acc_o[h_i, d_i] / sumexp[h_i] - ) - for h_i, d_i in T.Parallel(H_per_block, D_tail): - acc_o_tail[h_i, d_i] = T.if_then_else( - sumexp[h_i] == 0, 0.0, acc_o_tail[h_i, d_i] / sumexp[h_i] - ) - for h_i in T.Parallel(H_per_block): - m_i[h_i] = T.if_then_else( - sumexp[h_i] == 0, - T.infinity(accum_dtype), - (T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale) * ln2, - ) - - # attn_sink: o_scale = 1/(1+exp(attn_sink - lse)) - if use_attn_sink: - for h_i in T.Parallel(H_per_block): - attn_sink_frag[h_i] = Attn_sink[H0 + h_i] - for h_i in T.Parallel(H_per_block): - o_scale_frag[h_i] = T.if_then_else( - sumexp[h_i] == 0, - 0.0, - 1.0 / (1.0 + T.exp2((attn_sink_frag[h_i] - m_i[h_i]) * log2e)), - ) - for h_i, d_i in T.Parallel(H_per_block, D): - acc_o[h_i, d_i] = acc_o[h_i, d_i] * o_scale_frag[h_i] - for h_i, d_i in T.Parallel(H_per_block, D_tail): - acc_o_tail[h_i, d_i] = acc_o_tail[h_i, d_i] * o_scale_frag[h_i] - - T.copy(acc_o, Output[b_i, s_i, H0:H1, :D]) - T.copy(acc_o_tail, Output[b_i, s_i, H0:H1, D : D + D_tail]) - T.copy(m_i, LSE[b_i, s_i, H0:H1]) - - return main - - -def dpsk_v4_bf16_sparse_attention_fwd( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: Optional[torch.Tensor], - cache_seqlens: Optional[torch.Tensor], - head_dim_v: int, - tile_scheduler_metadata: Any, - num_splits: None = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - is_fp8_kvcache: bool = False, - indices: Optional[torch.Tensor] = None, - attn_sink: Optional[torch.Tensor] = None, - extra_k_cache: Optional[torch.Tensor] = None, - extra_indices_in_kvcache: Optional[torch.Tensor] = None, - topk_length: Optional[torch.Tensor] = None, - extra_topk_length: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - from sglang.srt.flashmla_tests import quant - - assert head_dim_v == 512 - - # select hyperconfig - if _is_gfx95_supported: - block_I = 64 - threads = 256 - num_stages = 0 - use_swizzle = True # MI355 has 160KB LDS, can afford swizzle - use_q_shared = True - elif _is_fp8_fnuz: # gfx94 (MI300) - block_I = 32 - threads = 128 - num_stages = 1 - use_swizzle = False # MI300 has only 64KB LDS, be conservative - use_q_shared = False - else: - raise Exception("Only support gf94x, gf95x") - - num_heads = q.shape[2] - batch, seq_len, _, _ = q.shape - - # q - q = q.contiguous() - - # k cache - k_bf16 = quant.dequantize_k_cache( - k_cache.view(FP8), quant.FP8KVCacheLayout.MODEL1_FP8Sparse - ) - num_blocks, block_size = k_bf16.shape[0], k_bf16.shape[1] - seq_len_kv = num_blocks * block_size - k_bf16 = k_bf16.reshape(seq_len_kv, 1, 512).unsqueeze(0).contiguous() - - # indices - indices = indices.unsqueeze(2).contiguous() - topk = indices.shape[-1] - if topk_length is not None: - for bi in range(batch): - valid = int(topk_length[bi].item()) - if valid < topk: - indices[bi, :, :, valid:] = -1 - - # attn_sink - if attn_sink is None: - attn_sink = torch.full( - (num_heads,), float("-inf"), dtype=torch.float32, device=q.device - ) - not_use_extra_k = extra_k_cache is None - - kernel = dpsk_v4_bf16_sparse_attention_kernel( - num_heads, - topk, - dim=448, - tail_dim=64, - sm_scale=softmax_scale, - block_I=block_I, - num_stages=num_stages, - threads=threads, - use_attn_sink=not_use_extra_k, - use_swizzle=use_swizzle, - use_q_shared=use_q_shared, - ) - o1, lse1 = kernel(q, k_bf16, indices, attn_sink) - - if not_use_extra_k: - return o1, lse1 - else: - # extra k cache - extra_k_bf16 = quant.dequantize_k_cache( - extra_k_cache.view(FP8), quant.FP8KVCacheLayout.MODEL1_FP8Sparse - ) - num_blocks, block_size = extra_k_bf16.shape[0], extra_k_bf16.shape[1] - seq_len_kv = num_blocks * block_size - extra_k_bf16 = ( - extra_k_bf16.reshape(seq_len_kv, 1, 512).unsqueeze(0).contiguous() - ) - - # indices - extra_indices = extra_indices_in_kvcache.unsqueeze(2).contiguous() - extra_topk = extra_indices.shape[-1] - if extra_topk_length is not None: - for bi in range(batch): - valid = int(extra_topk_length[bi].item()) - if valid < extra_topk: - extra_indices[bi, :, :, valid:] = -1 - - kernel = dpsk_v4_bf16_sparse_attention_kernel( - num_heads, - extra_topk, - dim=448, - tail_dim=64, - sm_scale=softmax_scale, - block_I=block_I, - num_stages=num_stages, - threads=threads, - use_attn_sink=not_use_extra_k, - use_swizzle=use_swizzle, - use_q_shared=use_q_shared, - ) - o2, lse2 = kernel(q, extra_k_bf16, extra_indices, attn_sink) - - def _merge_two_attn_out_lse( - o1: torch.Tensor, - lse1: torch.Tensor, - o2: torch.Tensor, - lse2: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - both_finite = torch.isfinite(lse1) & torch.isfinite(lse2) - lse_total = torch.where( - both_finite, - torch.logsumexp(torch.stack([lse1, lse2], dim=0), dim=0), - torch.where(torch.isfinite(lse1), lse1, lse2), - ) - w1 = torch.where( - both_finite, - torch.exp(lse1 - lse_total), - torch.where( - torch.isfinite(lse1), - torch.ones_like(lse1), - torch.zeros_like(lse1), - ), - ) - w2 = torch.where( - both_finite, - torch.exp(lse2 - lse_total), - torch.where( - torch.isfinite(lse2), - torch.ones_like(lse2), - torch.zeros_like(lse2), - ), - ) - o_total = w1.unsqueeze(-1) * o1.float() + w2.unsqueeze(-1) * o2.float() - return o_total, lse_total - - o1, lse1 = _merge_two_attn_out_lse(o1, lse1, o2, lse2) - - attn_sink_br = attn_sink.view(1, 1, -1) - o_scale = torch.sigmoid(lse1 - attn_sink_br) - output = (o1.float() * o_scale.unsqueeze(-1)).to(q.dtype) - lse_ok = torch.isfinite(lse1).unsqueeze(-1) - output = torch.where(lse_ok, output, torch.zeros_like(output)) - return output, lse1 diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index cc176116d6b3..397932f38379 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -969,7 +969,6 @@ def init_forward_metadata_replay_cuda_graph( spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: Optional[torch.Tensor] = None, - **kwargs, ): """Initialize forward metadata for replaying CUDA graph.""" assert seq_lens_cpu is not None diff --git a/python/sglang/srt/layers/attention/tbo_backend.py b/python/sglang/srt/layers/attention/tbo_backend.py index f56214a64d38..494d82d808e8 100644 --- a/python/sglang/srt/layers/attention/tbo_backend.py +++ b/python/sglang/srt/layers/attention/tbo_backend.py @@ -79,7 +79,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: "ForwardMode", spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): self.primary.init_forward_metadata_replay_cuda_graph( bs=bs, @@ -90,7 +89,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode=forward_mode, spec_info=spec_info, seq_lens_cpu=seq_lens_cpu, - **kwargs, ) self._init_forward_metadata_cuda_graph_children( diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index ba0573f5872d..d37f9101aa73 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -720,7 +720,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): # NOTE: encoder_lens expected to be zeros or None if forward_mode.is_decode_or_idle(): diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index eb27b4a3193d..0fe16379e866 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -450,7 +450,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): """Replay CUDA graph with new inputs.""" seq_lens = seq_lens[:bs] diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index c1ad49d129e4..65c1cdb549b5 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -518,7 +518,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): """Replay CUDA graph with new inputs.""" # Delegate to parent for non-decode modes. diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index 9bc6c7dc09af..9669a4568106 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -479,7 +479,6 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], - **kwargs, ): # NOTE: encoder_lens expected to be zeros or None if forward_mode.is_decode_or_idle(): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 74429337bf9b..5dcc968e949e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -1152,9 +1152,7 @@ def run_once(): # Python branch (if self.swa_loc is not None) takes the fast path, # and the graph records GPU ops using this buffer instead of the # per-layer translate_loc_from_full_to_swa fallback. - if self.buffers.out_cache_loc_swa is not None and hasattr( - self.model_runner.token_to_kv_pool, "set_swa_loc" - ): + if self.buffers.out_cache_loc_swa is not None: self.model_runner.token_to_kv_pool.set_swa_loc( self.buffers.out_cache_loc_swa[:num_tokens] ) @@ -1164,9 +1162,6 @@ def run_once(): self.model_runner.tp_group.barrier() run_once() - if hasattr(attn_backend, "on_after_cuda_graph_warmup_pass"): - attn_backend.on_after_cuda_graph_warmup_pass() - if get_global_graph_memory_pool() is None: set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory @@ -1272,7 +1267,6 @@ def replay_prepare( attn_backend = self.model_runner.decode_attn_backend_group[stream_idx] else: attn_backend = self.attn_backend - num_tokens = bs * self.num_tokens_per_bs attn_backend.init_forward_metadata_replay_cuda_graph( bs, buffers.req_pool_indices[:bs], @@ -1282,8 +1276,6 @@ def replay_prepare( self.capture_forward_mode, forward_batch.spec_info, seq_lens_cpu=buffers.seq_lens_cpu[:bs], - out_cache_loc=buffers.out_cache_loc[:num_tokens], - actual_forward_mode=forward_batch.forward_mode, ) # Store fields diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e8d0c74086f8..e8251720542f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -633,7 +633,6 @@ def forward( should_allreduce_fusion, use_reduce_scatter, gemm_output_zero_allocator, - input_ids_global=input_ids_global, ) else: return self.forward_normal( @@ -655,7 +654,6 @@ def forward_normal_dual_stream( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, - input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) @@ -671,12 +669,10 @@ def forward_normal_dual_stream( with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) - topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} topk_output = self.topk( hidden_states, router_logits, expert_location_dispatch_info=dispatch_info, - **topk_kwargs, ) final_hidden_states = self.experts(hidden_states, topk_output) if not (_is_cuda or _is_musa) or isinstance( diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 5c76c9caf9dc..d227e42d7116 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -2002,7 +2002,7 @@ def __init__( ) self.rms_norm_eps = config.rms_norm_eps self.alt_streams = ( - [torch.cuda.Stream() for _ in range(5)] if (_is_cuda) else None + [torch.cuda.Stream() for _ in range(5)] if (_is_cuda or _is_hip) else None ) self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers,