diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 936eecb90b97..f220b277a18a 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -80,14 +80,66 @@ _is_npu = is_npu() _use_ag_after_qlora = envs.SGLANG_USE_AG_AFTER_QLORA.get() -if _use_aiter and _is_gfx95_supported: - from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant +if _use_aiter: + from aiter.ops.rmsnorm import add_rmsnorm_quant as _aiter_add_rmsnorm_quant + from aiter.ops.rmsnorm import rmsnorm_quant as _aiter_rmsnorm_quant - from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant + from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as _aiter_fp8_dtype + + if _is_gfx95_supported: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( + fused_rms_mxfp4_quant, + ) elif _is_npu: from sglang.srt.hardware_backend.npu.cmo import prepare_weight_cache +def _fused_rmsnorm_fp8_per_token_quant( + hidden_states: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + residual: Optional[torch.Tensor] = None, +): + """Fused (optional residual-add +) RMSNorm + FP8 per-token quantization. + + Args: + residual: if provided, computes hidden_states + residual before RMSNorm + and returns updated residual_out as second element. + + Returns: + If residual is None: (out_fp8, scale) + If residual provided: ((out_fp8, scale), residual_out) + """ + M, N = hidden_states.shape + out_fp8 = torch.empty((M, N), dtype=_aiter_fp8_dtype, device=hidden_states.device) + scale = torch.empty(M, dtype=torch.float32, device=hidden_states.device) + if residual is not None: + residual_out = torch.empty_like(hidden_states) + _aiter_add_rmsnorm_quant( + out_fp8, + hidden_states, + residual, + residual_out, + scale, + weight, + epsilon, + 0, # group_size=0 → per-token + ) + return (out_fp8, scale.unsqueeze(1)), residual_out + else: + _aiter_rmsnorm_quant( + out_fp8, + hidden_states, + scale, + weight, + epsilon, + 0, # group_size=0 → per-token + ) + return (out_fp8, scale.unsqueeze(1)) + + # TODO: According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465 # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True). FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 @@ -147,7 +199,6 @@ def model_input_output(): class AttentionInputs: - def __init__( self, hidden_states: torch.Tensor, @@ -309,8 +360,8 @@ def _compute_mlp_mode(cls, context: _LayerModeComputationContext): if context.is_layer_sparse: return ( ScatterMode.SCATTERED + # Token dispatch/combine will be handled outside of LayerCommunicator for these modes. if ( - # Token dispatch/combine will be handled outside of LayerCommunicator for these modes. not get_moe_a2a_backend().is_none() or should_use_flashinfer_cutlass_moe_fp4_allgather() ) @@ -482,8 +533,7 @@ def prepare_attn( None, None, ) - elif _use_aiter and _is_gfx95_supported and ("fp8" in quant_format): - + elif _use_aiter and _is_gfx95_supported and (quant_format == "fp8"): hidden_states, _, _, _res = fused_rms_fp8_group_quant( hidden_states, self.input_layernorm.weight, @@ -497,10 +547,16 @@ def prepare_attn( output_unquantized_inp1=False, ) + elif _use_aiter and (quant_format == "fp8_per_token"): + hidden_states = _fused_rmsnorm_fp8_per_token_quant( + hidden_states, + self.input_layernorm.weight.data, + self.input_layernorm.variance_epsilon, + ) + else: hidden_states = self.input_layernorm(hidden_states) else: - if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format): hidden_states, *_, residual = fused_rms_mxfp4_quant( hidden_states, @@ -511,7 +567,7 @@ def prepare_attn( None, residual, ) - elif _use_aiter and _is_gfx95_supported and ("fp8" in quant_format): + elif _use_aiter and _is_gfx95_supported and (quant_format == "fp8"): # RMSNorm + FP8 per-group quant # return hidden_states: # out_fp8 : FP8 activation → a8w8 GEMM @@ -528,6 +584,15 @@ def prepare_attn( res1=residual, output_unquantized_inp1=False, ) + elif _use_aiter and (quant_format == "fp8_per_token"): + if post_residual_addition is not None: + residual = residual + post_residual_addition + hidden_states, residual = _fused_rmsnorm_fp8_per_token_quant( + hidden_states, + self.input_layernorm.weight.data, + self.input_layernorm.variance_epsilon, + residual=residual, + ) else: hidden_states, residual = self.input_layernorm( hidden_states, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f5bcf975186e..4b15ac181383 100755 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -3,7 +3,7 @@ import logging from enum import Enum from functools import lru_cache -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union import torch @@ -1523,7 +1523,7 @@ def can_auto_enable_marlin_fp8() -> bool: def apply_fp8_ptpc_linear( - input: torch.Tensor, + input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, @@ -1534,6 +1534,18 @@ def apply_fp8_ptpc_linear( pad_output: Optional[bool] = None, compressed_tensor_quant: bool = False, ) -> torch.Tensor: + # Handle pre-quantized (fp8_tensor, scale) tuple from fused RMSNorm+Quant + if isinstance(input, tuple): + q_input, x_scale = input + q_input = q_input.view(-1, q_input.shape[-1]) + output_shape = [*q_input.shape[:-1], weight.shape[0]] + output = aiter.gemm_a8w8_bpreshuffle( + q_input, weight, x_scale, weight_scale, None, torch.bfloat16 + ) + if bias is not None: + output = output + bias + return output.view(*output_shape) + # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index fcf62efc8727..7280b0bdb167 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -275,7 +275,9 @@ def forward_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ): - if hidden_states.shape[0] == 0: + # hidden_states can be a (fp8_tensor, scale) tuple from fused RMSNorm+Quant + hs = hidden_states[0] if isinstance(hidden_states, tuple) else hidden_states + if hs.shape[0] == 0: return hidden_states, forward_batch, None qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -772,6 +774,51 @@ def __init__( ), ) + # Detect if QKV uses aiter FP8 per-token quant so we can fuse + # RMSNorm + FP8 quant into a single kernel in prepare_attn + self.attn_quant_format = "" + self._detect_attn_quant_format() + + def _detect_fp8_per_token_quant(self, linear_layer, label: str) -> str: + """Check if a linear layer uses aiter FP8 per-token quantization.""" + from sglang.srt.utils import get_bool_env_var, is_hip + + if not (get_bool_env_var("SGLANG_USE_AITER") and is_hip()): + return "" + if not hasattr(linear_layer, "quant_method"): + return "" + scheme = getattr(linear_layer, "scheme", None) or getattr( + linear_layer.quant_method, "scheme", None + ) + if scheme is not None: + from compressed_tensors.quantization import QuantizationStrategy + + from sglang.srt.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import ( + CompressedTensorsW8A8Fp8, + ) + + if ( + isinstance(scheme, CompressedTensorsW8A8Fp8) + and scheme.strategy == QuantizationStrategy.CHANNEL + ): + logger.info( + "layer_%d Fused RMSNorm+Quant %s: ENABLED (fp8_per_token)", + self.layer_id, + label, + ) + return "fp8_per_token" + logger.info( + "layer_%d Fused RMSNorm+Quant %s: skipped", + self.layer_id, + label, + ) + return "" + + def _detect_attn_quant_format(self): + self.attn_quant_format = self._detect_fp8_per_token_quant( + self.self_attn.qkv_proj, "attn" + ) + def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: return is_nextn or ( self.config.n_routed_experts is not None @@ -787,7 +834,10 @@ def forward( ) -> torch.Tensor: hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, + residual, + forward_batch, + quant_format=self.attn_quant_format, ) hidden_states = self.self_attn( @@ -834,7 +884,12 @@ def op_comm_prepare_attn( tbo_subbatch_index: Optional[int] = None, ): state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( - self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) + self.layer_communicator.prepare_attn( + hidden_states, + residual, + forward_batch, + quant_format=self.attn_quant_format, + ) ) state.update( dict(