-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[FlashInfer v0.6.10] [RL] [DSv32] [GLM-5] Add --dsa-topk-backend and integrate FlashInfer and pytorch topk
#22851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f665246
Initial impl
zianglih 4d052c0
Add docs
zianglih c70bc30
Add docs_new
zianglih 92a2e68
Minor reformat
zianglih 660d7ef
Clean up after rebase
zianglih f4e3705
Update tie break docs
zianglih dcce26d
Refactor
zianglih f20dcdd
Clean up
zianglih fb7fa68
Refactor to `[None, "small", "large"]`
zianglih aeeedc8
Clean up
zianglih c430da9
Address DSA topk review comments
zianglih 2685a15
Merge branch 'main' into torch-topk
Fridge003 23d7552
Merge branch 'main' into torch-topk
zianglih File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
271 changes: 271 additions & 0 deletions
271
python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,271 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from enum import Enum, IntEnum, auto | ||
| from typing import Callable, Dict, List, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.environ import envs | ||
|
|
||
| _FLASHINFER_TIE_BREAK_VALUES = { | ||
| "small": 1, | ||
| "large": 2, | ||
| } | ||
|
|
||
|
|
||
| class TopkTransformMethod(IntEnum): | ||
| # Transform topk indices to indices to the page table (page_size = 1) | ||
| PAGED = auto() | ||
| # Transform topk indices to indices to ragged kv (non-paged) | ||
| RAGGED = auto() | ||
|
|
||
|
|
||
| class DSATopKBackend(Enum): | ||
| SGL_KERNEL = "sgl-kernel" | ||
| TORCH = "torch" | ||
| FLASHINFER = "flashinfer" | ||
|
|
||
| def is_sgl_kernel(self) -> bool: | ||
| return self == DSATopKBackend.SGL_KERNEL | ||
|
|
||
| def is_torch(self) -> bool: | ||
| return self == DSATopKBackend.TORCH | ||
|
|
||
| def is_flashinfer(self) -> bool: | ||
| return self == DSATopKBackend.FLASHINFER | ||
|
|
||
| def topk_func( | ||
| self, | ||
| score: torch.Tensor, | ||
| lengths: torch.Tensor, | ||
| topk: int, | ||
| row_starts: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| if self.is_sgl_kernel(): | ||
| from sgl_kernel import fast_topk_v2 | ||
|
|
||
| return fast_topk_v2(score, lengths, topk, row_starts=row_starts) | ||
| if self.is_torch(): | ||
| return _topk_unfused( | ||
| score, | ||
| lengths, | ||
| topk, | ||
| row_starts=row_starts, | ||
| topk_op=torch.topk, | ||
| topk_op_kwargs={"dim": -1}, | ||
| ) | ||
| if self.is_flashinfer(): | ||
| import flashinfer | ||
|
|
||
| return _topk_unfused( | ||
| score, | ||
| lengths, | ||
| topk, | ||
| row_starts=row_starts, | ||
| topk_op=flashinfer.top_k, | ||
| topk_op_kwargs={ | ||
| "sorted": False, | ||
| "deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), | ||
| "tie_break": _flashinfer_tie_break_value(), | ||
| "dsa_graph_safe": True, | ||
| }, | ||
| ) | ||
| raise RuntimeError(f"Unsupported {self = }.") | ||
|
|
||
| def topk_transform( | ||
| self, | ||
| logits: torch.Tensor, | ||
| lengths: torch.Tensor, | ||
| topk: int, | ||
| topk_transform_method: TopkTransformMethod, | ||
| attn_metadata, | ||
| cu_seqlens_q_topk: Optional[torch.Tensor] = None, | ||
| topk_indices_offset: Optional[torch.Tensor] = None, | ||
| row_starts: Optional[torch.Tensor] = None, | ||
| batch_idx_list: Optional[List[int]] = None, | ||
| force_unfused_topk: bool = False, | ||
| ) -> torch.Tensor: | ||
| if not envs.SGLANG_DSA_FUSE_TOPK.get() or force_unfused_topk: | ||
| return self.topk_func(logits, lengths, topk, row_starts=row_starts) | ||
|
|
||
| if self.is_sgl_kernel(): | ||
| from sgl_kernel import ( | ||
| fast_topk_transform_fused, | ||
| fast_topk_transform_ragged_fused, | ||
| ) | ||
|
|
||
| if topk_transform_method == TopkTransformMethod.PAGED: | ||
| page_table_size_1 = ( | ||
| attn_metadata.page_table_1[batch_idx_list] | ||
| if batch_idx_list is not None | ||
| else attn_metadata.page_table_1 | ||
| ) | ||
| return fast_topk_transform_fused( | ||
| score=logits, | ||
| lengths=lengths, | ||
| page_table_size_1=page_table_size_1, | ||
| cu_seqlens_q=cu_seqlens_q_topk, | ||
| topk=topk, | ||
| row_starts=row_starts, | ||
| ) | ||
| if topk_transform_method == TopkTransformMethod.RAGGED: | ||
| if topk_indices_offset is None: | ||
| raise RuntimeError( | ||
| "RAGGED topk_transform requires topk_indices_offset; " | ||
| "expected extend-without-speculative metadata." | ||
| ) | ||
| return fast_topk_transform_ragged_fused( | ||
| score=logits, | ||
| lengths=lengths, | ||
| topk_indices_offset=topk_indices_offset, | ||
| topk=topk, | ||
| row_starts=row_starts, | ||
| ) | ||
| raise RuntimeError(f"Unsupported {topk_transform_method = }.") | ||
|
|
||
| if self.is_flashinfer(): | ||
| import flashinfer | ||
|
|
||
| if topk_transform_method == TopkTransformMethod.PAGED: | ||
| row_to_batch, local_row_starts = _build_flashinfer_paged_args( | ||
| attn_metadata=attn_metadata, | ||
| row_starts=row_starts, | ||
| cu_seqlens_q_topk=cu_seqlens_q_topk, | ||
| batch_idx_list=batch_idx_list, | ||
| device=logits.device, | ||
| num_rows=logits.shape[0], | ||
| ) | ||
| return flashinfer.top_k_page_table_transform( | ||
| logits.contiguous(), | ||
| attn_metadata.page_table_1.contiguous(), | ||
| lengths.contiguous(), | ||
| topk, | ||
| row_to_batch=row_to_batch, | ||
| deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), | ||
| tie_break=_flashinfer_tie_break_value(), | ||
| dsa_graph_safe=True, | ||
| row_starts=local_row_starts, | ||
| ) | ||
| if topk_transform_method == TopkTransformMethod.RAGGED: | ||
| if topk_indices_offset is None: | ||
| raise RuntimeError( | ||
| "RAGGED topk_transform requires topk_indices_offset; " | ||
| "expected extend-without-speculative metadata." | ||
| ) | ||
| return flashinfer.top_k_ragged_transform( | ||
| logits.contiguous(), | ||
| topk_indices_offset.contiguous(), | ||
| lengths.contiguous(), | ||
| topk, | ||
| deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), | ||
| tie_break=_flashinfer_tie_break_value(), | ||
| dsa_graph_safe=True, | ||
| row_starts=row_starts, | ||
| ) | ||
| raise RuntimeError(f"Unsupported {topk_transform_method = }.") | ||
|
|
||
| raise RuntimeError(f"Unsupported {self = } for SGLANG_DSA_FUSE_TOPK.") | ||
|
|
||
|
|
||
| def _topk_unfused( | ||
| score: torch.Tensor, | ||
| lengths: torch.Tensor, | ||
| topk: int, | ||
| row_starts: Optional[torch.Tensor] = None, | ||
| topk_op: Callable[..., Tuple[torch.Tensor, torch.Tensor]] = torch.topk, | ||
| topk_op_kwargs: Optional[Dict[str, object]] = None, | ||
| ) -> torch.Tensor: | ||
| batch_size, max_score_len = score.shape | ||
| topk_indices = score.new_full((batch_size, topk), -1, dtype=torch.int32) | ||
| if batch_size == 0 or topk == 0 or max_score_len == 0: | ||
| return topk_indices | ||
|
|
||
| if row_starts is None: | ||
| row_starts = torch.zeros_like(lengths, dtype=torch.int32, device=score.device) | ||
| else: | ||
| row_starts = row_starts.to(dtype=torch.int32, device=score.device) | ||
| lengths = lengths.to(dtype=torch.int32, device=score.device) | ||
|
|
||
| col_indices = torch.arange(max_score_len, dtype=torch.int32, device=score.device) | ||
| col_indices = col_indices.unsqueeze(0) | ||
| row_starts_unsqueezed = row_starts.unsqueeze(1) | ||
| row_ends_unsqueezed = (row_starts + lengths).unsqueeze(1) | ||
| valid_mask = (col_indices >= row_starts_unsqueezed) & ( | ||
| col_indices < row_ends_unsqueezed | ||
| ) | ||
|
|
||
| masked_logits = score.masked_fill(~valid_mask, float("-inf")) | ||
| valid_topk = min(topk, max_score_len) | ||
| topk_kwargs = topk_op_kwargs or {} | ||
| topk_scores, topk_col_indices = topk_op(masked_logits, valid_topk, **topk_kwargs) | ||
| topk_local_indices = topk_col_indices.to(torch.int32) - row_starts_unsqueezed | ||
| topk_local_indices = topk_local_indices.masked_fill( | ||
| topk_scores == float("-inf"), -1 | ||
| ) | ||
| topk_indices[:, :valid_topk] = topk_local_indices | ||
|
|
||
| return topk_indices | ||
|
|
||
|
|
||
| def _build_flashinfer_paged_args( | ||
| attn_metadata, | ||
| row_starts: Optional[torch.Tensor], | ||
| cu_seqlens_q_topk: Optional[torch.Tensor], | ||
| batch_idx_list: Optional[List[int]], | ||
| device: torch.device, | ||
| num_rows: int, | ||
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
| row_to_batch = ( | ||
| torch.as_tensor(batch_idx_list, dtype=torch.int32, device=device) | ||
| if batch_idx_list is not None | ||
| else None | ||
| ) | ||
|
|
||
| if ( | ||
| row_to_batch is not None | ||
| and cu_seqlens_q_topk is not None | ||
| and row_to_batch.shape[0] != num_rows | ||
| ): | ||
| q_lens = (cu_seqlens_q_topk[1:] - cu_seqlens_q_topk[:-1]).to( | ||
| dtype=torch.int32, device=device | ||
| ) | ||
| row_to_batch = torch.repeat_interleave(row_to_batch, q_lens) | ||
|
|
||
| if row_to_batch is None and cu_seqlens_q_topk is not None: | ||
| # Decode-like case (one query row per batch) does not need an explicit mapping. | ||
| # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. | ||
| num_batches = cu_seqlens_q_topk.shape[0] - 1 | ||
| if not (row_starts is None and num_rows == num_batches): | ||
| q_lens = (cu_seqlens_q_topk[1:] - cu_seqlens_q_topk[:-1]).to( | ||
| dtype=torch.int32, device=device | ||
| ) | ||
| row_to_batch = torch.repeat_interleave( | ||
| torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), | ||
| q_lens, | ||
| ) | ||
|
|
||
| if row_starts is not None and row_to_batch is None: | ||
| raise RuntimeError( | ||
| "PAGED topk_transform with row_starts requires cu_seqlens_q metadata." | ||
| ) | ||
|
|
||
| local_row_starts = row_starts | ||
| if local_row_starts is not None and row_to_batch is not None: | ||
| local_row_starts = ( | ||
| local_row_starts - attn_metadata.cu_seqlens_k[:-1][row_to_batch] | ||
| ) | ||
|
|
||
| return row_to_batch, local_row_starts | ||
|
|
||
|
|
||
| def _flashinfer_tie_break_value() -> int: | ||
| mode = envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get() | ||
| if mode is None: | ||
| return 0 | ||
| mode = mode.lower() | ||
| if mode not in _FLASHINFER_TIE_BREAK_VALUES: | ||
| raise RuntimeError( | ||
| "SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK must be one of " | ||
| f"{tuple(_FLASHINFER_TIE_BREAK_VALUES)} or unset, got {mode!r}." | ||
| ) | ||
| return _FLASHINFER_TIE_BREAK_VALUES[mode] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.