Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
536 changes: 536 additions & 0 deletions benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
| `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | `auto` | `normal`, `low_latency`, `auto` |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | `0` | Type: int |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | `None` | Type: str |
Expand Down
17 changes: 16 additions & 1 deletion python/sglang/srt/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py

from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.distributed
Expand All @@ -13,6 +13,21 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
return get_tp_group().all_reduce(input_)


def tensor_model_parallel_fused_allreduce_rmsnorm(
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""Fused TP all-reduce + RMSNorm.

Policy and backend selection are owned by GroupCoordinator:
it may dispatch to communicator-native fused APIs, custom fused kernels,
or return None so callers can run generic fallback paths.
"""
return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)


def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
Expand Down
52 changes: 52 additions & 0 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,58 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
inplace_all_reduce(input_, group_name=self.unique_name)
return input_

def fused_allreduce_rmsnorm(
self,
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""Attempt fused all-reduce + RMSNorm via custom all-reduce communicator."""
ca_comm = self.ca_comm
if ca_comm is None or getattr(ca_comm, "disabled", True):
return None

# Prefer communicator-native fused API when provided.
if hasattr(ca_comm, "fused_allreduce_rmsnorm"):
try:
return ca_comm.fused_allreduce_rmsnorm(
input_, residual_inp_, weight_, eps
)
except Exception:
# Fall back to custom_fused_ar_rms path below.
pass

if not hasattr(ca_comm, "custom_fused_ar_rms"):
return None

# 1-stage policy for fused AR+RMSNorm:
# 1) Explicit env override wins.
# 2) Deterministic inference forces 1-stage for reproducibility.
# 3) Otherwise follow AITER's heuristic (small payloads only).
if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():
use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()
elif envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get():
use_1stage_ar = True
else:
total_bytes = input_.numel() * input_.element_size()
hidden_dim = input_.shape[-1]
use_1stage_ar = total_bytes <= 128 * 1024 and hidden_dim in {
512,
1024,
2048,
4096,
}

fused_outputs = ca_comm.custom_fused_ar_rms(
input_,
residual_inp_,
weight_,
eps,
use_1stage_ar,
)
return fused_outputs

def _all_reduce_out_place(
self, input_: torch.Tensor, outplace_all_reduce_method: str
) -> torch.Tensor:
Expand Down
51 changes: 43 additions & 8 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ def apply_flashinfer_allreduce_fusion(batch_size: int):
)


def apply_aiter_all_reduce_fusion(input_tensor: torch.Tensor):
n = input_tensor.shape[-1]
total_bytes = input_tensor.numel() * input_tensor.element_size()
return (
_use_aiter
and total_bytes > 0
and n <= 16384
and total_bytes < 8 * 1024 * 8192
and get_tensor_model_parallel_world_size() != 6
and not is_dp_attention_enabled()
and get_global_server_args().enable_aiter_allreduce_fusion
)


class ScatterMode(Enum):
"""
Suppose we have TP=4, DP=2, enable-dp-attention, and the system handles seq a,b,c,d
Expand Down Expand Up @@ -428,11 +442,20 @@ def prepare_attn(
and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
and hidden_states._sglang_needs_allreduce_fusion
):
hidden_states, residual = (
self.input_layernorm.forward_with_allreduce_fusion(
if (
apply_aiter_all_reduce_fusion(hidden_states)
or apply_flashinfer_allreduce_fusion(hidden_states.shape[0])
) and hasattr(self.input_layernorm, "forward_with_allreduce_fusion"):
hidden_states, residual = (
self.input_layernorm.forward_with_allreduce_fusion(
hidden_states, residual
)
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
)
else:
if residual is None:
residual = hidden_states
Expand Down Expand Up @@ -599,7 +622,15 @@ def should_fuse_mlp_allreduce_with_next_layer(
)

return (
apply_flashinfer_allreduce_fusion(batch_size)
(
apply_flashinfer_allreduce_fusion(batch_size)
or (
_use_aiter
and batch_size > 0
and get_tensor_model_parallel_world_size() != 6
and get_global_server_args().enable_aiter_allreduce_fusion
)
)
and (not self.is_last_layer)
and (self._context.tp_size > 1)
)
Expand Down Expand Up @@ -799,13 +830,17 @@ def _gather_hidden_states_and_residual(
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
if apply_flashinfer_allreduce_fusion(hidden_states.shape[0]) and hasattr(
layernorm, "forward_with_allreduce_fusion"
):
handled = False
if (
apply_aiter_all_reduce_fusion(hidden_states)
or apply_flashinfer_allreduce_fusion(hidden_states.shape[0])
) and hasattr(layernorm, "forward_with_allreduce_fusion"):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual
)
else:
handled = True

if not handled:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if _is_npu and context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
Expand Down
39 changes: 30 additions & 9 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,43 @@ def forward_with_allreduce_fusion(
Forward method with allreduce fusion, prioritizing flashinfer fused operations
"""
if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
tensor_model_parallel_fused_allreduce_rmsnorm,
)
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)

if get_tensor_model_parallel_world_size() > 1:
if post_residual_addition is not None:
residual = residual + post_residual_addition
fused_result = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
)
if fused_result[0] is not None:
return fused_result

# Prefer AITER fused AR+RMSNorm when enabled on AMD.
if _use_aiter:
fused_result = tensor_model_parallel_fused_allreduce_rmsnorm(
x, residual, self.weight, self.variance_epsilon
)
if fused_result is not None:
return fused_result
else:
fused_result = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
)
if fused_result[0] is not None:
return fused_result

# For AITER route, preserve correctness when fused path is unavailable.
if (
_use_aiter
and get_global_server_args().enable_aiter_allreduce_fusion
):
x = tensor_model_parallel_all_reduce(x)
return self.forward(x, residual, None)

return self.forward(x, residual, post_residual_addition)

Expand Down
35 changes: 35 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ class ServerArgs:
moe_runner_backend: str = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False
enable_aiter_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
Expand Down Expand Up @@ -1303,6 +1304,13 @@ def _handle_model_specific_adjustments(self):
logger.info(
"Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM"
)
elif is_hip():
if not self.enable_dp_attention and self.nnodes == 1:
# TODO (Hubert): Put this back later
# self.enable_aiter_allreduce_fusion = True
logger.info(
"Enable Aiter AllReduce Fusion for DeepseekV3ForCausalLM"
)

if (
self.quantization == "modelopt_fp4"
Expand Down Expand Up @@ -1358,6 +1366,22 @@ def _handle_model_specific_adjustments(self):

quant_method = get_quantization_config(hf_config)
is_mxfp4_quant_format = quant_method == "mxfp4"
if is_blackwell_supported():
# workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006
if not self.enable_dp_attention and self.nnodes == 1:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
if not self.enable_dp_attention and self.nnodes == 1 and is_hip():
# TODO (Hubert): Put this back later
# self.enable_aiter_allreduce_fusion = True
logger.info("Enable Aiter AllReduce Fusion for GptOssForCausalLM")
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
Expand Down Expand Up @@ -2644,6 +2668,12 @@ def _handle_deterministic_inference(self):
os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1"

if self.enable_deterministic_inference:
if self.enable_aiter_allreduce_fusion:
logger.warning(
"Disable --enable-aiter-allreduce-fusion because deterministic inference is enabled."
)
self.enable_aiter_allreduce_fusion = False

# Check sampling backend
self.sampling_backend = "pytorch"
logger.warning(
Expand Down Expand Up @@ -4022,6 +4052,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
)
parser.add_argument(
"--enable-aiter-allreduce-fusion",
action="store_true",
help="Enable Aiter AllReduce Fusion.",
)
parser.add_argument(
"--deepep-mode",
type=str,
Expand Down
Loading
Loading