Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available

_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()


class ScatterMode(Enum):
Expand Down Expand Up @@ -397,8 +402,19 @@ def _gather_hidden_states_and_residual(
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
if (
_is_sm100_supported
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_state.numel() * hidden_state.element_size() < THRESHOLD

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, I'll update it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this feature could also be applied to other models. However, since different models may have varying hidden_size values, directly checking max_token_num seems more general. Regarding the token number 1024: This value was estimated based on the workspace threshold in trt_llm, using ds-v3's hidden_size as a reference. With 1024 tokens, it only allocates an additional ~10MB buffer, so the memory overhead is minimal.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I think the max_workspace settings in trt-llm are not for these flashinfer kernels (allreduce_fusion_xxx).

The max_token_number should be set as the actual max token number.
The default value of use_oneshot should be False. In these two cases it could be set as True. (https://github.com/NVIDIA/TensorRT-LLM/blob/a1235ee9781050e562bbce2c86f714c38d434dbe/cpp/tensorrt_llm/thop/allreduceOp.cpp#L416-L429)

):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual

@staticmethod
Expand Down
202 changes: 202 additions & 0 deletions python/sglang/srt/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import logging
from typing import Tuple

import torch
import torch.distributed as dist

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.utils import is_flashinfer_available

logger = logging.getLogger(__name__)

_flashinfer_comm = None
_workspace_manager = None

if is_flashinfer_available():
try:
import flashinfer.comm as comm

_flashinfer_comm = comm
except ImportError:
logger.warning(
"flashinfer.comm is not available, falling back to standard "
"implementation"
)


class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False

def initialize(
self,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
group=None,
use_fp32_lamport: bool = False,
):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return

if _flashinfer_comm is None:
logger.warning(
"FlashInfer comm not available, skipping workspace " "initialization"
)
return

self.cleanup()

self.ipc_handles, self.workspace_tensor = (
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)
)

self.world_size = world_size
self.rank = rank
self.initialized = True

logger.info(
f"FlashInfer workspace initialized for rank {rank}, "
f"world_size {world_size}"
)

def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
self.ipc_handles, group=dist.group.WORLD
)
except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False


_workspace_manager = FlashInferWorkspaceManager()


def ensure_workspace_initialized(
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None:
return False

world_size = get_tensor_model_parallel_world_size()
if world_size <= 1:
return False

rank = dist.get_rank()

if (
not _workspace_manager.initialized
or _workspace_manager.world_size != world_size
):
_workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport,
)

return _workspace_manager.initialized


def flashinfer_allreduce_add_rmsnorm(
input_tensor: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 1024,
use_oneshot: bool = True,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation

Args:
input_tensor: Input tensor that needs allreduce
residual: Residual tensor
weight: RMS norm weight
eps: RMS norm epsilon
max_token_num: Maximum token number
use_oneshot: Whether to use oneshot mode
trigger_completion_at_end: Whether to trigger completion at end
fp32_acc: Whether to use fp32 precision

Returns:
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
"""
if not is_flashinfer_available() or _flashinfer_comm is None:
logger.debug(
"FlashInfer not available, falling back to standard " "implementation"
)
return None, None

world_size = get_tensor_model_parallel_world_size()
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None

if not ensure_workspace_initialized(
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == torch.float32),
):
logger.debug("FlashInfer workspace not available")
return None, None

token_num, hidden_dim = input_tensor.shape

residual_out = torch.empty_like(residual)
norm_out = torch.empty_like(input_tensor)

_flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)

return norm_out, residual_out


def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
26 changes: 26 additions & 0 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,32 @@ def forward_cpu(
else:
return self.forward_native(x, residual)

def forward_with_allreduce_fusion(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward method with allreduce fusion, prioritizing flashinfer fused operations
"""
if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_add_rmsnorm,
)

if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
)
if fused_result[0] is not None:
return fused_result

return self.forward(x, residual)


class GemmaRMSNorm(CustomOp):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_moe",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_moe: bool = False
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
Expand Down Expand Up @@ -1200,6 +1201,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
Expand Down
Loading