Skip to content

[fusion] add composable fusion pass framework#10549

Open
DevashishLal-CB wants to merge 19 commits intosgl-project:mainfrom
DevashishLal-CB:gh/dlal/sgl-fusion
Open

[fusion] add composable fusion pass framework#10549
DevashishLal-CB wants to merge 19 commits intosgl-project:mainfrom
DevashishLal-CB:gh/dlal/sgl-fusion

Conversation

@DevashishLal-CB
Copy link
Contributor

@DevashishLal-CB DevashishLal-CB commented Sep 17, 2025

Motivation

Initial implementation of the changes proposed in #10118

Modifications

This PR adds the fusion passes and integration tests for them

Passes added

  • gate_up project + silu and mul (triton kernel, cutlass kernel is possible)
  • gate_up project + silu and mul + quant (triton kernel, cutlass kernel is possible)
  • rmsnorm + quant (kernel from vllm)
  • fused_add_rmsnorm + quant (kernel from vllm)

For fusion passes to work with cuda graph runner I had to get rid for the model patching (or I could rewrite the pass with the pattern functions looking for pure pytorch code, we should avoid this model patching as it will interfere with the compilation process)

I have also added model_bench.py, the idea with this is to provide a stripped down sglang runtime where each layer can be instantiated in isolation helping write integration and accuracy tests from fusion passes and fused kernels

Accuracy Tests

Benchmarking and Profiling

MM + Silu and Mul fusion

 ===== Before_FusedActivationPass =====
 <eval_with_key>.198 from /usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py:1301 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "bf16[1, 2048][2048, 1]cuda:0", arg1_1: "bf16[1, 2048][2048, 1]cuda:0", arg2_1: "bf16[2048][1]cuda:0", arg3_1: "bf16[16384, 2048][2048, 1]cuda:0", arg4_1: "bf16[2048, 8192][8192, 1]cuda:0"):
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.fused_add_rmsnorm.default, weight = arg2_1, eps = 1e-05, enable_pdl = False, _input_base_index = 0, _residual_base_index = 1, _all_bases = [arg0_1, arg1_1]);  arg2_1 = None
        getitem_1: "bf16[1, 2048][2048, 1]cuda:0" = auto_functionalized_v2[1]
        getitem_2: "bf16[1, 2048][2048, 1]cuda:0" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/unquant.py:135 in apply, code: return F.linear(x, layer.weight, bias)
        permute: "bf16[2048, 16384][1, 2048]cuda:0" = torch.ops.aten.permute.default(arg3_1, [1, 0]);  arg3_1 = None
        mm: "bf16[1, 16384][16384, 1]cuda:0" = torch.ops.aten.mm.default(getitem_1, permute);  permute = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/activation.py:69 in forward_cuda, code: out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        empty: "bf16[1, 8192][8192, 1]cuda:0" = torch.ops.aten.empty.memory_format([1, 8192], dtype = torch.bfloat16, device = device(type='cuda', index=0), pin_memory = False)
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:183 in silu_and_mul, code: torch.ops.sgl_kernel.silu_and_mul.default(out, input)
        auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.silu_and_mul.default, input = mm, _out_base_index = 0, _all_bases = [empty]);  mm = empty = None
        getitem_4: "bf16[1, 8192][8192, 1]cuda:0" = auto_functionalized_v2_1[1];  auto_functionalized_v2_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/unquant.py:135 in apply, code: return F.linear(x, layer.weight, bias)
        permute_1: "bf16[8192, 2048][1, 8192]cuda:0" = torch.ops.aten.permute.default(arg4_1, [1, 0]);  arg4_1 = None
        mm_1: "bf16[1, 2048][2048, 1]cuda:0" = torch.ops.aten.mm.default(getitem_4, permute_1);  getitem_4 = permute_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        copy_: "bf16[1, 2048][2048, 1]cuda:0" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        copy__1: "bf16[1, 2048][2048, 1]cuda:0" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy__1 = None
        return (mm_1,)
        
