Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions benchmark/kernels/bench_flashmla_fused_kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""
Microbenchmark: fused vs baseline (emulated) for MLA RoPE + FP8 + KV write.
Uses the sgl_kernel.mla_rope_quantize_fp8_fused extension.
"""
import time

import torch

_has_sgl_kernel = False
mla_rope_quantize_fp8_fused = None
try:
from mla_fusion_kernel import mla_rope_quantize_fp8_fused

_has_sgl_kernel = True
print("Using standalone mla_fusion_kernel")
except ImportError:
try:
from sgl_kernel import mla_rope_quantize_fp8_fused

_has_sgl_kernel = True
print("Using sgl_kernel.mla_rope_quantize_fp8_fused")
except ImportError:
print(
"ERROR: Fusion kernel not available. Please build mla_fusion_standalone first."
)
_has_sgl_kernel = False


def run_one(nnz=1024, Dn=512, Dr=64, iters=200, warmup=20, device="cuda"):
if not _has_sgl_kernel:
return 0, 0, 0

torch.manual_seed(0)

q_nope = torch.randn(nnz, Dn, device=device, dtype=torch.float16)
q_rope = torch.randn(nnz, Dr, device=device, dtype=torch.float16)
k_nope = torch.randn(nnz, Dn, device=device, dtype=torch.float16)
k_rope = torch.randn(nnz, Dr, device=device, dtype=torch.float16)

max_seq = max(2048, nnz)
t = torch.linspace(0, 1, steps=max_seq, device=device, dtype=torch.float32)[:, None]
idx = torch.arange(Dr, device=device, dtype=torch.float32)[None, :]
freqs = 0.1 * (idx + 1.0)
cos = torch.cos(t * freqs)
sin = torch.sin(t * freqs)
cos_sin = torch.cat([cos, sin], dim=1)
pos_ids = torch.randint(
low=0, high=max_seq, size=(nnz,), device=device, dtype=torch.long
)

slots = nnz + 8
loc = torch.arange(nnz, device=device, dtype=torch.long)

q_out = torch.empty(nnz, Dn + Dr, device=device, dtype=torch.uint8)
k_nope_out = torch.empty(nnz, Dn, device=device, dtype=torch.uint8)
k_rope_out = torch.empty(nnz, Dr, device=device, dtype=torch.uint8)
kv_base = torch.zeros(slots, Dn + Dr, device=device, dtype=torch.uint8)

# baselines
def baseline():
mla_rope_quantize_fp8_fused(
q_nope,
q_rope,
k_nope,
k_rope,
cos_sin,
pos_ids,
False,
q_out,
k_nope_out,
k_rope_out,
None,
None,
)
kv_base.zero_()
kv_base[loc, :Dn] = k_nope_out
kv_base[loc, Dn:] = k_rope_out

kv_fused = torch.zeros_like(kv_base)

def fused():
mla_rope_quantize_fp8_fused(
q_nope,
q_rope,
k_nope,
k_rope,
cos_sin,
pos_ids,
False,
q_out,
None,
None,
kv_fused,
loc,
)

# warmup
for _ in range(warmup):
baseline()
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
baseline()
torch.cuda.synchronize()
t1 = time.time()
baseline_ms = (t1 - t0) * 1000.0 / iters

for _ in range(warmup):
fused()
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
fused()
torch.cuda.synchronize()
t1 = time.time()
fused_ms = (t1 - t0) * 1000.0 / iters

return baseline_ms, fused_ms, baseline_ms / fused_ms


if __name__ == "__main__":
if not _has_sgl_kernel:
print("Benchmark skipped: sgl_kernel not available")
exit(1)

print("MLA RoPE + FP8 Quantization + KV Cache Write Fusion Benchmark")
print("=" * 70)
print("Config: Dn=512, Dr=64, iters=1000, warmup=100")
print("=" * 70)

# Test larger batch sizes and more iterations for stable measurements
for nnz in [1024, 4096, 8192, 16384, 32768]:
b, f, s = run_one(nnz=nnz, iters=1000, warmup=100)
if b > 0:
speedup_pct = (s - 1.0) * 100
print(
f"nnz={nnz:5d} | baseline={b:7.3f} ms | fused={f:7.3f} ms | speedup x{s:4.2f} ({speedup_pct:+5.1f}%)"
)
10 changes: 10 additions & 0 deletions mla_fusion_standalone/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[build-system]
requires = ["setuptools", "wheel", "torch"]
build-backend = "setuptools.build_meta"

[project]
name = "mla-fusion-kernel"
version = "0.1.0"
description = "Standalone MLA RoPE + FP8 Fusion Kernel"
requires-python = ">=3.8"
dependencies = ["torch"]
75 changes: 75 additions & 0 deletions mla_fusion_standalone/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Standalone build for MLA RoPE FP8 Fusion kernel
"""

import os
import sys

from setuptools import setup


# Delay torch import until build time
def get_cuda_arch():
try:
import torch

cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if cuda_arch_list is None:
# Auto-detect
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
cuda_arch_list = f"{capability[0]}.{capability[1]}"
else:
# Default to common architectures
cuda_arch_list = "8.0;9.0;10.0"
print(f"Building for CUDA architectures: {cuda_arch_list}")
return cuda_arch_list
except Exception as e:
print(f"Warning: Could not detect CUDA arch, using defaults: {e}")
return "8.0;9.0;10.0"


def get_extensions():
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

cuda_arch_list = get_cuda_arch()

return [
CUDAExtension(
name="mla_fusion_kernel",
sources=[
"../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu",
],
include_dirs=[
"../sgl-kernel/include",
Comment on lines +41 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sources and include_dirs paths for the CUDAExtension use relative paths with ... This makes the build script dependent on the current working directory from which setup.py is invoked. To make it more robust, you could construct absolute paths based on the location of the setup.py file itself.

Suggested change
"../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu",
],
include_dirs=[
"../sgl-kernel/include",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "sgl-kernel", "csrc", "elementwise", "mla_rope_fp8_kv_fused.cu"),
],
include_dirs=[
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "sgl-kernel", "include"),

],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": [
"-O3",
"--use_fast_math",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
+ [
f'-gencode=arch=compute_{arch.replace(".", "")},code=sm_{arch.replace(".", "")}'
for arch in cuda_arch_list.split(";")
],
},
)
]


if __name__ == "__main__":
from torch.utils.cpp_extension import BuildExtension

setup(
name="mla_fusion_kernel",
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension},
python_requires=">=3.8",
)
Loading
Loading