-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[b200] support trt-llm allreduce fuse rms_norm_add kernel #7621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
110d3b9
init
BBuf b603c2c
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
BBuf b603958
upd
BBuf 9af8348
refine
BBuf bd04bfa
upd
BBuf b640fb4
upd
BBuf b224bee
lint
BBuf 02f1a98
tune max_token_num
BBuf 32f812f
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
zhyncs 16632a3
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
BBuf 3dbbd25
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
zhyncs 810a21d
refine
BBuf bf6d477
Merge branch 'support_allreduce_rmsnorm_add_fusion' of github.com:sgl…
BBuf a60580d
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
BBuf 8adbb74
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
zhyncs 890dadd
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
zhyncs cab45c3
fix comment
BBuf 2088b16
fix ci
BBuf 8a97504
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
BBuf fa8f93e
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
zhyncs 3908aa5
Merge branch 'main' into support_allreduce_rmsnorm_add_fusion
BBuf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| import logging | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
| from sglang.srt.utils import is_flashinfer_available | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
| if is_flashinfer_available(): | ||
| try: | ||
| import flashinfer.comm as comm | ||
|
|
||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning( | ||
| "flashinfer.comm is not available, falling back to standard " | ||
| "implementation" | ||
| ) | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize( | ||
| self, | ||
| world_size: int, | ||
| rank: int, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| group=None, | ||
| use_fp32_lamport: bool = False, | ||
| ): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| if _flashinfer_comm is None: | ||
| logger.warning( | ||
| "FlashInfer comm not available, skipping workspace " "initialization" | ||
| ) | ||
| return | ||
|
|
||
| self.cleanup() | ||
|
|
||
| self.ipc_handles, self.workspace_tensor = ( | ||
| comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info( | ||
| f"FlashInfer workspace initialized for rank {rank}, " | ||
| f"world_size {world_size}" | ||
| ) | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( | ||
| self.ipc_handles, group=dist.group.WORLD | ||
| ) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized( | ||
| max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False | ||
| ): | ||
| """Ensure workspace is initialized""" | ||
| if not is_flashinfer_available() or _flashinfer_comm is None: | ||
| return False | ||
|
|
||
| world_size = get_tensor_model_parallel_world_size() | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if ( | ||
| not _workspace_manager.initialized | ||
| or _workspace_manager.world_size != world_size | ||
| ): | ||
| _workspace_manager.initialize( | ||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
|
|
||
| def flashinfer_allreduce_add_rmsnorm( | ||
| input_tensor: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 1024, | ||
| use_oneshot: bool = True, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
|
|
||
| Args: | ||
| input_tensor: Input tensor that needs allreduce | ||
| residual: Residual tensor | ||
| weight: RMS norm weight | ||
| eps: RMS norm epsilon | ||
| max_token_num: Maximum token number | ||
| use_oneshot: Whether to use oneshot mode | ||
| trigger_completion_at_end: Whether to trigger completion at end | ||
| fp32_acc: Whether to use fp32 precision | ||
|
|
||
| Returns: | ||
| Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) | ||
| """ | ||
| if not is_flashinfer_available() or _flashinfer_comm is None: | ||
| logger.debug( | ||
| "FlashInfer not available, falling back to standard " "implementation" | ||
| ) | ||
| return None, None | ||
|
|
||
| world_size = get_tensor_model_parallel_world_size() | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == torch.float32), | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = torch.empty_like(residual) | ||
| norm_out = torch.empty_like(input_tensor) | ||
|
|
||
| _flashinfer_comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hidden_state.numel() * hidden_state.element_size() < THRESHOLD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, I'll update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this feature could also be applied to other models. However, since different models may have varying hidden_size values, directly checking
max_token_numseems more general. Regarding the token number 1024: This value was estimated based on the workspace threshold in trt_llm, using ds-v3's hidden_size as a reference. With 1024 tokens, it only allocates an additional ~10MB buffer, so the memory overhead is minimal.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I think the max_workspace settings in trt-llm are not for these flashinfer kernels (allreduce_fusion_xxx).
The max_token_number should be set as the actual max token number.
The default value of use_oneshot should be False. In these two cases it could be set as True. (https://github.com/NVIDIA/TensorRT-LLM/blob/a1235ee9781050e562bbce2c86f714c38d434dbe/cpp/tensorrt_llm/thop/allreduceOp.cpp#L416-L429)