-
Notifications
You must be signed in to change notification settings - Fork 993
Allreduce auto backend improvements #2239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5c2dcff
Check fabric handle support
d421a67
one_shot in workspace creation, removed topology
nvmbreughe 5285e32
Backward compatibility, test integration
nvmbreughe 0f83701
Added documentation
nvmbreughe 9dd38fe
Addressed bot reviewer comments
nvmbreughe 1e42bb6
Remove MPI dependency from MNNVL allreduce tests
nvmbreughe c7a6ead
Update flashinfer/comm/trtllm_ar.py
nvmbreughe f6b0edc
Update flashinfer/comm/trtllm_ar.py
nvmbreughe 4c433fd
Merge branch 'main' into unified_ar_auto
nvmbreughe 9296de1
Added comment on how to run the test
nvmbreughe 1cb0870
Typo
nvmbreughe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -55,17 +55,17 @@ | |||||
|
|
||||||
| from .trtllm_ar import trtllm_allreduce_fusion | ||||||
| from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion | ||||||
| from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion | ||||||
| from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata | ||||||
|
|
||||||
| from .mapping import Mapping | ||||||
|
|
||||||
| from .mnnvl import CommBackend | ||||||
| from .mnnvl import CommBackend, SymmDeviceMemory | ||||||
|
|
||||||
| # Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) | ||||||
| # Import them for runtime use but type hint as int for mypy compatibility | ||||||
| from .trtllm_ar import AllReduceFusionPattern | ||||||
| from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace | ||||||
| from .trtllm_mnnvl_ar import MNNVLAllreduceFusionStrategy | ||||||
| from .trtllm_mnnvl_ar import trtllm_mnnvl_allreduce | ||||||
| from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_add_rmsnorm | ||||||
|
|
||||||
|
|
@@ -95,7 +95,7 @@ def __init__( | |||||
| max_token_num: int, | ||||||
| hidden_dim: int, | ||||||
| dtype: torch.dtype = torch.float16, | ||||||
| process_group: Optional["torch.distributed.ProcessGroup"] = None, | ||||||
| comm_backend: Optional[CommBackend] = None, | ||||||
| ): | ||||||
| """ | ||||||
| Create TensorRT-LLM AllReduce fusion workspace. | ||||||
|
|
@@ -106,7 +106,7 @@ def __init__( | |||||
| max_token_num: Maximum number of tokens | ||||||
| hidden_dim: Hidden dimension size | ||||||
| dtype: Data type | ||||||
| process_group: PyTorch distributed process group | ||||||
| comm_backend: Communication backend | ||||||
| **kwargs: Additional arguments for workspace creation | ||||||
| """ | ||||||
| super().__init__(tp_size, tp_rank) | ||||||
|
|
@@ -117,19 +117,22 @@ def __init__( | |||||
| tp_size=tp_size, | ||||||
| max_token_num=max_token_num, | ||||||
| hidden_dim=hidden_dim, | ||||||
| group=process_group, | ||||||
| comm_backend=comm_backend, | ||||||
| create_metadata=True, | ||||||
| use_fp32_lamport=dtype == torch.float32, | ||||||
| use_symm_dev_mem=True, | ||||||
| ) | ||||||
|
|
||||||
| # Store essential attributes for easy access | ||||||
| # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True | ||||||
| workspace_tuple = cast( | ||||||
| Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace | ||||||
| Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], | ||||||
| self._internal_workspace, | ||||||
| ) | ||||||
| self.ipc_handles = workspace_tuple[0] | ||||||
| self.workspace_tensor = workspace_tuple[1] | ||||||
| self.metadata = workspace_tuple[2] | ||||||
| self.mem_handles = workspace_tuple[2] | ||||||
| self.metadata = workspace_tuple[3] | ||||||
|
|
||||||
| @property | ||||||
| def backend(self) -> str: | ||||||
|
|
@@ -165,7 +168,10 @@ def destroy(self) -> None: | |||||
| if getattr(self, "_destroyed", False): | ||||||
| return # Already destroyed, nothing to do | ||||||
|
|
||||||
| trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) | ||||||
| del self.ipc_handles | ||||||
| del self.workspace_tensor | ||||||
| del self.mem_handles | ||||||
| del self.metadata | ||||||
| self._destroyed = True | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -181,20 +187,14 @@ def _trtllm_workspace_check( | |||||
| max_token_num: int, | ||||||
| hidden_dim: int, | ||||||
| dtype: torch.dtype, | ||||||
| topology: Literal["single_node", "multi_node"], | ||||||
| ) -> bool: | ||||||
| """ | ||||||
| Check if trtllm backend CAN be used for workspace creation. | ||||||
|
|
||||||
| Hard requirements: | ||||||
| - Single-node topology (multi-node not supported) | ||||||
|
|
||||||
| - Up to 16 ranks supported. | ||||||
| """ | ||||||
| # trtllm is optimized for single-node | ||||||
| if topology == "multi_node": | ||||||
| return False | ||||||
|
|
||||||
| return True | ||||||
| return world_size <= 16 | ||||||
|
|
||||||
|
|
||||||
| def _mnnvl_workspace_check( | ||||||
|
|
@@ -204,16 +204,12 @@ def _mnnvl_workspace_check( | |||||
| max_token_num: int, | ||||||
| hidden_dim: int, | ||||||
| dtype: torch.dtype, | ||||||
| topology: Literal["single_node", "multi_node"], | ||||||
| ) -> bool: | ||||||
| """ | ||||||
| Check if mnnvl backend CAN be used for workspace creation. | ||||||
|
|
||||||
| """ | ||||||
|
|
||||||
| if topology == "multi_node": | ||||||
| return True | ||||||
|
|
||||||
| return True | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -230,7 +226,6 @@ def _workspace_creation_heuristic( | |||||
| max_token_num: int, | ||||||
| hidden_dim: int, | ||||||
| dtype: torch.dtype, | ||||||
| topology: Literal["single_node", "multi_node"], | ||||||
| ) -> list[str]: | ||||||
| """ | ||||||
| Select best backend for workspace creation based on performance. | ||||||
|
|
@@ -246,7 +241,6 @@ def _workspace_creation_heuristic( | |||||
| max_token_num: Maximum number of tokens | ||||||
| hidden_dim: Hidden dimension size | ||||||
| dtype: Data type | ||||||
| topology: Network topology ("single_node" or "multi_node") | ||||||
| **kwargs: Additional arguments | ||||||
|
|
||||||
| Note that at this point, the backend selection does not take "runtime parameters" into account, such as layout_code, and fusion pattern. | ||||||
|
|
@@ -262,13 +256,6 @@ def _workspace_creation_heuristic( | |||||
|
|
||||||
| # Decision tree based on benchmark data | ||||||
|
|
||||||
| # Multi-node: MNNVL is designed for this | ||||||
| if topology == "multi_node": | ||||||
| if "mnnvl" in suitable_backends: | ||||||
| return ["mnnvl"] | ||||||
| else: | ||||||
| return [suitable_backends[0]] | ||||||
|
|
||||||
| # Single-node scenarios | ||||||
| # From benchmarking data, we can see that MNNVL is either on par (smaller problem sizes) or significantly faster than TRTLLM (larger problem sizes such as hidden_dim=8192, token_num=64 for TP=4), for single-node scenarios. | ||||||
| # However, trtllm has a larger support surface (more fusion patterns, more quantization support, etc.) | ||||||
|
|
@@ -290,10 +277,9 @@ def create_allreduce_fusion_workspace( | |||||
| max_token_num: int = None, | ||||||
| hidden_dim: int = None, | ||||||
| dtype: torch.dtype = None, | ||||||
| topology: Literal["single_node", "multi_node"] = "single_node", | ||||||
| process_group: Optional["torch.distributed.ProcessGroup"] = None, | ||||||
| gpus_per_node: int = None, | ||||||
| comm_backend: Optional[CommBackend] = None, | ||||||
| force_oneshot_support: bool = False, | ||||||
| ) -> AllReduceFusionWorkspace: | ||||||
| """ | ||||||
| Create workspace for AllReduce fusion operations. | ||||||
|
|
@@ -315,19 +301,19 @@ def create_allreduce_fusion_workspace( | |||||
|
|
||||||
| Args: | ||||||
| backend: Backend to use ("trtllm", "mnnvl", or "auto") | ||||||
| "auto" uses heuristic to select best backend based on topology | ||||||
| and problem size | ||||||
| "auto" uses heuristic to select best backend | ||||||
| world_size: Number of ranks in the process group | ||||||
| rank: Current rank ID | ||||||
| max_token_num: Maximum number of tokens to support | ||||||
| hidden_dim: Hidden dimension size | ||||||
| dtype: Data type for communication tensors | ||||||
| topology: Network topology hint for backend selection | ||||||
| "single_node" - All ranks on one node (default) | ||||||
| "multi_node" - Ranks span multiple nodes | ||||||
| process_group: PyTorch distributed process group (for trtllm backend). | ||||||
| gpus_per_node: Number of GPUs per node (for multi-node topology). | ||||||
| comm_backend: Communication backend to use (for multi-node topology). | ||||||
| comm_backend: Communication backend to use. | ||||||
| force_oneshot_support: Allocate workspace for oneshot strategy vs twoshot | ||||||
| True: Allocate workspace for oneshot strategy up to the largest problem size requested | ||||||
| False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. | ||||||
| Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. | ||||||
| The trtllm backend will be sufficient for both strategies. | ||||||
|
|
||||||
| Returns: | ||||||
| Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) | ||||||
|
|
@@ -338,15 +324,14 @@ def create_allreduce_fusion_workspace( | |||||
| ValueError: If problem size not supported for the specified backend | ||||||
|
|
||||||
| Examples: | ||||||
| >>> # Auto-select best backend based on topology | ||||||
| >>> # Auto-select best backend | ||||||
| >>> workspace = create_allreduce_fusion_workspace( | ||||||
| ... backend="auto", | ||||||
| ... world_size=8, | ||||||
| ... rank=0, | ||||||
| ... max_token_num=2048, | ||||||
| ... hidden_dim=4096, | ||||||
| ... dtype=torch.bfloat16, | ||||||
| ... topology="single_node" | ||||||
| ... ) | ||||||
| >>> print(workspace.backend) # "trtllm" | ||||||
| >>> print(workspace.get_workspace_capacity()) # 8388608 elements | ||||||
|
|
@@ -363,15 +348,14 @@ def create_allreduce_fusion_workspace( | |||||
| ... max_token_num=2048, | ||||||
| ... hidden_dim=4096, | ||||||
| ... dtype=torch.bfloat16, | ||||||
| ... topology="multi_node" | ||||||
| ... ) | ||||||
| >>> print(workspace.backend) # "mnnvl" | ||||||
| """ | ||||||
| if gpus_per_node is None: | ||||||
| gpus_per_node = min(torch.cuda.device_count(), world_size) | ||||||
| # Determine the actual backend to use | ||||||
| if backend == "auto": | ||||||
| # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point) | ||||||
| # Find suitable backends (any compute capability check needs to be checked at kernel runtime, since there are no tensor available at this point) | ||||||
| suitable_backends = [] | ||||||
| if _trtllm_workspace_check( | ||||||
| backend=backend, | ||||||
|
|
@@ -380,7 +364,6 @@ def create_allreduce_fusion_workspace( | |||||
| max_token_num=max_token_num, | ||||||
| hidden_dim=hidden_dim, | ||||||
| dtype=dtype, | ||||||
| topology=topology, | ||||||
| ): | ||||||
| suitable_backends.append("trtllm") | ||||||
| if _mnnvl_workspace_check( | ||||||
|
|
@@ -390,15 +373,11 @@ def create_allreduce_fusion_workspace( | |||||
| max_token_num=max_token_num, | ||||||
| hidden_dim=hidden_dim, | ||||||
| dtype=dtype, | ||||||
| topology=topology, | ||||||
| ): | ||||||
| suitable_backends.append("mnnvl") | ||||||
|
|
||||||
| if not suitable_backends: | ||||||
| raise ValueError( | ||||||
| f"No suitable backend found for topology={topology}. " | ||||||
| f"trtllm requires single_node topology, mnnvl works with both." | ||||||
| ) | ||||||
| raise ValueError("No suitable backend found. ") | ||||||
|
|
||||||
| # Apply heuristic to select best backend | ||||||
| selected = _workspace_creation_heuristic( | ||||||
|
|
@@ -409,7 +388,6 @@ def create_allreduce_fusion_workspace( | |||||
| max_token_num=max_token_num, | ||||||
| hidden_dim=hidden_dim, | ||||||
| dtype=dtype, | ||||||
| topology=topology, | ||||||
| ) | ||||||
| actual_backend = selected[0] | ||||||
| else: | ||||||
|
|
@@ -423,7 +401,7 @@ def create_allreduce_fusion_workspace( | |||||
| max_token_num=max_token_num, | ||||||
| hidden_dim=hidden_dim, | ||||||
| dtype=dtype, | ||||||
| process_group=process_group, | ||||||
| comm_backend=comm_backend, | ||||||
| ) | ||||||
|
|
||||||
| elif actual_backend == "mnnvl": | ||||||
|
|
@@ -433,12 +411,25 @@ def create_allreduce_fusion_workspace( | |||||
| gpus_per_node=gpus_per_node, | ||||||
| tp_size=world_size, | ||||||
| ) | ||||||
| buffer_size_in_bytes = None | ||||||
| if force_oneshot_support: | ||||||
| buffer_size_in_bytes = ( | ||||||
| MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes( | ||||||
| world_size, | ||||||
| max_token_num, | ||||||
| hidden_dim, | ||||||
| dtype, | ||||||
| MNNVLAllreduceFusionStrategy.ONESHOT, | ||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| return MNNVLAllReduceFusionWorkspace( | ||||||
| mapping=mapping, | ||||||
| max_num_tokens=max_token_num, | ||||||
| hidden_dim=hidden_dim, | ||||||
| dtype=dtype, | ||||||
| comm_backend=comm_backend, | ||||||
| buffer_size_in_bytes=buffer_size_in_bytes, | ||||||
| ) | ||||||
| else: | ||||||
| raise RuntimeError(f"Unknown backend: {actual_backend}") | ||||||
|
|
@@ -514,7 +505,7 @@ def allreduce_fusion( | |||||
| # ===== Control parameters ===== | ||||||
| use_oneshot: Use oneshot strategy vs twoshot | ||||||
| If None, uses internal heuristics. | ||||||
| Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. | ||||||
| Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This note is a bit vague. To improve clarity, consider explicitly mentioning the
Suggested change
|
||||||
| fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce | ||||||
|
|
||||||
| Returns: | ||||||
|
|
@@ -529,7 +520,6 @@ def allreduce_fusion( | |||||
| ... max_token_num=2048, | ||||||
| ... hidden_dim=4096, | ||||||
| ... dtype=torch.bfloat16, | ||||||
| ... topology="single_node" | ||||||
| ... ) | ||||||
| >>> | ||||||
| >>> # Pre-allocate output tensors | ||||||
|
|
||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.