Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,66 @@
_is_npu = is_npu()
_use_ag_after_qlora = envs.SGLANG_USE_AG_AFTER_QLORA.get()

if _use_aiter and _is_gfx95_supported:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
if _use_aiter:
from aiter.ops.rmsnorm import add_rmsnorm_quant as _aiter_add_rmsnorm_quant
from aiter.ops.rmsnorm import rmsnorm_quant as _aiter_rmsnorm_quant

from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as _aiter_fp8_dtype

if _is_gfx95_supported:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant

from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
fused_rms_mxfp4_quant,
)
elif _is_npu:
from sglang.srt.hardware_backend.npu.cmo import prepare_weight_cache


def _fused_rmsnorm_fp8_per_token_quant(
hidden_states: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
residual: Optional[torch.Tensor] = None,
):
"""Fused (optional residual-add +) RMSNorm + FP8 per-token quantization.

Args:
residual: if provided, computes hidden_states + residual before RMSNorm
and returns updated residual_out as second element.

Returns:
If residual is None: (out_fp8, scale)
If residual provided: ((out_fp8, scale), residual_out)
"""
M, N = hidden_states.shape
out_fp8 = torch.empty((M, N), dtype=_aiter_fp8_dtype, device=hidden_states.device)
scale = torch.empty(M, dtype=torch.float32, device=hidden_states.device)
if residual is not None:
residual_out = torch.empty_like(hidden_states)
_aiter_add_rmsnorm_quant(
out_fp8,
hidden_states,
residual,
residual_out,
scale,
weight,
epsilon,
0, # group_size=0 → per-token
)
return (out_fp8, scale.unsqueeze(1)), residual_out
else:
_aiter_rmsnorm_quant(
out_fp8,
hidden_states,
scale,
weight,
epsilon,
0, # group_size=0 → per-token
)
return (out_fp8, scale.unsqueeze(1))


# TODO: According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
Expand Down Expand Up @@ -147,7 +199,6 @@ def model_input_output():


class AttentionInputs:

def __init__(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -309,8 +360,8 @@ def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
if (
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not get_moe_a2a_backend().is_none()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
)
Expand Down Expand Up @@ -482,8 +533,7 @@ def prepare_attn(
None,
None,
)
elif _use_aiter and _is_gfx95_supported and ("fp8" in quant_format):

elif _use_aiter and _is_gfx95_supported and (quant_format == "fp8"):
hidden_states, _, _, _res = fused_rms_fp8_group_quant(
hidden_states,
self.input_layernorm.weight,
Expand All @@ -497,10 +547,16 @@ def prepare_attn(
output_unquantized_inp1=False,
)

elif _use_aiter and (quant_format == "fp8_per_token"):
hidden_states = _fused_rmsnorm_fp8_per_token_quant(
hidden_states,
self.input_layernorm.weight.data,
self.input_layernorm.variance_epsilon,
)

else:
hidden_states = self.input_layernorm(hidden_states)
else:

if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format):
hidden_states, *_, residual = fused_rms_mxfp4_quant(
hidden_states,
Expand All @@ -511,7 +567,7 @@ def prepare_attn(
None,
residual,
)
elif _use_aiter and _is_gfx95_supported and ("fp8" in quant_format):
elif _use_aiter and _is_gfx95_supported and (quant_format == "fp8"):
# RMSNorm + FP8 per-group quant
# return hidden_states:
# out_fp8 : FP8 activation → a8w8 GEMM
Expand All @@ -528,6 +584,15 @@ def prepare_attn(
res1=residual,
output_unquantized_inp1=False,
)
elif _use_aiter and (quant_format == "fp8_per_token"):
if post_residual_addition is not None:
residual = residual + post_residual_addition
hidden_states, residual = _fused_rmsnorm_fp8_per_token_quant(
hidden_states,
self.input_layernorm.weight.data,
self.input_layernorm.variance_epsilon,
residual=residual,
)
else:
hidden_states, residual = self.input_layernorm(
hidden_states,
Expand Down
16 changes: 14 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -1523,7 +1523,7 @@ def can_auto_enable_marlin_fp8() -> bool:


def apply_fp8_ptpc_linear(
input: torch.Tensor,
input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
Expand All @@ -1534,6 +1534,18 @@ def apply_fp8_ptpc_linear(
pad_output: Optional[bool] = None,
compressed_tensor_quant: bool = False,
) -> torch.Tensor:
# Handle pre-quantized (fp8_tensor, scale) tuple from fused RMSNorm+Quant
if isinstance(input, tuple):
q_input, x_scale = input
q_input = q_input.view(-1, q_input.shape[-1])
output_shape = [*q_input.shape[:-1], weight.shape[0]]
output = aiter.gemm_a8w8_bpreshuffle(
q_input, weight, x_scale, weight_scale, None, torch.bfloat16
)
if bias is not None:
output = output + bias
return output.view(*output_shape)

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])

Expand Down
61 changes: 58 additions & 3 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def forward_prepare(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
if hidden_states.shape[0] == 0:
# hidden_states can be a (fp8_tensor, scale) tuple from fused RMSNorm+Quant
hs = hidden_states[0] if isinstance(hidden_states, tuple) else hidden_states
if hs.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Expand Down Expand Up @@ -772,6 +774,51 @@ def __init__(
),
)

# Detect if QKV uses aiter FP8 per-token quant so we can fuse
# RMSNorm + FP8 quant into a single kernel in prepare_attn
self.attn_quant_format = ""
self._detect_attn_quant_format()

def _detect_fp8_per_token_quant(self, linear_layer, label: str) -> str:
"""Check if a linear layer uses aiter FP8 per-token quantization."""
from sglang.srt.utils import get_bool_env_var, is_hip

if not (get_bool_env_var("SGLANG_USE_AITER") and is_hip()):
return ""
if not hasattr(linear_layer, "quant_method"):
return ""
scheme = getattr(linear_layer, "scheme", None) or getattr(
linear_layer.quant_method, "scheme", None
)
if scheme is not None:
from compressed_tensors.quantization import QuantizationStrategy

from sglang.srt.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import (
CompressedTensorsW8A8Fp8,
)

if (
isinstance(scheme, CompressedTensorsW8A8Fp8)
and scheme.strategy == QuantizationStrategy.CHANNEL
):
logger.info(
"layer_%d Fused RMSNorm+Quant %s: ENABLED (fp8_per_token)",
self.layer_id,
label,
)
return "fp8_per_token"
logger.info(
"layer_%d Fused RMSNorm+Quant %s: skipped",
self.layer_id,
label,
)
return ""

def _detect_attn_quant_format(self):
self.attn_quant_format = self._detect_fp8_per_token_quant(
self.self_attn.qkv_proj, "attn"
)

def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or (
self.config.n_routed_experts is not None
Expand All @@ -787,7 +834,10 @@ def forward(
) -> torch.Tensor:

hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
hidden_states,
residual,
forward_batch,
quant_format=self.attn_quant_format,
)

hidden_states = self.self_attn(
Expand Down Expand Up @@ -834,7 +884,12 @@ def op_comm_prepare_attn(
tbo_subbatch_index: Optional[int] = None,
):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
self.layer_communicator.prepare_attn(
hidden_states,
residual,
forward_batch,
quant_format=self.attn_quant_format,
)
)
state.update(
dict(
Expand Down
Loading