-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[b200] support trt-llm allreduce fuse rms_norm_add kernel #7621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
110d3b9
b603c2c
b603958
9af8348
bd04bfa
b640fb4
b224bee
02f1a98
32f812f
16632a3
3dbbd25
810a21d
bf6d477
a60580d
8adbb74
890dadd
cab45c3
2088b16
8a97504
fa8f93e
3908aa5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,8 +32,13 @@ | |
| get_attention_tp_rank, | ||
| get_attention_tp_size, | ||
| ) | ||
| from sglang.srt.layers.utils import is_sm100_supported | ||
| from sglang.srt.managers.schedule_batch import global_server_args_dict | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
| from sglang.srt.utils import is_flashinfer_available | ||
|
|
||
| _is_flashinfer_available = is_flashinfer_available() | ||
| _is_sm100_supported = is_sm100_supported() | ||
|
|
||
|
|
||
| class ScatterMode(Enum): | ||
|
|
@@ -397,8 +402,19 @@ def _gather_hidden_states_and_residual( | |
| if hidden_states.shape[0] != 0: | ||
| hidden_states = layernorm(hidden_states) | ||
| else: | ||
| hidden_states = tensor_model_parallel_all_reduce(hidden_states) | ||
| hidden_states, residual = layernorm(hidden_states, residual) | ||
| if ( | ||
| _is_sm100_supported | ||
| and _is_flashinfer_available | ||
| and hasattr(layernorm, "forward_with_allreduce_fusion") | ||
| and global_server_args_dict["enable_flashinfer_allreduce_fusion"] | ||
| and hidden_states.shape[0] <= 1024 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hidden_state.numel() * hidden_state.element_size() < THRESHOLD
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense, I'll update it.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this feature could also be applied to other models. However, since different models may have varying hidden_size values, directly checking
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I think the max_workspace settings in trt-llm are not for these flashinfer kernels (allreduce_fusion_xxx). The max_token_number should be set as the actual max token number. |
||
| ): | ||
| hidden_states, residual = layernorm.forward_with_allreduce_fusion( | ||
| hidden_states, residual | ||
| ) | ||
| else: | ||
| hidden_states = tensor_model_parallel_all_reduce(hidden_states) | ||
| hidden_states, residual = layernorm(hidden_states, residual) | ||
| return hidden_states, residual | ||
|
|
||
| @staticmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| import logging | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
| from sglang.srt.utils import is_flashinfer_available | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
| if is_flashinfer_available(): | ||
| try: | ||
| import flashinfer.comm as comm | ||
|
|
||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning( | ||
| "flashinfer.comm is not available, falling back to standard " | ||
| "implementation" | ||
| ) | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize( | ||
| self, | ||
| world_size: int, | ||
| rank: int, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| group=None, | ||
| use_fp32_lamport: bool = False, | ||
| ): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| if _flashinfer_comm is None: | ||
| logger.warning( | ||
| "FlashInfer comm not available, skipping workspace " "initialization" | ||
| ) | ||
| return | ||
|
|
||
| self.cleanup() | ||
|
|
||
| self.ipc_handles, self.workspace_tensor = ( | ||
| comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info( | ||
| f"FlashInfer workspace initialized for rank {rank}, " | ||
| f"world_size {world_size}" | ||
| ) | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( | ||
| self.ipc_handles, group=dist.group.WORLD | ||
| ) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized( | ||
| max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False | ||
| ): | ||
| """Ensure workspace is initialized""" | ||
| if not is_flashinfer_available() or _flashinfer_comm is None: | ||
| return False | ||
|
|
||
| world_size = get_tensor_model_parallel_world_size() | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if ( | ||
| not _workspace_manager.initialized | ||
| or _workspace_manager.world_size != world_size | ||
| ): | ||
| _workspace_manager.initialize( | ||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
|
|
||
| def flashinfer_allreduce_add_rmsnorm( | ||
| input_tensor: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 1024, | ||
| use_oneshot: bool = True, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
|
|
||
| Args: | ||
| input_tensor: Input tensor that needs allreduce | ||
| residual: Residual tensor | ||
| weight: RMS norm weight | ||
| eps: RMS norm epsilon | ||
| max_token_num: Maximum token number | ||
| use_oneshot: Whether to use oneshot mode | ||
| trigger_completion_at_end: Whether to trigger completion at end | ||
| fp32_acc: Whether to use fp32 precision | ||
|
|
||
| Returns: | ||
| Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) | ||
| """ | ||
| if not is_flashinfer_available() or _flashinfer_comm is None: | ||
| logger.debug( | ||
| "FlashInfer not available, falling back to standard " "implementation" | ||
| ) | ||
| return None, None | ||
|
|
||
| world_size = get_tensor_model_parallel_world_size() | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == torch.float32), | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = torch.empty_like(residual) | ||
| norm_out = torch.empty_like(input_tensor) | ||
|
|
||
| _flashinfer_comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() |
Uh oh!
There was an error while loading. Please reload this page.