[AMD] Fuse RMSNorm + FP8 per-token quant for GLM-4.7-FP8#21403
[AMD] Fuse RMSNorm + FP8 per-token quant for GLM-4.7-FP8#21403Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the performance of FP8 per-token quantization by fusing RMSNorm and quantization operations into a single kernel. This change reduces memory overhead and improves efficiency, particularly for models like GLM-4.7-FP8. The update also includes robust detection mechanisms for the appropriate quantization schemes and ensures correct routing to the optimized kernels, leading to 95 kernel fusions across attention and dense MLP layers. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces FP8 per-token quantization support using the aiter library for RMSNorm and linear layers within the SGLang framework, specifically for attention and MLP blocks. Key changes include adding new aiter RMSNorm quantization functions and logic in communicator.py to handle the fp8_per_token format, returning quantized tensors and scales as a tuple. The fp8_utils.py file is updated to process this (tensor, scale) tuple input for FP8 linear operations. Additionally, glm4_moe.py is modified to dynamically detect and pass the appropriate quantization format to the communicator layers, and its forward methods are adapted to correctly handle tuple inputs for hidden states. Review comments suggest improving consistency in quant_format string matching, addressing a potential bug where the scale tensor might be discarded in GLM4MoEBlock.forward_prepare, refactoring duplicated code in communicator.py, restoring missing type hints in apply_fp8_ptpc_linear, and investigating a potential dead code block in GLM4MoEBlock.forward related to tuple handling for MoE layers.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
1 similar comment
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
e784322 to
61b4a63
Compare
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Fuse `add_rmsnorm_quant_kernel` (RMSNorm) with `dynamic_per_token_scaled_quant_kernel` (FP8 quantization) into a single kernel call using aiter's `add_rmsnorm_quant` with FUSE_QUANT=true. This eliminates redundant global memory round-trips between RMSNorm output and FP8 quantization input for models using CompressedTensorsW8A8Fp8 with per-channel weight quantization (e.g. GLM-4.7-FP8). Changes: - communicator.py: Add fp8_per_token path in prepare_attn using aiter add_rmsnorm_quant (group_size=0 for per-token) - glm4_moe.py: Auto-detect FP8 per-token quant scheme on qkv_proj, pass quant_format to prepare_attn, handle (fp8, scale) tuple - fp8_utils.py: Handle pre-quantized (fp8, scale) tuple input in apply_fp8_ptpc_linear to skip redundant per_token_quant_hip - Fix "fp8" in quant_format -> quant_format == "fp8" to prevent fp8_per_token from being intercepted by the fp8 group-quant path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
61b4a63 to
0845ef2
Compare
|
/tag-and-rerun-ci |
HaiShaw
left a comment
There was a problem hiding this comment.
@Jacob0226 lint fix pls
Summary
add_rmsnorm_quant_kernel(RMSNorm) withdynamic_per_token_scaled_quant_kernel(FP8 quantization) into a single kernel call using aiter'sadd_rmsnorm_quantwithFUSE_QUANT=true, eliminating redundant global memory round-tripsCompressedTensorsW8A8Fp8with per-channel weight quantization (e.g. GLM-4.7-FP8) and enable fused path viaquant_format="fp8_per_token""fp8" in quant_format→quant_format == "fp8"to preventfp8_per_tokenfrom being intercepted by the existingfused_rms_fp8_group_quantpathChanges
communicator.pyfp8_per_tokenpath inprepare_attnusing aiteradd_rmsnorm_quant(group_size=0 for per-token)glm4_moe.pyqkv_proj, passquant_formattoprepare_attn, handle(fp8, scale)tuplefp8_utils.py(fp8, scale)tuple input inapply_fp8_ptpc_linearto skip redundantper_token_quant_hipScope
prepare_attn, applied across all 92 decoder layersTest plan
Accuracy (GSM8K on MI355X TP8 with GLM-4.7-FP8):
Within margin of error.
Performance (InferenceMax config on MI355X TP8):
Profiling:

Shown on the right side of the figure, fused
dynamic_per_token_scaled_quant_kernelintoadd_rmsnorm_quant_kernel.🤖 Generated with Claude Code