[2025-09-17 02:10:27] TRACED GRAPH
 ===== After_FusedActivationPass =====
 <eval_with_key>.198 from /usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py:1301 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "bf16[1, 2048][2048, 1]cuda:0", arg1_1: "bf16[1, 2048][2048, 1]cuda:0", arg2_1: "bf16[2048][1]cuda:0", arg3_1: "bf16[16384, 2048][2048, 1]cuda:0", arg4_1: "bf16[2048, 8192][8192, 1]cuda:0"):
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.fused_add_rmsnorm.default, weight = arg2_1, eps = 1e-05, enable_pdl = False, _input_base_index = 0, _residual_base_index = 1, _all_bases = [arg0_1, arg1_1]);  arg2_1 = None
        getitem_1: "bf16[1, 2048][2048, 1]cuda:0" = auto_functionalized_v2[1]
        getitem_2: "bf16[1, 2048][2048, 1]cuda:0" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/unquant.py:135 in apply, code: return F.linear(x, layer.weight, bias)
        permute: "bf16[2048, 16384][1, 2048]cuda:0" = torch.ops.aten.permute.default(arg3_1, [1, 0]);  arg3_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/activation.py:69 in forward_cuda, code: out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        empty: "bf16[1, 8192][8192, 1]cuda:0" = torch.ops.aten.empty.memory_format([1, 8192], dtype = torch.bfloat16, device = device(type='cuda', index=0), pin_memory = False);  empty = None
        
        # No stacktrace found for following nodes
        fused_swiglu_default: "bf16[1, 8192][8192, 1]cuda:0" = torch.ops.sglang.fused_swiglu.default(getitem_1, permute);  permute = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/unquant.py:135 in apply, code: return F.linear(x, layer.weight, bias)
        permute_1: "bf16[8192, 2048][1, 8192]cuda:0" = torch.ops.aten.permute.default(arg4_1, [1, 0]);  arg4_1 = None
        mm_1: "bf16[1, 2048][2048, 1]cuda:0" = torch.ops.aten.mm.default(fused_swiglu_default, permute_1);  fused_swiglu_default = permute_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        copy_: "bf16[1, 2048][2048, 1]cuda:0" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        copy__1: "bf16[1, 2048][2048, 1]cuda:0" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy__1 = None
        return (mm_1,)
        
[2025-09-17 02:10:27] FusedActivationPass completed in 19.8 ms, matched 1 times

MM + Silu and Mul + Quant (I have a small diff to use sgl_per_tensor_quant_fp8 for quant instead of the triton quant kernel, will add support for the default quant kernel before merge)

