-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[b200] support trt-llm allreduce fuse rms_norm_add kernel #7621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
110d3b9
b603c2c
b603958
9af8348
bd04bfa
b640fb4
b224bee
02f1a98
32f812f
16632a3
3dbbd25
810a21d
bf6d477
a60580d
8adbb74
890dadd
cab45c3
2088b16
8a97504
fa8f93e
3908aa5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
| # Copyright 2023-2024 SGLang Team | ||
BBuf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """FlashInfer fusion operations for allreduce + RMS norm.""" | ||
|
|
||
| import logging | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
| from sglang.srt.utils import is_flashinfer_available | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
| if is_flashinfer_available(): | ||
| try: | ||
| import flashinfer.comm as comm | ||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning( | ||
| "flashinfer.comm is not available, falling back to standard " | ||
| "implementation" | ||
| ) | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize(self, world_size: int, rank: int, max_token_num: int, | ||
| hidden_dim: int, group=None, | ||
| use_fp32_lamport: bool = False): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| if _flashinfer_comm is None: | ||
| logger.warning( | ||
| "FlashInfer comm not available, skipping workspace " | ||
| "initialization" | ||
| ) | ||
| return | ||
|
|
||
| # try: | ||
| # Clean up previous workspace | ||
| self.cleanup() | ||
|
|
||
| # Create new workspace | ||
| self.ipc_handles, self.workspace_tensor = ( | ||
| comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info( | ||
| f"FlashInfer workspace initialized for rank {rank}, " | ||
| f"world_size {world_size}" | ||
| ) | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( | ||
| self.ipc_handles, | ||
| group=dist.group.WORLD | ||
| ) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a check to ensure if self.initialized and self.ipc_handles is not None:
if not self.initialized:
return
try: |
||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized(max_token_num: int = 8192, | ||
| hidden_dim: int = 4096, | ||
| use_fp32_lamport: bool = False): | ||
| """Ensure workspace is initialized""" | ||
| if not is_flashinfer_available() or _flashinfer_comm is None: | ||
| return False | ||
|
|
||
| world_size = get_tensor_model_parallel_world_size() | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if (not _workspace_manager.initialized or | ||
| _workspace_manager.world_size != world_size): | ||
| _workspace_manager.initialize( | ||
|
||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
| def flashinfer_allreduce_add_rmsnorm( | ||
| input_tensor: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 8192, | ||
| use_oneshot: bool = True, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
| Args: | ||
| input_tensor: Input tensor that needs allreduce | ||
| residual: Residual tensor | ||
| weight: RMS norm weight | ||
| eps: RMS norm epsilon | ||
| max_token_num: Maximum token number | ||
| use_oneshot: Whether to use oneshot mode | ||
| trigger_completion_at_end: Whether to trigger completion at end | ||
| fp32_acc: Whether to use fp32 precision | ||
| Returns: | ||
| Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) | ||
| """ | ||
| if not is_flashinfer_available() or _flashinfer_comm is None: | ||
| logger.debug( | ||
| "FlashInfer not available, falling back to standard " | ||
| "implementation" | ||
| ) | ||
| return None, None | ||
|
|
||
| world_size = get_tensor_model_parallel_world_size() | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == torch.float32) | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = torch.empty_like(residual) | ||
| norm_out = torch.empty_like(input_tensor) | ||
|
|
||
| _flashinfer_comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=( | ||
| _flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm | ||
| ), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| is_cuda, | ||
| is_hip, | ||
| is_npu, | ||
| is_flashinfer_available | ||
| ) | ||
|
|
||
| _is_cuda = is_cuda() | ||
|
|
@@ -163,6 +164,31 @@ def forward_cpu( | |
| else: | ||
| return self.forward_native(x, residual) | ||
|
|
||
| def forward_with_allreduce_fusion( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| """ | ||
| Forward method with allreduce fusion, prioritizing flashinfer fused operations | ||
| """ | ||
| if is_flashinfer_available() and residual is not None: | ||
| from sglang.srt.layers.flashinfer_fusion import flashinfer_allreduce_add_rmsnorm | ||
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
|
|
||
| # Only use fusion operation in multi-GPU environment | ||
| if get_tensor_model_parallel_world_size() > 1: | ||
| fused_result = flashinfer_allreduce_add_rmsnorm( | ||
| input_tensor=x, | ||
| residual=residual, | ||
| weight=self.weight, | ||
| eps=self.variance_epsilon, | ||
| ) | ||
| if fused_result[0] is not None: | ||
| return fused_result | ||
|
|
||
| return self.forward(x, residual) | ||
|
||
|
|
||
|
|
||
| class GemmaRMSNorm(CustomOp): | ||
| def __init__( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
getattrwith a default value to avoid checking for the attributeforward_with_allreduce_fusionbefore accessing it. This can simplify the code and make it more readable.