diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 38e7fdcf5542..3cd52160dfb6 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -5,7 +5,7 @@ Benchmark for FlashInfer fused collective operations vs standard operations. This benchmark compares: -1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant) 2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations Usage with torchrun: @@ -24,7 +24,6 @@ from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.distributed import ( - get_tp_group, tensor_model_parallel_all_reduce, ) from vllm.distributed.parallel_state import ( @@ -52,11 +51,12 @@ try: import flashinfer.comm as flashinfer_comm # type: ignore - if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + if not ( + hasattr(flashinfer_comm, "allreduce_fusion") + and hasattr(flashinfer_comm, "create_allreduce_fusion_workspace") + ): flashinfer_comm = None - logger.warning( - "FlashInfer comm module found but missing trtllm_allreduce_fusion" - ) + logger.warning("FlashInfer comm module found but missing allreduce_fusion API") except ImportError: flashinfer_comm = None logger.warning("FlashInfer not found, only benchmarking standard operations") @@ -75,7 +75,7 @@ } # Global workspace tensor for FlashInfer -_FI_WORKSPACE_TENSOR = None +_FI_WORKSPACE = None def setup_flashinfer_workspace( @@ -83,10 +83,10 @@ def setup_flashinfer_workspace( rank: int, hidden_dim: int, max_token_num: int, - use_fp32_lamport: bool = False, + dtype: torch.dtype, ): """Setup FlashInfer workspace for fused allreduce operations.""" - global _FI_WORKSPACE_TENSOR + global _FI_WORKSPACE if flashinfer_comm is None: return None, None @@ -96,33 +96,29 @@ def setup_flashinfer_workspace( return None, None try: - # Create IPC workspace - ipc_handles, workspace_tensor = ( - flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - tp_rank=rank, - tp_size=world_size, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - group=get_tp_group().device_group, - use_fp32_lamport=use_fp32_lamport, - ) + workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, ) - _FI_WORKSPACE_TENSOR = workspace_tensor - return ipc_handles, workspace_tensor + _FI_WORKSPACE = workspace + return workspace except Exception as e: logger.error("Failed to setup FlashInfer workspace: %s", e) - return None, None + return None -def cleanup_flashinfer_workspace(ipc_handles): +def cleanup_flashinfer_workspace(workspace): """Cleanup FlashInfer workspace.""" - if flashinfer_comm is None or ipc_handles is None: + if flashinfer_comm is None or workspace is None: return try: - group = get_tp_group().device_group - flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + workspace.destroy() except Exception as e: logger.error("Failed to cleanup FlashInfer workspace: %s", e) @@ -132,25 +128,15 @@ class FlashInferFusedAllReduceParams: def __init__( self, - rank: int, - world_size: int, - use_fp32_lamport: bool = False, max_token_num: int = 1024, ): - self.rank = rank - self.world_size = world_size - self.use_fp32_lamport = use_fp32_lamport - self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True self.max_token_num = max_token_num def get_trtllm_fused_allreduce_kwargs(self): return { - "world_rank": self.rank, - "world_size": self.world_size, "launch_with_pdl": self.launch_with_pdl, - "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, } @@ -165,7 +151,7 @@ def flashinfer_fused_allreduce_rmsnorm( norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm operation.""" - if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + if flashinfer_comm is None or _FI_WORKSPACE is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -174,18 +160,15 @@ def flashinfer_fused_allreduce_rmsnorm( else: residual_out = input_tensor - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=input_tensor, - token_num=input_tensor.shape[0], + flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=_FI_WORKSPACE, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, - hidden_dim=input_tensor.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, - allreduce_out=None, quant_out=None, scale_out=None, layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, @@ -207,7 +190,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( quant_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" - if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + if flashinfer_comm is None or _FI_WORKSPACE is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -216,18 +199,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( else: residual_out = input_tensor - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=input_tensor, - token_num=input_tensor.shape[0], + flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=_FI_WORKSPACE, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, - hidden_dim=input_tensor.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, - allreduce_out=None, quant_out=quant_out, scale_out=None, layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, @@ -250,7 +230,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" - if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + if flashinfer_comm is None or _FI_WORKSPACE is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -259,18 +239,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( else: residual_out = input_tensor - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=input_tensor, - token_num=input_tensor.shape[0], + flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=_FI_WORKSPACE, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, - hidden_dim=input_tensor.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, - allreduce_out=None, quant_out=quant_out, scale_out=output_scale, layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, @@ -1040,23 +1017,31 @@ def main(): configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) # Setup FlashInfer workspace if available - ipc_handles = None + workspace = None allreduce_params = None if flashinfer_comm is not None: # Use the largest hidden dimension for workspace setup + max_element_size = max(torch.finfo(dt).bits // 8 for dt in dtypes) + workspace_dtype = ( + torch.float32 + if max_element_size == 4 + else (torch.bfloat16 if torch.bfloat16 in dtypes else torch.float16) + ) max_num_token = _FI_MAX_SIZES.get(world_size) // ( - args.hidden_dim * world_size * 2 + args.hidden_dim * max_element_size ) - ipc_handles, workspace_tensor = setup_flashinfer_workspace( - world_size, rank, args.hidden_dim, max_num_token + workspace = setup_flashinfer_workspace( + world_size, + rank, + args.hidden_dim, + max_num_token, + dtype=workspace_dtype, ) - if workspace_tensor is not None: + if workspace is not None: allreduce_params = FlashInferFusedAllReduceParams( - rank=rank, - world_size=world_size, max_token_num=max_num_token, ) @@ -1119,8 +1104,8 @@ def main(): finally: # Cleanup - if ipc_handles is not None: - cleanup_flashinfer_workspace(ipc_handles) + if workspace is not None: + cleanup_flashinfer_workspace(workspace) dist.barrier() diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index f13f49b67376..d48f22970313 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -202,9 +202,10 @@ def ops_in_model_before(self): @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") - or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), + or not has_module_attribute("flashinfer.comm", "allreduce_fusion") + or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"), reason="flashinfer is not found or flashinfer " - "is not compiled with trtllm_allreduce_fusion", + "is not compiled with allreduce_fusion", ) def test_all_reduce_fusion_pass_replace( test_model: torch.nn.Module, diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 0b343fd162b7..b613d4424ee3 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from importlib.util import find_spec from types import ModuleType @@ -36,7 +37,9 @@ try: import flashinfer.comm as _flashinfer_comm - if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"): + if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr( + _flashinfer_comm, "create_allreduce_fusion_workspace" + ): flashinfer_comm = _flashinfer_comm except ImportError: pass @@ -79,7 +82,7 @@ if flashinfer_comm is not None: - _FI_WORKSPACE_TENSOR = None + _FI_WORKSPACE = None MiB = 1024 * 1024 def call_trtllm_fused_allreduce_norm( @@ -87,10 +90,8 @@ def call_trtllm_fused_allreduce_norm( residual: torch.Tensor, rms_gamma: torch.Tensor, rms_eps: float, - world_rank: int, world_size: int, launch_with_pdl: bool, - trigger_completion_at_end: bool, fp32_acc: bool, max_token_num: int, pattern_code: int, @@ -121,7 +122,7 @@ def call_trtllm_fused_allreduce_norm( max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB ) - assert _FI_WORKSPACE_TENSOR is not None, ( + assert _FI_WORKSPACE is not None, ( "Flashinfer must be enabled when using flashinfer" ) if norm_out is None: @@ -134,24 +135,18 @@ def call_trtllm_fused_allreduce_norm( residual_out = allreduce_in # For the sizes that are smaller than the max size, # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], + flashinfer_comm.allreduce_fusion( + input=allreduce_in, + workspace=_FI_WORKSPACE, + pattern=pattern_code, residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, launch_with_pdl=launch_with_pdl, use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout @@ -164,10 +159,8 @@ def call_trtllm_fused_allreduce_norm_fake( residual: torch.Tensor, rms_gamma: torch.Tensor, rms_eps: float, - world_rank: int, world_size: int, launch_with_pdl: bool, - trigger_completion_at_end: bool, fp32_acc: bool, max_token_num: int, pattern_code: int, @@ -200,25 +193,18 @@ class FlashInferFusedAllReduceParams: def __init__( self, - rank: int, world_size: int, - use_fp32_lamport: bool = False, max_token_num: int = 1024, ) -> None: - self.rank = rank self.world_size = world_size - self.use_fp32_lamport = use_fp32_lamport - self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True self.max_token_num = max_token_num def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]: return { - "world_rank": self.rank, "world_size": self.world_size, "launch_with_pdl": self.launch_with_pdl, - "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, } @@ -712,7 +698,6 @@ def __init__(self, config: VllmConfig) -> None: self.hidden_dim = config.model_config.get_hidden_size() self.group = get_tp_group().device_group rank = get_tensor_model_parallel_rank() - use_fp32_lamport = self.model_dtype == torch.float32 if flashinfer_comm is None: logger.warning( "Flashinfer is not installed or comm module not found, " @@ -730,7 +715,7 @@ def __init__(self, config: VllmConfig) -> None: self.tp_size, ) return - element_size = 4 if use_fp32_lamport else 2 + element_size = torch.tensor([], dtype=self.model_dtype).element_size() self.max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways @@ -744,23 +729,19 @@ def __init__(self, config: VllmConfig) -> None: scope="global", ) - self.ipc_handles, workspace_tensor = ( - flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - tp_rank=rank, - tp_size=self.tp_size, - max_token_num=self.max_token_num, - hidden_dim=self.hidden_dim, - group=self.group, - use_fp32_lamport=use_fp32_lamport, - ) + self.workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=self.tp_size, + rank=rank, + max_token_num=self.max_token_num, + hidden_dim=self.hidden_dim, + dtype=self.model_dtype, ) - global _FI_WORKSPACE_TENSOR - _FI_WORKSPACE_TENSOR = workspace_tensor + global _FI_WORKSPACE + _FI_WORKSPACE = self.workspace self.allreduce_params = FlashInferFusedAllReduceParams( - rank=rank, world_size=self.tp_size, - use_fp32_lamport=use_fp32_lamport, max_token_num=self.max_token_num, ) @@ -832,7 +813,6 @@ def __call__(self, graph: fx.Graph) -> None: def __del__(self) -> None: if getattr(self, "disabled", True): return - if flashinfer_comm is not None: - flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( - self.ipc_handles, self.group - ) + if getattr(self, "workspace", None) is not None: + with contextlib.suppress(Exception): + self.workspace.destroy()