-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[GDN] Enable FI Blackwell GDN prefill kernel #40717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6aaab0a
560797f
be4dfdb
cb71f43
c93ebc1
d79e7fe
d897bfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,61 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> bool: | ||
| """Whether to use FlashInfer's GDN prefill kernel instead of the | ||
| Triton/FLA fallback. | ||
|
|
||
| Requirements: | ||
| * ``requested in ["flashinfer", "auto"]``; | ||
| * ``platform == cuda``; | ||
| * one of the following: | ||
| - Hopper (SM90) — no further constraints; | ||
| - Blackwell (SM10.x) with ``head_k_dim == 128`` and ``cuda_runtime >= 13``. | ||
| """ | ||
| if backend not in ["flashinfer", "auto"]: | ||
| return False | ||
| if not current_platform.is_cuda(): | ||
| return False | ||
| if current_platform.is_device_capability(90): | ||
| return True # Hopper — no further constraints. | ||
| if not current_platform.is_device_capability_family(100): | ||
| return False # Neither Hopper nor Blackwell. | ||
| if head_k_dim != 128: | ||
| return False | ||
| return current_platform.get_cuda_runtime_major() >= 13 | ||
|
|
||
|
|
||
| def _log_gdn_backend_decision( | ||
| backend: str, head_k_dim: int | None, use_flashinfer: bool | ||
| ) -> None: | ||
| """Dump the inputs to the backend decision and the final choice.""" | ||
| is_cuda = current_platform.is_cuda() | ||
| platform = "cuda" if is_cuda else current_platform.device_name | ||
| cuda_runtime = torch.version.cuda or "n/a" | ||
| device_cap = str(current_platform.get_device_capability()) if is_cuda else "n/a" | ||
| logger.info_once( | ||
| "GDN prefill backend inputs:\n" | ||
| " requested=%s\n" | ||
| " platform=%s, cuda_runtime=%s, device_capability=%s\n" | ||
| " head_k_dim=%s", | ||
| backend, | ||
| platform, | ||
| cuda_runtime, | ||
| device_cap, | ||
| head_k_dim, | ||
| scope="local", | ||
| ) | ||
| if use_flashinfer: | ||
| logger.info_once("Using FlashInfer GDN prefill kernel") | ||
| logger.info_once( | ||
| "FlashInfer GDN prefill kernel is JIT-compiled; first run may " | ||
| "take a while to compile. Set `--gdn-prefill-backend triton` to " | ||
| "avoid JIT compile time.", | ||
|
arpera marked this conversation as resolved.
|
||
| ) | ||
| else: | ||
| logger.info_once("Using Triton/FLA GDN prefill kernel") | ||
|
|
||
|
|
||
| def fi_chunk_gated_delta_rule( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
|
|
@@ -118,39 +173,21 @@ def fi_chunk_gated_delta_rule( | |
|
|
||
| @CustomOp.register("chunk_gated_delta_rule") | ||
| class ChunkGatedDeltaRule(CustomOp): | ||
| def __init__(self) -> None: | ||
| def __init__(self, head_k_dim: int | None = None) -> None: | ||
| super().__init__() | ||
| additional_config = get_current_vllm_config().additional_config | ||
| assert isinstance(additional_config, dict) | ||
| backend_cfg = additional_config.get("gdn_prefill_backend", "auto") | ||
| backend = str(backend_cfg).strip().lower() | ||
|
|
||
| supports_flashinfer = ( | ||
| current_platform.is_cuda() and current_platform.is_device_capability(90) | ||
| ) | ||
|
|
||
| if backend == "flashinfer": | ||
| use_flashinfer = supports_flashinfer | ||
| if not use_flashinfer: | ||
| logger.warning_once( | ||
| "GDN prefill backend 'flashinfer' is selected but " | ||
| "cannot use this kernel on the current platform. " | ||
| "Falling back to Triton/FLA." | ||
| ) | ||
| elif backend == "triton": | ||
| use_flashinfer = False | ||
| else: | ||
| use_flashinfer = supports_flashinfer | ||
|
|
||
| if use_flashinfer: | ||
| logger.info_once("Using FlashInfer GDN prefill kernel") | ||
| logger.info_once( | ||
| "FlashInfer GDN prefill kernel is JIT-compiled; first run may " | ||
| "take a while to compile. Set `--gdn-prefill-backend triton` to " | ||
| "avoid JIT compile time.", | ||
| use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim) | ||
| if backend == "flashinfer" and not use_flashinfer: | ||
| logger.warning_once( | ||
| "GDN prefill backend 'flashinfer' is selected but " | ||
| "cannot use this kernel on the current platform. " | ||
| "Falling back to Triton/FLA." | ||
| ) | ||
| else: | ||
| logger.info_once("Using Triton/FLA GDN prefill kernel") | ||
| _log_gdn_backend_decision(backend, head_k_dim, use_flashinfer) | ||
|
arpera marked this conversation as resolved.
|
||
|
|
||
| self._forward_method = ( | ||
| self.forward_cuda if use_flashinfer else self.forward_native | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my flavor, I'd place
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure? Currently we set |
||
|
|
@@ -380,7 +417,7 @@ def __init__( | |
| prefix=f"{prefix}.out_proj", | ||
| ) | ||
|
|
||
| self.chunk_gated_delta_rule = ChunkGatedDeltaRule() | ||
| self.chunk_gated_delta_rule = ChunkGatedDeltaRule(head_k_dim=self.head_k_dim) | ||
| self.enable_packed_recurrent_decode = ( | ||
| envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.