[2025-09-17 02:10:39] TRACED GRAPH
 ===== Before_FusedActivationPass =====
 <eval_with_key>.56 from /usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py:1301 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f16[6, 4096][4096, 1]cuda:0", arg1_1: "f16[6, 4096][4096, 1]cuda:0", arg2_1: "f16[4096][1]cuda:0", arg3_1: "f8e4m3fn[4096, 22016][1, 4096]cuda:0", arg4_1: "f32[][]cuda:0", arg5_1: "f32[][]cuda:0", arg6_1: "f8e4m3fn[11008, 4096][1, 11008]cuda:0", arg7_1: "f32[][]cuda:0", arg8_1: "f32[][]cuda:0"):
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.fused_add_rmsnorm.default, weight = arg2_1, eps = 1e-05, enable_pdl = False, _input_base_index = 0, _residual_base_index = 1, _all_bases = [arg0_1, arg1_1]);  arg2_1 = None
        getitem_1: "f16[6, 4096][4096, 1]cuda:0" = auto_functionalized_v2[1]
        getitem_2: "f16[6, 4096][4096, 1]cuda:0" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_kernel.py:1417 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
        empty: "f8e4m3fn[6, 4096][4096, 1]cuda:0" = torch.ops.aten.empty.memory_format([6, 4096], dtype = torch.float8_e4m3fn, device = device(type='cuda', index=0), pin_memory = False)
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/gemm.py:136 in sgl_per_tensor_quant_fp8, code: torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
        auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default, input = getitem_1, output_s = arg5_1, is_static = True, _output_q_base_index = 0, _all_bases = [empty]);  empty = None
        getitem_4: "f8e4m3fn[6, 4096][4096, 1]cuda:0" = auto_functionalized_v2_1[1];  auto_functionalized_v2_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_utils.py:803 in apply_fp8_linear, code: output = torch._scaled_mm(
        _scaled_mm: "f16[6, 22016][22016, 1]cuda:0" = torch.ops.aten._scaled_mm.default(getitem_4, arg3_1, arg5_1, arg4_1, None, None, torch.float16);  getitem_4 = arg3_1 = arg5_1 = arg4_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/activation.py:69 in forward_cuda, code: out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        empty_1: "f16[6, 11008][11008, 1]cuda:0" = torch.ops.aten.empty.memory_format([6, 11008], dtype = torch.float16, device = device(type='cuda', index=0), pin_memory = False)
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:183 in silu_and_mul, code: torch.ops.sgl_kernel.silu_and_mul.default(out, input)
        auto_functionalized_v2_2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.silu_and_mul.default, input = _scaled_mm, _out_base_index = 0, _all_bases = [empty_1]);  _scaled_mm = empty_1 = None
        getitem_6: "f16[6, 11008][11008, 1]cuda:0" = auto_functionalized_v2_2[1];  auto_functionalized_v2_2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_kernel.py:1417 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
        empty_2: "f8e4m3fn[6, 11008][11008, 1]cuda:0" = torch.ops.aten.empty.memory_format([6, 11008], dtype = torch.float8_e4m3fn, device = device(type='cuda', index=0), pin_memory = False)
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/gemm.py:136 in sgl_per_tensor_quant_fp8, code: torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
        auto_functionalized_v2_3 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default, input = getitem_6, output_s = arg8_1, is_static = True, _output_q_base_index = 0, _all_bases = [empty_2]);  getitem_6 = empty_2 = None
        getitem_8: "f8e4m3fn[6, 11008][11008, 1]cuda:0" = auto_functionalized_v2_3[1];  auto_functionalized_v2_3 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_utils.py:803 in apply_fp8_linear, code: output = torch._scaled_mm(
        _scaled_mm_1: "f16[6, 4096][4096, 1]cuda:0" = torch.ops.aten._scaled_mm.default(getitem_8, arg6_1, arg8_1, arg7_1, None, None, torch.float16);  getitem_8 = arg6_1 = arg8_1 = arg7_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        copy_: "f16[6, 4096][4096, 1]cuda:0" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        copy__1: "f16[6, 4096][4096, 1]cuda:0" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy__1 = None
        return (_scaled_mm_1,)
        
[2025-09-17 02:10:39] TRACED GRAPH
 ===== After_FusedActivationPass =====
 <eval_with_key>.56 from /usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py:1301 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f16[6, 4096][4096, 1]cuda:0", arg1_1: "f16[6, 4096][4096, 1]cuda:0", arg2_1: "f16[4096][1]cuda:0", arg3_1: "f8e4m3fn[4096, 22016][1, 4096]cuda:0", arg4_1: "f32[][]cuda:0", arg5_1: "f32[][]cuda:0", arg6_1: "f8e4m3fn[11008, 4096][1, 11008]cuda:0", arg7_1: "f32[][]cuda:0", arg8_1: "f32[][]cuda:0"):
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.fused_add_rmsnorm.default, weight = arg2_1, eps = 1e-05, enable_pdl = False, _input_base_index = 0, _residual_base_index = 1, _all_bases = [arg0_1, arg1_1]);  arg2_1 = None
        getitem_1: "f16[6, 4096][4096, 1]cuda:0" = auto_functionalized_v2[1]
        getitem_2: "f16[6, 4096][4096, 1]cuda:0" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_kernel.py:1417 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
        empty: "f8e4m3fn[6, 4096][4096, 1]cuda:0" = torch.ops.aten.empty.memory_format([6, 4096], dtype = torch.float8_e4m3fn, device = device(type='cuda', index=0), pin_memory = False)
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/gemm.py:136 in sgl_per_tensor_quant_fp8, code: torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
        auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default, input = getitem_1, output_s = arg5_1, is_static = True, _output_q_base_index = 0, _all_bases = [empty]);  empty = None
        getitem_4: "f8e4m3fn[6, 4096][4096, 1]cuda:0" = auto_functionalized_v2_1[1];  auto_functionalized_v2_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/activation.py:69 in forward_cuda, code: out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        empty_1: "f16[6, 11008][11008, 1]cuda:0" = torch.ops.aten.empty.memory_format([6, 11008], dtype = torch.float16, device = device(type='cuda', index=0), pin_memory = False);  empty_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_kernel.py:1417 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
        empty_2: "f8e4m3fn[6, 11008][11008, 1]cuda:0" = torch.ops.aten.empty.memory_format([6, 11008], dtype = torch.float8_e4m3fn, device = device(type='cuda', index=0), pin_memory = False);  empty_2 = None
        
        # No stacktrace found for following nodes
        fused_swiglu_default: "f8e4m3fn[6, 11008][11008, 1]cuda:0" = torch.ops.sglang.fused_swiglu.default(getitem_4, arg3_1, arg5_1, arg4_1, arg8_1);  getitem_4 = arg3_1 = arg5_1 = arg4_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sglang/srt/layers/quantization/fp8_utils.py:803 in apply_fp8_linear, code: output = torch._scaled_mm(
        _scaled_mm_1: "f16[6, 4096][4096, 1]cuda:0" = torch.ops.aten._scaled_mm.default(fused_swiglu_default, arg6_1, arg8_1, arg7_1, None, None, torch.float16);  fused_swiglu_default = arg6_1 = arg8_1 = arg7_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/sgl_kernel/elementwise.py:81 in fused_add_rmsnorm, code: torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        copy_: "f16[6, 4096][4096, 1]cuda:0" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        copy__1: "f16[6, 4096][4096, 1]cuda:0" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy__1 = None
        return (_scaled_mm_1,)
        
[2025-09-17 02:10:39] FusedActivationPass completed in 17.4 ms, matched 1 times

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@DevashishLal-CB DevashishLal-CB changed the title [fusion] add fusion pass manager, fusion passes and fused activation pass [DRAFT][fusion] add fusion pass manager, fusion passes and fused activation pass Sep 17, 2025
@DevashishLal-CB
Copy link
Contributor Author

DevashishLal-CB commented Sep 17, 2025

Things Pending as of now

  • Cuda graph support, currently the fused activation pass doesn't work with cuda graph enabled
  • RMSNorm Quant fusion pass
  • Unit tests for fusion passes

@DevashishLal-CB DevashishLal-CB marked this pull request as draft September 17, 2025 05:44
@DevashishLal-CB DevashishLal-CB changed the title [DRAFT][fusion] add fusion pass manager, fusion passes and fused activation pass [fusion] add fusion pass manager, fusion passes and fused activation pass Sep 17, 2025
@DevashishLal-CB DevashishLal-CB changed the title [fusion] add fusion pass manager, fusion passes and fused activation pass [fusion] add fusion pass manager, base fusion pass and fused activation pass Sep 17, 2025
@DevashishLal-CB DevashishLal-CB force-pushed the gh/dlal/sgl-fusion branch 2 times, most recently from 56aec73 to 33aa252 Compare September 19, 2025 05:47
@BBuf
Copy link
Collaborator

BBuf commented Sep 19, 2025

Can we add a sgl-kernel fuse kernel pass example? Such as topk_softmax

@DevashishLal-CB
Copy link
Contributor Author

Can we add a sgl-kernel fuse kernel pass example? Such as topk_softmax

@BBuf Added the example for topk_softmax fusion, Also added rmsnorm_quant fusion pass with tests

This MR is ready for review, will look into cuda graph support and do it as a separate MR

Will collaborate with @yuan-luo

@BBuf
Copy link
Collaborator

BBuf commented Sep 26, 2025

Can we add a sgl-kernel fuse kernel pass example? Such as topk_softmax

@BBuf Added the example for topk_softmax fusion, Also added rmsnorm_quant fusion pass with tests

This MR is ready for review, will look into cuda graph support and do it as a separate MR

Will collaborate with @yuan-luo

Cool, we'll review ASAP.

from sglang.srt.server_args import ServerArgs


class FusionManager(CustomGraphPass):
Copy link
Collaborator

@yuan-luo yuan-luo Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of FusionManager, we prefer to do abstraction and form a PassManager, in which fusion is one type of all the Pass types like llvm pass concept. There can be other Pass types like AsyncTPPass, AllReduceFusionPass, RMSNormQuantFusionPass and etc.
Refer to https://github.com/sgl-project/sglang/pull/10987/files#diff-61475915ef47a86d47da62c647cd346f64c4b702c94728ab84172aed428e4fc0
for more details.

@yuan-luo
Copy link
Collaborator

#10987

from sglang.srt.server_args import ServerArgs

try:
from vllm import _custom_ops # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't depend on vllm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll port over the kernel

@@ -147,14 +156,21 @@ def patch_model(
tp_group.ca_comm = backup_ca_comm


def set_torch_compile_config():
def set_torch_compile_config(server_args, model_config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameters in def should have type.

@@ -1788,6 +1788,8 @@ def init_device_graphs(self):
return

if self.device != "cpu" and self.server_args.disable_cuda_graph:
if self.server_args.enable_torch_compile:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to conduct torch_compile in case of disable_cuda_graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked into it much but two passes I added weren't working with cuda graph enabled, also I am not sure about if all other hw platforms support cuda graph

# limitations under the License.
# ==============================================================================

import logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better put this configuration file in the python/sglang/srt/configs/ directory.

Comment on lines +118 to +122
return torch.compile(
torch.no_grad()(forward),
mode=os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"),
dynamic=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to use fullgraph=True. It's merge stopper, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently dynamo encounters graph breaks on attention, a unified attention op would solve it as done here #10062

@@ -114,6 +114,21 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
_to_torch(sub, reverse, num_tokens)


def _torch_compile_wrapper(forward):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more design patterns in 2025 except Wrapper and Manager, right? [sarcasm]
Your function is Decorator, not Wrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this entry point is suppose to be a placeholder, once we have a custom backend (which will be required by piecewise cuda graphs) that would manage this invocation, I didn't wanna do a big diff

from sglang.srt.compilation.fusion.fusion_pass import FusionPass


class RMSNormQuantPass(FusionPass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear from name and namespace: what type of quantization is supported: fp8 / int8/ int4 or binary?

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
@BLaZeKiLL
Copy link

Some performance numbers from sglang on an RTX 5090 running llama 3.1 8b fp8 on a 16 prompt benchmark for the rmsnorm_quant fusion pass

server config throughput (tok/sec
cuda graph 503.25
cuda graph + torch compile 600.37
cuda graph + torch compile + fusion (vllm kernels) 612.58
cuda graph + torch compile + fusion (flashinfer kernels) 619.11

@DevashishLal-CB DevashishLal-CB requested a review from AniZpZ as a code owner March 11, 2026 23:24
@DevashishLal-CB DevashishLal-CB changed the title [fusion] add fused activate and rmsnorm + quant fusions pass [fusion] add composable fusion pass framework Mar 12, 2026
Signed-off-by: Devashish Lal <devcode@fb.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…#2243)

<!-- .github/pull_request_template.md -->

## 📌 Description

FP8 model inference requires multiple intermediate quantization kernels,
which can be avoided by fusing norm and quantization kernels. Consumers
like sglang and vllm can lower to these norm + quant fusion kernels
using custom torch compile passes

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

### Reference
I have been working on adding custom fusion passes to sglang as part of
the following [RFC](sgl-project/sglang#10118)
and would like to use flashinfer's norm kernels for the norm quant
fusions instead of migrating vllm kernels to sglang as part of the
following [MR](sgl-project/sglang#10549)

### Implementation
I realise that existing kernels (at least for rmsnorm) can be modified
to add the scale parameter as an optional parameter, thereby avoiding
most code duplication. However, as an initial implementation, I have
opted for a separate implementation route. This can be refactored if
required.

For fused_add_rmsnorm_quant, I don't think an in-place update would be
possible since dtypes for input and output differ

Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am
not aware of getting this value at compile time without including c10
headers from torch, and not sure if that is acceptable post tvm ffi
migration

Following is a snippet from VLLM, and I have seen similar code for
getting the FP8 numeric limits
```cpp
#include <c10/util/Float8_e4m3fn.h>

template <typename T,
          typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
                                      std::is_same_v<T, c10::Float8_e4m3fnuz> ||
                                      std::is_same_v<T, int8_t>>>
struct quant_type_max {
  static constexpr T val() { return std::numeric_limits<T>::max(); }
};
```

The best option in my mind is to introduce `include/flashinfer/fp8.h`
containing something similar to the above snippet, and also support e5m2

### Tests
atol and rtol for the fp8 assertions had to be high due to the low
precision nature of the data, but with tolerances of 1e-2, just a few
tests fail with a single element mismatch

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added quantized RMSNorm and fused quantized RMSNorm (residual-add)
with configurable scale, eps, and PDL toggle.
* Supports FP16/FP8 paths and optional per-token or per-tensor scaling;
outputs are clamped for quantized formats.

* **Tests**
* Added tests validating quantized normalization and fused-residual
flows across dtypes, batch sizes, scaling modes, and PDL configurations.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
Devashish Lal added 3 commits March 13, 2026 17:21
Signed-off-by: Devashish Lal <devcode@fb.com>
these kernels are faster for all benchmarks when compared
against aot sglang, fused flashinfer (cutedsl) and unfused impl

Signed-off-by: Devashish Lal <devcode@fb.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants