Skip to content

[AMD] Fuse RMSNorm + FP8 per-token quant for GLM-4.7-FP8#21403

Open
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226:jacob/fused_rmsnorm_quant
Open

[AMD] Fuse RMSNorm + FP8 per-token quant for GLM-4.7-FP8#21403
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226:jacob/fused_rmsnorm_quant

Conversation

@Jacob0226
Copy link
Copy Markdown
Contributor

@Jacob0226 Jacob0226 commented Mar 25, 2026

🤖 This PR was developed with Claude Code (Claude Opus 4.6)

Summary

  • Fuse add_rmsnorm_quant_kernel (RMSNorm) with dynamic_per_token_scaled_quant_kernel (FP8 quantization) into a single kernel call using aiter's add_rmsnorm_quant with FUSE_QUANT=true, eliminating redundant global memory round-trips
  • Auto-detect CompressedTensorsW8A8Fp8 with per-channel weight quantization (e.g. GLM-4.7-FP8) and enable fused path via quant_format="fp8_per_token"
  • Fix "fp8" in quant_formatquant_format == "fp8" to prevent fp8_per_token from being intercepted by the existing fused_rms_fp8_group_quant path

Changes

File Change
communicator.py Add fp8_per_token path in prepare_attn using aiter add_rmsnorm_quant (group_size=0 for per-token)
glm4_moe.py Auto-detect FP8 per-token quant scheme on qkv_proj, pass quant_format to prepare_attn, handle (fp8, scale) tuple
fp8_utils.py Handle pre-quantized (fp8, scale) tuple input in apply_fp8_ptpc_linear to skip redundant per_token_quant_hip

Scope

  • Fuses RMSNorm + FP8 per-token quant into a single kernel in prepare_attn, applied across all 92 decoder layers

Test plan

Accuracy (GSM8K on MI355X TP8 with GLM-4.7-FP8):

Before After
GSM8K 0.948 0.943

Within margin of error.

Performance (InferenceMax config on MI355X TP8):

Config ITL Decode Speedup
ISL 1K/8K, OSL 1K, concurrency 4/8/16/32/64 ~+1%

Profiling:
Shown on the right side of the figure, fused dynamic_per_token_scaled_quant_kernel into add_rmsnorm_quant_kernel.
image

🤖 Generated with Claude Code

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the performance of FP8 per-token quantization by fusing RMSNorm and quantization operations into a single kernel. This change reduces memory overhead and improves efficiency, particularly for models like GLM-4.7-FP8. The update also includes robust detection mechanisms for the appropriate quantization schemes and ensures correct routing to the optimized kernels, leading to 95 kernel fusions across attention and dense MLP layers.

Highlights

  • Kernel Fusion: Fused the add_rmsnorm_quant_kernel (RMSNorm) with dynamic_per_token_scaled_quant_kernel (FP8 quantization) into a single kernel call using aiter's add_rmsnorm_quant with FUSE_QUANT=true, which eliminates redundant global memory round-trips.
  • Automatic Detection: Implemented auto-detection for CompressedTensorsW8A8Fp8 with per-channel weight quantization (e.g., GLM-4.7-FP8) to enable the fused path via quant_format="fp8_per_token".
  • Quantization Format Handling: Fixed the logic for detecting the 'fp8' quantization format from "fp8" in quant_format to quant_format == "fp8" to prevent fp8_per_token from being incorrectly intercepted by the existing fused_rms_fp8_group_quant path.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FP8 per-token quantization support using the aiter library for RMSNorm and linear layers within the SGLang framework, specifically for attention and MLP blocks. Key changes include adding new aiter RMSNorm quantization functions and logic in communicator.py to handle the fp8_per_token format, returning quantized tensors and scales as a tuple. The fp8_utils.py file is updated to process this (tensor, scale) tuple input for FP8 linear operations. Additionally, glm4_moe.py is modified to dynamically detect and pass the appropriate quantization format to the communicator layers, and its forward methods are adapted to correctly handle tuple inputs for hidden states. Review comments suggest improving consistency in quant_format string matching, addressing a potential bug where the scale tensor might be discarded in GLM4MoEBlock.forward_prepare, refactoring duplicated code in communicator.py, restoring missing type hints in apply_fp8_ptpc_linear, and investigating a potential dead code block in GLM4MoEBlock.forward related to tuple handling for MoE layers.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

1 similar comment
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Jacob0226 Jacob0226 force-pushed the jacob/fused_rmsnorm_quant branch 3 times, most recently from e784322 to 61b4a63 Compare March 27, 2026 04:37
@Jacob0226 Jacob0226 marked this pull request as ready for review March 27, 2026 04:40
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Fuse `add_rmsnorm_quant_kernel` (RMSNorm) with
`dynamic_per_token_scaled_quant_kernel` (FP8 quantization) into a single
kernel call using aiter's `add_rmsnorm_quant` with FUSE_QUANT=true.

This eliminates redundant global memory round-trips between RMSNorm output
and FP8 quantization input for models using CompressedTensorsW8A8Fp8 with
per-channel weight quantization (e.g. GLM-4.7-FP8).

Changes:
- communicator.py: Add fp8_per_token path in prepare_attn using aiter
  add_rmsnorm_quant (group_size=0 for per-token)
- glm4_moe.py: Auto-detect FP8 per-token quant scheme on qkv_proj,
  pass quant_format to prepare_attn, handle (fp8, scale) tuple
- fp8_utils.py: Handle pre-quantized (fp8, scale) tuple input in
  apply_fp8_ptpc_linear to skip redundant per_token_quant_hip
- Fix "fp8" in quant_format -> quant_format == "fp8" to prevent
  fp8_per_token from being intercepted by the fp8 group-quant path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Jacob0226 Jacob0226 force-pushed the jacob/fused_rmsnorm_quant branch from 61b4a63 to 0845ef2 Compare March 30, 2026 01:47
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Mar 31, 2026

/tag-and-rerun-ci

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jacob0226 lint fix pls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants