[fusion] add composable fusion pass framework#10549
[fusion] add composable fusion pass framework#10549DevashishLal-CB wants to merge 19 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Things Pending as of now
|
56aec73 to
33aa252
Compare
|
Can we add a sgl-kernel fuse kernel pass example? Such as |
b2c8368 to
6990ed4
Compare
6990ed4 to
ba01b82
Compare
@BBuf Added the example for topk_softmax fusion, Also added rmsnorm_quant fusion pass with tests This MR is ready for review, will look into cuda graph support and do it as a separate MR Will collaborate with @yuan-luo |
Cool, we'll review ASAP. |
| from sglang.srt.server_args import ServerArgs | ||
|
|
||
|
|
||
| class FusionManager(CustomGraphPass): |
There was a problem hiding this comment.
Instead of FusionManager, we prefer to do abstraction and form a PassManager, in which fusion is one type of all the Pass types like llvm pass concept. There can be other Pass types like AsyncTPPass, AllReduceFusionPass, RMSNormQuantFusionPass and etc.
Refer to https://github.com/sgl-project/sglang/pull/10987/files#diff-61475915ef47a86d47da62c647cd346f64c4b702c94728ab84172aed428e4fc0
for more details.
| from sglang.srt.server_args import ServerArgs | ||
|
|
||
| try: | ||
| from vllm import _custom_ops # noqa: F401 |
There was a problem hiding this comment.
I'll port over the kernel
| @@ -147,14 +156,21 @@ def patch_model( | |||
| tp_group.ca_comm = backup_ca_comm | |||
|
|
|||
|
|
|||
| def set_torch_compile_config(): | |||
| def set_torch_compile_config(server_args, model_config): | |||
There was a problem hiding this comment.
Parameters in def should have type.
| @@ -1788,6 +1788,8 @@ def init_device_graphs(self): | |||
| return | |||
|
|
|||
| if self.device != "cpu" and self.server_args.disable_cuda_graph: | |||
| if self.server_args.enable_torch_compile: | |||
There was a problem hiding this comment.
Do we need to conduct torch_compile in case of disable_cuda_graph?
There was a problem hiding this comment.
I haven't looked into it much but two passes I added weren't working with cuda graph enabled, also I am not sure about if all other hw platforms support cuda graph
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import logging |
There was a problem hiding this comment.
We'd better put this configuration file in the python/sglang/srt/configs/ directory.
| return torch.compile( | ||
| torch.no_grad()(forward), | ||
| mode=os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"), | ||
| dynamic=False, | ||
| ) |
There was a problem hiding this comment.
You have to use fullgraph=True. It's merge stopper, isn't it?
There was a problem hiding this comment.
Currently dynamo encounters graph breaks on attention, a unified attention op would solve it as done here #10062
| @@ -114,6 +114,21 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): | |||
| _to_torch(sub, reverse, num_tokens) | |||
|
|
|||
|
|
|||
| def _torch_compile_wrapper(forward): | |||
There was a problem hiding this comment.
No more design patterns in 2025 except Wrapper and Manager, right? [sarcasm]
Your function is Decorator, not Wrapper.
There was a problem hiding this comment.
Yeah, this entry point is suppose to be a placeholder, once we have a custom backend (which will be required by piecewise cuda graphs) that would manage this invocation, I didn't wanna do a big diff
| from sglang.srt.compilation.fusion.fusion_pass import FusionPass | ||
|
|
||
|
|
||
| class RMSNormQuantPass(FusionPass): |
There was a problem hiding this comment.
Not clear from name and namespace: what type of quantization is supported: fp8 / int8/ int4 or binary?
Signed-off-by: Devashish Lal <laldevashish@gmail.com>
ee63937 to
2ebffb6
Compare
Signed-off-by: Devashish Lal <laldevashish@gmail.com>
|
Some performance numbers from sglang on an RTX 5090 running llama 3.1 8b fp8 on a 16 prompt benchmark for the rmsnorm_quant fusion pass
|
13f96ea to
69c1778
Compare
Signed-off-by: Devashish Lal <devcode@fb.com>
69c1778 to
59ca839
Compare
…#2243) <!-- .github/pull_request_template.md --> ## 📌 Description FP8 model inference requires multiple intermediate quantization kernels, which can be avoided by fusing norm and quantization kernels. Consumers like sglang and vllm can lower to these norm + quant fusion kernels using custom torch compile passes ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ### Reference I have been working on adding custom fusion passes to sglang as part of the following [RFC](sgl-project/sglang#10118) and would like to use flashinfer's norm kernels for the norm quant fusions instead of migrating vllm kernels to sglang as part of the following [MR](sgl-project/sglang#10549) ### Implementation I realise that existing kernels (at least for rmsnorm) can be modified to add the scale parameter as an optional parameter, thereby avoiding most code duplication. However, as an initial implementation, I have opted for a separate implementation route. This can be refactored if required. For fused_add_rmsnorm_quant, I don't think an in-place update would be possible since dtypes for input and output differ Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am not aware of getting this value at compile time without including c10 headers from torch, and not sure if that is acceptable post tvm ffi migration Following is a snippet from VLLM, and I have seen similar code for getting the FP8 numeric limits ```cpp #include <c10/util/Float8_e4m3fn.h> template <typename T, typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || std::is_same_v<T, c10::Float8_e4m3fnuz> || std::is_same_v<T, int8_t>>> struct quant_type_max { static constexpr T val() { return std::numeric_limits<T>::max(); } }; ``` The best option in my mind is to introduce `include/flashinfer/fp8.h` containing something similar to the above snippet, and also support e5m2 ### Tests atol and rtol for the fp8 assertions had to be high due to the low precision nature of the data, but with tolerances of 1e-2, just a few tests fail with a single element mismatch <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added quantized RMSNorm and fused quantized RMSNorm (residual-add) with configurable scale, eps, and PDL toggle. * Supports FP16/FP8 paths and optional per-token or per-tensor scaling; outputs are clamped for quantized formats. * **Tests** * Added tests validating quantized normalization and fused-residual flows across dtypes, batch sizes, scaling modes, and PDL configurations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Devashish Lal <laldevashish@gmail.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
these kernels are faster for all benchmarks when compared against aot sglang, fused flashinfer (cutedsl) and unfused impl Signed-off-by: Devashish Lal <devcode@fb.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
Motivation
Initial implementation of the changes proposed in #10118
Modifications
This PR adds the fusion passes and integration tests for them
Passes added
For fusion passes to work with cuda graph runner I had to get rid for the model patching (or I could rewrite the pass with the pattern functions looking for pure pytorch code, we should avoid this model patching as it will interfere with the compilation process)
I have also added
model_bench.py, the idea with this is to provide a stripped down sglang runtime where each layer can be instantiated in isolation helping write integration and accuracy tests from fusion passes and fused kernelsAccuracy Tests
Benchmarking and Profiling
MM + Silu and Mul fusion
MM + Silu and Mul + Quant (I have a small diff to use sgl_per_tensor_quant_fp8 for quant instead of the triton quant kernel, will add support for the default quant kernel before merge)
Checklist