From 110d3b921c66a9356457ddb635baf131c1560b29 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 28 Jun 2025 10:44:24 +0000 Subject: [PATCH 01/10] init --- python/sglang/srt/layers/communicator.py | 7 +- python/sglang/srt/layers/flashinfer_fusion.py | 215 ++++++++++++++++++ python/sglang/srt/layers/layernorm.py | 26 +++ 3 files changed, 246 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/layers/flashinfer_fusion.py diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 42d2ec2a3f3..72527b38c8a 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -397,8 +397,11 @@ def _gather_hidden_states_and_residual( if hidden_states.shape[0] != 0: hidden_states = layernorm(hidden_states) else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - hidden_states, residual = layernorm(hidden_states, residual) + if hasattr(layernorm, 'forward_with_allreduce_fusion'): + hidden_states, residual = layernorm.forward_with_allreduce_fusion(hidden_states, residual) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual @staticmethod diff --git a/python/sglang/srt/layers/flashinfer_fusion.py b/python/sglang/srt/layers/flashinfer_fusion.py new file mode 100644 index 00000000000..b93ea07419e --- /dev/null +++ b/python/sglang/srt/layers/flashinfer_fusion.py @@ -0,0 +1,215 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FlashInfer fusion operations for allreduce + RMS norm.""" + +import logging +from typing import Tuple + +import torch +import torch.distributed as dist + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.utils import is_flashinfer_available + +logger = logging.getLogger(__name__) + +_flashinfer_comm = None +_workspace_manager = None + +if is_flashinfer_available(): + try: + import flashinfer.comm as comm + _flashinfer_comm = comm + except ImportError: + logger.warning( + "flashinfer.comm is not available, falling back to standard " + "implementation" + ) + + +class FlashInferWorkspaceManager: + def __init__(self): + self.workspace_tensor = None + self.ipc_handles = None + self.world_size = None + self.rank = None + self.initialized = False + + def initialize(self, world_size: int, rank: int, max_token_num: int, + hidden_dim: int, group=None, + use_fp32_lamport: bool = False): + """Initialize workspace""" + if self.initialized and self.world_size == world_size: + return + + if _flashinfer_comm is None: + logger.warning( + "FlashInfer comm not available, skipping workspace " + "initialization" + ) + return + + # try: + # Clean up previous workspace + self.cleanup() + + # Create new workspace + self.ipc_handles, self.workspace_tensor = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + max_token_num, + hidden_dim, + group=group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + self.world_size = world_size + self.rank = rank + self.initialized = True + + logger.info( + f"FlashInfer workspace initialized for rank {rank}, " + f"world_size {world_size}" + ) + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.ipc_handles is not None: + try: + _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( + self.ipc_handles, + group=dist.group.WORLD + ) + except Exception as e: + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") + finally: + self.workspace_tensor = None + self.ipc_handles = None + self.initialized = False + + +_workspace_manager = FlashInferWorkspaceManager() + + +def ensure_workspace_initialized(max_token_num: int = 8192, + hidden_dim: int = 4096, + use_fp32_lamport: bool = False): + """Ensure workspace is initialized""" + if not is_flashinfer_available() or _flashinfer_comm is None: + return False + + world_size = get_tensor_model_parallel_world_size() + if world_size <= 1: + return False + + rank = dist.get_rank() + + if (not _workspace_manager.initialized or + _workspace_manager.world_size != world_size): + _workspace_manager.initialize( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + use_fp32_lamport=use_fp32_lamport + ) + + return _workspace_manager.initialized + +def flashinfer_allreduce_add_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + max_token_num: int = 8192, + use_oneshot: bool = True, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Use FlashInfer's fused allreduce + residual + RMS norm operation + + Args: + input_tensor: Input tensor that needs allreduce + residual: Residual tensor + weight: RMS norm weight + eps: RMS norm epsilon + max_token_num: Maximum token number + use_oneshot: Whether to use oneshot mode + trigger_completion_at_end: Whether to trigger completion at end + fp32_acc: Whether to use fp32 precision + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) + """ + if not is_flashinfer_available() or _flashinfer_comm is None: + logger.debug( + "FlashInfer not available, falling back to standard " + "implementation" + ) + return None, None + + world_size = get_tensor_model_parallel_world_size() + if world_size <= 1: + logger.debug("Single GPU, no need for allreduce fusion") + return None, None + + if not ensure_workspace_initialized( + max_token_num=max_token_num, + hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == torch.float32) + ): + logger.debug("FlashInfer workspace not available") + return None, None + + token_num, hidden_dim = input_tensor.shape + + residual_out = torch.empty_like(residual) + norm_out = torch.empty_like(input_tensor) + + _flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + world_size=world_size, + world_rank=dist.get_rank(), + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=_workspace_manager.workspace_tensor, + launch_with_pdl=True, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=( + _flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm + ), + allreduce_out=None, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=eps, + scale_factor=None, + layout_code=None, + ) + + return norm_out, residual_out + + +def cleanup_flashinfer_workspace(): + global _workspace_manager + if _workspace_manager is not None: + _workspace_manager.cleanup() diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 5d8106f17f4..f9231e85268 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -27,6 +27,7 @@ is_cuda, is_hip, is_npu, + is_flashinfer_available ) _is_cuda = is_cuda() @@ -163,6 +164,31 @@ def forward_cpu( else: return self.forward_native(x, residual) + def forward_with_allreduce_fusion( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward method with allreduce fusion, prioritizing flashinfer fused operations + """ + if is_flashinfer_available() and residual is not None: + from sglang.srt.layers.flashinfer_fusion import flashinfer_allreduce_add_rmsnorm + from sglang.srt.distributed import get_tensor_model_parallel_world_size + + # Only use fusion operation in multi-GPU environment + if get_tensor_model_parallel_world_size() > 1: + fused_result = flashinfer_allreduce_add_rmsnorm( + input_tensor=x, + residual=residual, + weight=self.weight, + eps=self.variance_epsilon, + ) + if fused_result[0] is not None: + return fused_result + + return self.forward(x, residual) + class GemmaRMSNorm(CustomOp): def __init__( From b6039583c41643788daa3f1712ba5c311d88d48d Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 28 Jun 2025 14:01:53 +0000 Subject: [PATCH 02/10] upd --- .../layers/{flashinfer_fusion.py => flashinfer_comm_fusion.py} | 0 python/sglang/srt/layers/layernorm.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename python/sglang/srt/layers/{flashinfer_fusion.py => flashinfer_comm_fusion.py} (100%) diff --git a/python/sglang/srt/layers/flashinfer_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py similarity index 100% rename from python/sglang/srt/layers/flashinfer_fusion.py rename to python/sglang/srt/layers/flashinfer_comm_fusion.py diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index f9231e85268..7c0f621a28f 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -173,7 +173,7 @@ def forward_with_allreduce_fusion( Forward method with allreduce fusion, prioritizing flashinfer fused operations """ if is_flashinfer_available() and residual is not None: - from sglang.srt.layers.flashinfer_fusion import flashinfer_allreduce_add_rmsnorm + from sglang.python.sglang.srt.layers.flashinfer_comm_fusion import flashinfer_allreduce_add_rmsnorm from sglang.srt.distributed import get_tensor_model_parallel_world_size # Only use fusion operation in multi-GPU environment From 9af8348a731dd7be264279b38cba99ddee991a57 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 28 Jun 2025 14:06:23 +0000 Subject: [PATCH 03/10] refine --- python/sglang/srt/layers/layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 7c0f621a28f..374d33d6ce2 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -173,7 +173,7 @@ def forward_with_allreduce_fusion( Forward method with allreduce fusion, prioritizing flashinfer fused operations """ if is_flashinfer_available() and residual is not None: - from sglang.python.sglang.srt.layers.flashinfer_comm_fusion import flashinfer_allreduce_add_rmsnorm + from sglang.srt.layers.flashinfer_comm_fusion import flashinfer_allreduce_add_rmsnorm from sglang.srt.distributed import get_tensor_model_parallel_world_size # Only use fusion operation in multi-GPU environment From bd04bfaa7be87b667bb6f7594581616dcc1da632 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 28 Jun 2025 17:13:28 +0000 Subject: [PATCH 04/10] upd --- python/sglang/srt/layers/communicator.py | 4 +++- python/sglang/srt/layers/flashinfer_comm_fusion.py | 2 +- python/sglang/srt/layers/layernorm.py | 4 +--- python/sglang/srt/server_args.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 72527b38c8a..38dba8b024c 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -19,6 +19,8 @@ import torch.distributed +from sglang.srt.utils import is_flashinfer_available +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -397,7 +399,7 @@ def _gather_hidden_states_and_residual( if hidden_states.shape[0] != 0: hidden_states = layernorm(hidden_states) else: - if hasattr(layernorm, 'forward_with_allreduce_fusion'): + if is_sm100_supported() and is_flashinfer_available() and hasattr(layernorm, 'forward_with_allreduce_fusion'): hidden_states, residual = layernorm.forward_with_allreduce_fusion(hidden_states, residual) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index b93ea07419e..371272e164e 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -136,7 +136,7 @@ def flashinfer_allreduce_add_rmsnorm( eps: float = 1e-6, max_token_num: int = 8192, use_oneshot: bool = True, - trigger_completion_at_end: bool = False, + trigger_completion_at_end: bool = True, fp32_acc: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 374d33d6ce2..ab6ca314d9d 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -27,7 +27,6 @@ is_cuda, is_hip, is_npu, - is_flashinfer_available ) _is_cuda = is_cuda() @@ -172,11 +171,10 @@ def forward_with_allreduce_fusion( """ Forward method with allreduce fusion, prioritizing flashinfer fused operations """ - if is_flashinfer_available() and residual is not None: + if residual is not None: from sglang.srt.layers.flashinfer_comm_fusion import flashinfer_allreduce_add_rmsnorm from sglang.srt.distributed import get_tensor_model_parallel_world_size - # Only use fusion operation in multi-GPU environment if get_tensor_model_parallel_world_size() > 1: fused_result = flashinfer_allreduce_add_rmsnorm( input_tensor=x, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 835e6e88851..d2c93aa6df3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -324,7 +324,7 @@ def __post_init__(self): elif gpu_mem < 160 * 1024: # H100, H200, A100, H20 self.chunked_prefill_size = 8192 else: # B200, MI300 - self.chunked_prefill_size = 16384 + self.chunked_prefill_size = 8192 else: self.chunked_prefill_size = 4096 assert self.chunked_prefill_size % self.page_size == 0 From b640fb4324b769c4a611abb8687ab0e0dfa6bb1d Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 30 Jun 2025 16:05:31 +0000 Subject: [PATCH 05/10] upd --- python/sglang/srt/layers/communicator.py | 2 +- .../srt/layers/flashinfer_comm_fusion.py | 24 +++---------------- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/server_args.py | 8 ++++++- 4 files changed, 12 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 38dba8b024c..059933fa211 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -399,7 +399,7 @@ def _gather_hidden_states_and_residual( if hidden_states.shape[0] != 0: hidden_states = layernorm(hidden_states) else: - if is_sm100_supported() and is_flashinfer_available() and hasattr(layernorm, 'forward_with_allreduce_fusion'): + if is_sm100_supported() and is_flashinfer_available() and hasattr(layernorm, 'forward_with_allreduce_fusion') and global_server_args_dict['enable_flashinfer_allreduce_fusion'] and hidden_states.shape[0] <= 4096: hidden_states, residual = layernorm.forward_with_allreduce_fusion(hidden_states, residual) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index 371272e164e..2c771e3b0a3 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -1,18 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FlashInfer fusion operations for allreduce + RMS norm.""" - import logging from typing import Tuple @@ -59,12 +44,9 @@ def initialize(self, world_size: int, rank: int, max_token_num: int, "initialization" ) return - - # try: - # Clean up previous workspace + self.cleanup() - # Create new workspace self.ipc_handles, self.workspace_tensor = ( comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( rank, @@ -104,7 +86,7 @@ def cleanup(self): _workspace_manager = FlashInferWorkspaceManager() -def ensure_workspace_initialized(max_token_num: int = 8192, +def ensure_workspace_initialized(max_token_num: int = 4096, hidden_dim: int = 4096, use_fp32_lamport: bool = False): """Ensure workspace is initialized""" @@ -134,7 +116,7 @@ def flashinfer_allreduce_add_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - max_token_num: int = 8192, + max_token_num: int = 4096, use_oneshot: bool = True, trigger_completion_at_end: bool = True, fp32_acc: bool = False, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6728f885235..2dc518373a2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -87,6 +87,7 @@ "deepep_mode", "enable_ep_moe", "enable_flashinfer_moe", + "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", "ep_dispatch_algorithm", "deepep_config", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d2c93aa6df3..1aa0b37b52e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -155,6 +155,7 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False enable_flashinfer_moe: bool = False + enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None @@ -324,7 +325,7 @@ def __post_init__(self): elif gpu_mem < 160 * 1024: # H100, H200, A100, H20 self.chunked_prefill_size = 8192 else: # B200, MI300 - self.chunked_prefill_size = 8192 + self.chunked_prefill_size = 16384 else: self.chunked_prefill_size = 4096 assert self.chunked_prefill_size % self.page_size == 0 @@ -1199,6 +1200,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", ) + parser.add_argument( + "--enable-flashinfer-allreduce-fusion", + action="store_true", + help="Enable FlashInfer allreduce fusion for Add_RMSNorm.", + ) parser.add_argument( "--enable-deepep-moe", action="store_true", From b224beeda3ee0e3ee7e44b43754d0a07699a2983 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 1 Jul 2025 00:07:18 +0800 Subject: [PATCH 06/10] lint --- python/sglang/srt/layers/communicator.py | 16 +++- .../srt/layers/flashinfer_comm_fusion.py | 79 ++++++++++--------- python/sglang/srt/layers/layernorm.py | 8 +- 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 059933fa211..802ba77a9b0 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -19,8 +19,6 @@ import torch.distributed -from sglang.srt.utils import is_flashinfer_available -from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -34,8 +32,10 @@ get_attention_tp_rank, get_attention_tp_size, ) +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import is_flashinfer_available class ScatterMode(Enum): @@ -399,8 +399,16 @@ def _gather_hidden_states_and_residual( if hidden_states.shape[0] != 0: hidden_states = layernorm(hidden_states) else: - if is_sm100_supported() and is_flashinfer_available() and hasattr(layernorm, 'forward_with_allreduce_fusion') and global_server_args_dict['enable_flashinfer_allreduce_fusion'] and hidden_states.shape[0] <= 4096: - hidden_states, residual = layernorm.forward_with_allreduce_fusion(hidden_states, residual) + if ( + is_sm100_supported() + and is_flashinfer_available() + and hasattr(layernorm, "forward_with_allreduce_fusion") + and global_server_args_dict["enable_flashinfer_allreduce_fusion"] + and hidden_states.shape[0] <= 4096 + ): + hidden_states, residual = layernorm.forward_with_allreduce_fusion( + hidden_states, residual + ) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states, residual = layernorm(hidden_states, residual) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index 2c771e3b0a3..df0e92b568c 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -15,6 +15,7 @@ if is_flashinfer_available(): try: import flashinfer.comm as comm + _flashinfer_comm = comm except ImportError: logger.warning( @@ -30,23 +31,28 @@ def __init__(self): self.world_size = None self.rank = None self.initialized = False - - def initialize(self, world_size: int, rank: int, max_token_num: int, - hidden_dim: int, group=None, - use_fp32_lamport: bool = False): + + def initialize( + self, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + group=None, + use_fp32_lamport: bool = False, + ): """Initialize workspace""" if self.initialized and self.world_size == world_size: return - + if _flashinfer_comm is None: logger.warning( - "FlashInfer comm not available, skipping workspace " - "initialization" + "FlashInfer comm not available, skipping workspace " "initialization" ) return self.cleanup() - + self.ipc_handles, self.workspace_tensor = ( comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( rank, @@ -57,23 +63,22 @@ def initialize(self, world_size: int, rank: int, max_token_num: int, use_fp32_lamport=use_fp32_lamport, ) ) - + self.world_size = world_size self.rank = rank self.initialized = True - + logger.info( f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}" ) - + def cleanup(self): """Clean up workspace""" if self.initialized and self.ipc_handles is not None: try: _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( - self.ipc_handles, - group=dist.group.WORLD + self.ipc_handles, group=dist.group.WORLD ) except Exception as e: logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") @@ -86,31 +91,34 @@ def cleanup(self): _workspace_manager = FlashInferWorkspaceManager() -def ensure_workspace_initialized(max_token_num: int = 4096, - hidden_dim: int = 4096, - use_fp32_lamport: bool = False): +def ensure_workspace_initialized( + max_token_num: int = 4096, hidden_dim: int = 4096, use_fp32_lamport: bool = False +): """Ensure workspace is initialized""" if not is_flashinfer_available() or _flashinfer_comm is None: return False - + world_size = get_tensor_model_parallel_world_size() if world_size <= 1: return False - + rank = dist.get_rank() - - if (not _workspace_manager.initialized or - _workspace_manager.world_size != world_size): + + if ( + not _workspace_manager.initialized + or _workspace_manager.world_size != world_size + ): _workspace_manager.initialize( world_size=world_size, rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, - use_fp32_lamport=use_fp32_lamport + use_fp32_lamport=use_fp32_lamport, ) - + return _workspace_manager.initialized + def flashinfer_allreduce_add_rmsnorm( input_tensor: torch.Tensor, residual: torch.Tensor, @@ -123,7 +131,7 @@ def flashinfer_allreduce_add_rmsnorm( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Use FlashInfer's fused allreduce + residual + RMS norm operation - + Args: input_tensor: Input tensor that needs allreduce residual: Residual tensor @@ -133,32 +141,31 @@ def flashinfer_allreduce_add_rmsnorm( use_oneshot: Whether to use oneshot mode trigger_completion_at_end: Whether to trigger completion at end fp32_acc: Whether to use fp32 precision - + Returns: Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) """ if not is_flashinfer_available() or _flashinfer_comm is None: logger.debug( - "FlashInfer not available, falling back to standard " - "implementation" + "FlashInfer not available, falling back to standard " "implementation" ) return None, None - + world_size = get_tensor_model_parallel_world_size() if world_size <= 1: logger.debug("Single GPU, no need for allreduce fusion") return None, None - + if not ensure_workspace_initialized( - max_token_num=max_token_num, + max_token_num=max_token_num, hidden_dim=input_tensor.shape[-1], - use_fp32_lamport=(input_tensor.dtype == torch.float32) + use_fp32_lamport=(input_tensor.dtype == torch.float32), ): logger.debug("FlashInfer workspace not available") return None, None - + token_num, hidden_dim = input_tensor.shape - + residual_out = torch.empty_like(residual) norm_out = torch.empty_like(input_tensor) @@ -173,9 +180,7 @@ def flashinfer_allreduce_add_rmsnorm( use_oneshot=use_oneshot, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, - pattern_code=( - _flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm - ), + pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), allreduce_out=None, residual_in=residual, residual_out=residual_out, @@ -187,7 +192,7 @@ def flashinfer_allreduce_add_rmsnorm( scale_factor=None, layout_code=None, ) - + return norm_out, residual_out diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index ab6ca314d9d..78b4a05139d 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -172,9 +172,11 @@ def forward_with_allreduce_fusion( Forward method with allreduce fusion, prioritizing flashinfer fused operations """ if residual is not None: - from sglang.srt.layers.flashinfer_comm_fusion import flashinfer_allreduce_add_rmsnorm from sglang.srt.distributed import get_tensor_model_parallel_world_size - + from sglang.srt.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_add_rmsnorm, + ) + if get_tensor_model_parallel_world_size() > 1: fused_result = flashinfer_allreduce_add_rmsnorm( input_tensor=x, @@ -184,7 +186,7 @@ def forward_with_allreduce_fusion( ) if fused_result[0] is not None: return fused_result - + return self.forward(x, residual) From 02f1a982fbecf7d35e100af08608776db7ec4e76 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 30 Jun 2025 16:49:03 +0000 Subject: [PATCH 07/10] tune max_token_num --- python/sglang/srt/layers/communicator.py | 2 +- python/sglang/srt/layers/flashinfer_comm_fusion.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 802ba77a9b0..1068112e196 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -404,7 +404,7 @@ def _gather_hidden_states_and_residual( and is_flashinfer_available() and hasattr(layernorm, "forward_with_allreduce_fusion") and global_server_args_dict["enable_flashinfer_allreduce_fusion"] - and hidden_states.shape[0] <= 4096 + and hidden_states.shape[0] <= 1024 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index df0e92b568c..a36e98a8ff4 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -92,7 +92,7 @@ def cleanup(self): def ensure_workspace_initialized( - max_token_num: int = 4096, hidden_dim: int = 4096, use_fp32_lamport: bool = False + max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False ): """Ensure workspace is initialized""" if not is_flashinfer_available() or _flashinfer_comm is None: @@ -124,7 +124,7 @@ def flashinfer_allreduce_add_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - max_token_num: int = 4096, + max_token_num: int = 1024, use_oneshot: bool = True, trigger_completion_at_end: bool = True, fp32_acc: bool = False, From 810a21d527873aa2a1ce43c2512c08ae3b9c13d5 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 1 Jul 2025 10:23:24 +0000 Subject: [PATCH 08/10] refine --- python/sglang/srt/layers/flashinfer_comm_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index a36e98a8ff4..fb78218c3d7 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -126,7 +126,7 @@ def flashinfer_allreduce_add_rmsnorm( eps: float = 1e-6, max_token_num: int = 1024, use_oneshot: bool = True, - trigger_completion_at_end: bool = True, + trigger_completion_at_end: bool = False, fp32_acc: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ From cab45c3d1a66f26b15e41a2f2178fa52a576957d Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 2 Jul 2025 14:15:59 +0800 Subject: [PATCH 09/10] fix comment --- python/sglang/srt/layers/communicator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 1068112e196..0c9b9fea58e 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -37,6 +37,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_flashinfer_available +_is_flashinfer_available = is_flashinfer_available() +_is_sm100_supported = is_sm100_supported() + class ScatterMode(Enum): """ @@ -400,8 +403,8 @@ def _gather_hidden_states_and_residual( hidden_states = layernorm(hidden_states) else: if ( - is_sm100_supported() - and is_flashinfer_available() + _is_sm100_supported + and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and hidden_states.shape[0] <= 1024 From 2088b161e183e768a1d11ad915c6613b56e08d10 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 2 Jul 2025 15:53:49 +0800 Subject: [PATCH 10/10] fix ci --- python/sglang/srt/layers/communicator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 0c9b9fea58e..4af27ad6930 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -35,10 +35,10 @@ from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda, is_flashinfer_available _is_flashinfer_available = is_flashinfer_available() -_is_sm100_supported = is_sm100_supported() +_is_sm100_supported = is_cuda() and is_sm100_supported() class ScatterMode(Enum):