-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fuse MLA RoPE + FP8 Quantization + KV Cache Write into Single CUDA Kernel #12503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
YAMY1234
wants to merge
7
commits into
sgl-project:main
Choose a base branch
from
YAMY1234:mtp_fuse_kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,373
−50
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
97530d2
fuse kernel initial code
YAMY1234 1cd4f8e
single req successful
YAMY1234 2d4c444
run successfully parallel
YAMY1234 de83169
small code adjustment
YAMY1234 a97c741
minor fix
YAMY1234 de2b55b
update unit tests
YAMY1234 5d373e6
code cleaning
YAMY1234 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| Microbenchmark: fused vs baseline (emulated) for MLA RoPE + FP8 + KV write. | ||
| Uses the sgl_kernel.mla_rope_quantize_fp8_fused extension. | ||
| """ | ||
| import time | ||
|
|
||
| import torch | ||
|
|
||
| _has_sgl_kernel = False | ||
| mla_rope_quantize_fp8_fused = None | ||
| try: | ||
| from mla_fusion_kernel import mla_rope_quantize_fp8_fused | ||
|
|
||
| _has_sgl_kernel = True | ||
| print("Using standalone mla_fusion_kernel") | ||
| except ImportError: | ||
| try: | ||
| from sgl_kernel import mla_rope_quantize_fp8_fused | ||
|
|
||
| _has_sgl_kernel = True | ||
| print("Using sgl_kernel.mla_rope_quantize_fp8_fused") | ||
| except ImportError: | ||
| print( | ||
| "ERROR: Fusion kernel not available. Please build mla_fusion_standalone first." | ||
| ) | ||
| _has_sgl_kernel = False | ||
|
|
||
|
|
||
| def run_one(nnz=1024, Dn=512, Dr=64, iters=200, warmup=20, device="cuda"): | ||
| if not _has_sgl_kernel: | ||
| return 0, 0, 0 | ||
|
|
||
| torch.manual_seed(0) | ||
|
|
||
| q_nope = torch.randn(nnz, Dn, device=device, dtype=torch.float16) | ||
| q_rope = torch.randn(nnz, Dr, device=device, dtype=torch.float16) | ||
| k_nope = torch.randn(nnz, Dn, device=device, dtype=torch.float16) | ||
| k_rope = torch.randn(nnz, Dr, device=device, dtype=torch.float16) | ||
|
|
||
| max_seq = max(2048, nnz) | ||
| t = torch.linspace(0, 1, steps=max_seq, device=device, dtype=torch.float32)[:, None] | ||
| idx = torch.arange(Dr, device=device, dtype=torch.float32)[None, :] | ||
| freqs = 0.1 * (idx + 1.0) | ||
| cos = torch.cos(t * freqs) | ||
| sin = torch.sin(t * freqs) | ||
| cos_sin = torch.cat([cos, sin], dim=1) | ||
| pos_ids = torch.randint( | ||
| low=0, high=max_seq, size=(nnz,), device=device, dtype=torch.long | ||
| ) | ||
|
|
||
| slots = nnz + 8 | ||
| loc = torch.arange(nnz, device=device, dtype=torch.long) | ||
|
|
||
| q_out = torch.empty(nnz, Dn + Dr, device=device, dtype=torch.uint8) | ||
| k_nope_out = torch.empty(nnz, Dn, device=device, dtype=torch.uint8) | ||
| k_rope_out = torch.empty(nnz, Dr, device=device, dtype=torch.uint8) | ||
| kv_base = torch.zeros(slots, Dn + Dr, device=device, dtype=torch.uint8) | ||
|
|
||
| # baselines | ||
| def baseline(): | ||
| mla_rope_quantize_fp8_fused( | ||
| q_nope, | ||
| q_rope, | ||
| k_nope, | ||
| k_rope, | ||
| cos_sin, | ||
| pos_ids, | ||
| False, | ||
| q_out, | ||
| k_nope_out, | ||
| k_rope_out, | ||
| None, | ||
| None, | ||
| ) | ||
| kv_base.zero_() | ||
| kv_base[loc, :Dn] = k_nope_out | ||
| kv_base[loc, Dn:] = k_rope_out | ||
|
|
||
| kv_fused = torch.zeros_like(kv_base) | ||
|
|
||
| def fused(): | ||
| mla_rope_quantize_fp8_fused( | ||
| q_nope, | ||
| q_rope, | ||
| k_nope, | ||
| k_rope, | ||
| cos_sin, | ||
| pos_ids, | ||
| False, | ||
| q_out, | ||
| None, | ||
| None, | ||
| kv_fused, | ||
| loc, | ||
| ) | ||
|
|
||
| # warmup | ||
| for _ in range(warmup): | ||
| baseline() | ||
| torch.cuda.synchronize() | ||
| t0 = time.time() | ||
| for _ in range(iters): | ||
| baseline() | ||
| torch.cuda.synchronize() | ||
| t1 = time.time() | ||
| baseline_ms = (t1 - t0) * 1000.0 / iters | ||
|
|
||
| for _ in range(warmup): | ||
| fused() | ||
| torch.cuda.synchronize() | ||
| t0 = time.time() | ||
| for _ in range(iters): | ||
| fused() | ||
| torch.cuda.synchronize() | ||
| t1 = time.time() | ||
| fused_ms = (t1 - t0) * 1000.0 / iters | ||
|
|
||
| return baseline_ms, fused_ms, baseline_ms / fused_ms | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| if not _has_sgl_kernel: | ||
| print("Benchmark skipped: sgl_kernel not available") | ||
| exit(1) | ||
|
|
||
| print("MLA RoPE + FP8 Quantization + KV Cache Write Fusion Benchmark") | ||
| print("=" * 70) | ||
| print("Config: Dn=512, Dr=64, iters=1000, warmup=100") | ||
| print("=" * 70) | ||
|
|
||
| # Test larger batch sizes and more iterations for stable measurements | ||
| for nnz in [1024, 4096, 8192, 16384, 32768]: | ||
| b, f, s = run_one(nnz=nnz, iters=1000, warmup=100) | ||
| if b > 0: | ||
| speedup_pct = (s - 1.0) * 100 | ||
| print( | ||
| f"nnz={nnz:5d} | baseline={b:7.3f} ms | fused={f:7.3f} ms | speedup x{s:4.2f} ({speedup_pct:+5.1f}%)" | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| [build-system] | ||
| requires = ["setuptools", "wheel", "torch"] | ||
| build-backend = "setuptools.build_meta" | ||
|
|
||
| [project] | ||
| name = "mla-fusion-kernel" | ||
| version = "0.1.0" | ||
| description = "Standalone MLA RoPE + FP8 Fusion Kernel" | ||
| requires-python = ">=3.8" | ||
| dependencies = ["torch"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """ | ||
| Standalone build for MLA RoPE FP8 Fusion kernel | ||
| """ | ||
|
|
||
| import os | ||
| import sys | ||
|
|
||
| from setuptools import setup | ||
|
|
||
|
|
||
| # Delay torch import until build time | ||
| def get_cuda_arch(): | ||
| try: | ||
| import torch | ||
|
|
||
| cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) | ||
| if cuda_arch_list is None: | ||
| # Auto-detect | ||
| if torch.cuda.is_available(): | ||
| capability = torch.cuda.get_device_capability() | ||
| cuda_arch_list = f"{capability[0]}.{capability[1]}" | ||
| else: | ||
| # Default to common architectures | ||
| cuda_arch_list = "8.0;9.0;10.0" | ||
| print(f"Building for CUDA architectures: {cuda_arch_list}") | ||
| return cuda_arch_list | ||
| except Exception as e: | ||
| print(f"Warning: Could not detect CUDA arch, using defaults: {e}") | ||
| return "8.0;9.0;10.0" | ||
|
|
||
|
|
||
| def get_extensions(): | ||
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
|
||
| cuda_arch_list = get_cuda_arch() | ||
|
|
||
| return [ | ||
| CUDAExtension( | ||
| name="mla_fusion_kernel", | ||
| sources=[ | ||
| "../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu", | ||
| ], | ||
| include_dirs=[ | ||
| "../sgl-kernel/include", | ||
| ], | ||
| extra_compile_args={ | ||
| "cxx": ["-O3", "-std=c++17"], | ||
| "nvcc": [ | ||
| "-O3", | ||
| "--use_fast_math", | ||
| "-U__CUDA_NO_HALF_OPERATORS__", | ||
| "-U__CUDA_NO_HALF_CONVERSIONS__", | ||
| "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | ||
| "-U__CUDA_NO_HALF2_OPERATORS__", | ||
| "--expt-relaxed-constexpr", | ||
| "--expt-extended-lambda", | ||
| ] | ||
| + [ | ||
| f'-gencode=arch=compute_{arch.replace(".", "")},code=sm_{arch.replace(".", "")}' | ||
| for arch in cuda_arch_list.split(";") | ||
| ], | ||
| }, | ||
| ) | ||
| ] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from torch.utils.cpp_extension import BuildExtension | ||
|
|
||
| setup( | ||
| name="mla_fusion_kernel", | ||
| ext_modules=get_extensions(), | ||
| cmdclass={"build_ext": BuildExtension}, | ||
| python_requires=">=3.8", | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sourcesandinclude_dirspaths for theCUDAExtensionuse relative paths with... This makes the build script dependent on the current working directory from whichsetup.pyis invoked. To make it more robust, you could construct absolute paths based on the location of thesetup.pyfile itself.