-
Notifications
You must be signed in to change notification settings - Fork 5.4k
[AllReduce] FlashInfer: add mnnvl backend selection and standalone TP path #19586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mmangkad
wants to merge
4
commits into
sgl-project:main
Choose a base branch
from
mmangkad-dev:fi-ar-mnnvl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
206 changes: 206 additions & 0 deletions
206
python/sglang/srt/distributed/device_communicators/flashinfer_all_reduce.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| import logging | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed import ProcessGroup | ||
|
|
||
| from sglang.srt.distributed.device_communicators.flashinfer_utils import ( | ||
| create_mnnvl_comm_backend, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _flashinfer_comm = None | ||
| _flashinfer_ar_available = False | ||
| try: | ||
| import flashinfer.comm as flashinfer_comm | ||
|
|
||
| if hasattr(flashinfer_comm, "allreduce_fusion") and hasattr( | ||
| flashinfer_comm, "create_allreduce_fusion_workspace" | ||
| ): | ||
| _flashinfer_comm = flashinfer_comm | ||
| _flashinfer_ar_available = True | ||
| except ImportError: | ||
| pass | ||
|
|
||
| MiB = 1024 * 1024 | ||
|
|
||
| # Max size of the communicated tensor by world size and GPU capability. | ||
| # Adopted from vLLM thresholds. | ||
| # TODO(mmangkad): Tune these thresholds for SGLang, since optimal values may | ||
| # differ from vLLM based on runtime/scheduling behavior. | ||
| _FI_ALLREDUCE_MAX_SIZE_MB: dict[int, dict[int, float]] = { | ||
| 90: { | ||
| 2: 64, | ||
| 4: 2, | ||
| 8: 0.5, | ||
| }, | ||
| 100: { | ||
| 2: 64, | ||
| 4: 32, | ||
| 8: 1, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| def _get_device_capability() -> Optional[int]: | ||
| if not torch.cuda.is_available(): | ||
| return None | ||
| major, minor = torch.cuda.get_device_capability() | ||
| return major * 10 + minor | ||
|
|
||
|
|
||
| class FlashInferAllReduce: | ||
| def __init__( | ||
| self, | ||
| group: ProcessGroup, | ||
| device: torch.device, | ||
| backend: str = "auto", | ||
| ): | ||
| self.disabled = True | ||
| self.workspace = None | ||
| self.max_num_tokens = 0 | ||
| self.max_workspace_size = None | ||
| self.hidden_dim = None | ||
| self.dtype = None | ||
|
|
||
| if not _flashinfer_ar_available or _flashinfer_comm is None: | ||
| logger.info( | ||
| "FlashInfer allreduce disabled: flashinfer comm API unavailable." | ||
| ) | ||
| return | ||
|
|
||
| if not torch.cuda.is_available(): | ||
| logger.info("FlashInfer allreduce disabled: CUDA is unavailable.") | ||
| return | ||
|
|
||
| self.group = group | ||
| self.world_size = dist.get_world_size(group=self.group) | ||
| self.rank = dist.get_rank(group=self.group) | ||
| self.device = device | ||
| self.backend = backend | ||
|
|
||
| if self.world_size == 1: | ||
| return | ||
|
|
||
| capability = _get_device_capability() | ||
| self.max_workspace_size = _FI_ALLREDUCE_MAX_SIZE_MB.get(capability, {}).get( | ||
| self.world_size | ||
| ) | ||
| if self.max_workspace_size is None: | ||
| logger.warning( | ||
| "FlashInfer allreduce disabled: unsupported world_size=%d for SM=%s.", | ||
| self.world_size, | ||
| str(capability), | ||
| ) | ||
| return | ||
|
|
||
| self.max_workspace_size = int(self.max_workspace_size * MiB) | ||
| self.disabled = False | ||
|
|
||
| def _create_workspace( | ||
| self, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| dtype: torch.dtype, | ||
| ) -> bool: | ||
| assert _flashinfer_comm is not None | ||
|
|
||
| workspace_kwargs = dict( | ||
| backend=self.backend, | ||
| world_size=self.world_size, | ||
| rank=self.rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| if self.backend in ("auto", "mnnvl"): | ||
| comm_backend = create_mnnvl_comm_backend(self.group) | ||
| if comm_backend is not None: | ||
| workspace_kwargs["comm_backend"] = comm_backend | ||
|
|
||
| self.workspace = _flashinfer_comm.create_allreduce_fusion_workspace( | ||
| **workspace_kwargs | ||
| ) | ||
| self.hidden_dim = hidden_dim | ||
| self.dtype = dtype | ||
| return self.workspace is not None | ||
|
|
||
| def _ensure_workspace( | ||
| self, | ||
| num_tokens: int, | ||
| hidden_dim: int, | ||
| dtype: torch.dtype, | ||
| ) -> bool: | ||
| if self.workspace is not None: | ||
| if self.hidden_dim == hidden_dim and self.dtype == dtype: | ||
| try: | ||
| if self.workspace.is_buffer_size_sufficient( | ||
| tp_size=self.world_size, | ||
| num_tokens=num_tokens, | ||
| hidden_dim=hidden_dim, | ||
| dtype=dtype, | ||
| ): | ||
| return True | ||
| except Exception as e: | ||
| logger.debug( | ||
| "FlashInfer workspace size check failed; recreating workspace: %s", | ||
| e, | ||
| ) | ||
| self.destroy() | ||
|
|
||
| assert self.max_workspace_size is not None | ||
| element_size = torch.tensor([], dtype=dtype, device="cpu").element_size() | ||
| max_tokens = self.max_workspace_size // (hidden_dim * element_size) | ||
| if max_tokens <= 0 or num_tokens > max_tokens: | ||
| return False | ||
|
|
||
| self.max_num_tokens = max_tokens | ||
| try: | ||
| return self._create_workspace( | ||
| max_token_num=max_tokens, | ||
| hidden_dim=hidden_dim, | ||
| dtype=dtype, | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| "Failed to initialize FlashInfer allreduce workspace: %s. " | ||
| "Disabling FlashInfer allreduce.", | ||
| e, | ||
| ) | ||
| self.disabled = True | ||
| self.workspace = None | ||
| return False | ||
|
|
||
| def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool: | ||
| if self.disabled: | ||
| return False | ||
|
|
||
| if not input_tensor.is_cuda or not input_tensor.is_contiguous(): | ||
| return False | ||
|
|
||
| if len(input_tensor.shape) != 2: | ||
| return False | ||
|
|
||
| num_tokens, hidden_dim = input_tensor.shape | ||
| return self._ensure_workspace(num_tokens, hidden_dim, input_tensor.dtype) | ||
|
|
||
| def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor: | ||
| assert _flashinfer_comm is not None | ||
| return _flashinfer_comm.allreduce_fusion( | ||
| input=input_tensor, | ||
| workspace=self.workspace, | ||
| pattern=_flashinfer_comm.AllReduceFusionPattern.kAllReduce, | ||
| ) | ||
|
|
||
| def destroy(self): | ||
| if self.workspace is not None: | ||
| try: | ||
| self.workspace.destroy() | ||
| except Exception as e: | ||
| logger.debug("Failed to destroy FlashInfer workspace: %s", e) | ||
| self.workspace = None | ||
| self.hidden_dim = None | ||
| self.dtype = None |
53 changes: 53 additions & 0 deletions
53
python/sglang/srt/distributed/device_communicators/flashinfer_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import torch.distributed as dist | ||
|
|
||
| from sglang.srt.utils import is_flashinfer_available | ||
|
|
||
| if is_flashinfer_available(): | ||
| try: | ||
| from flashinfer.comm.mnnvl import CommBackend | ||
| except ImportError: | ||
| CommBackend = object # type: ignore[assignment,misc] | ||
| else: | ||
|
|
||
| class CommBackend: | ||
| """Placeholder base class when flashinfer is not available.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| def create_mnnvl_comm_backend(group: dist.ProcessGroup): | ||
| """Create a mnnvl comm backend backed by torch.distributed process group.""" | ||
| try: | ||
| from flashinfer.comm.mnnvl import TorchDistBackend | ||
|
|
||
| return TorchDistBackend(group=group) | ||
| except Exception: | ||
| pass | ||
|
|
||
| class TorchDistributedCommBackend(CommBackend): | ||
| def __init__(self, group_: dist.ProcessGroup): | ||
| self._group = group_ | ||
|
|
||
| def Get_rank(self) -> int: | ||
| return self._group.rank() | ||
|
|
||
| def Get_size(self) -> int: | ||
| return self._group.size() | ||
|
|
||
| def allgather(self, data: int): | ||
| gathered = [None] * self.Get_size() | ||
| dist.all_gather_object(gathered, data, group=self._group) | ||
| return gathered | ||
|
|
||
| def bcast(self, data, root: int = 0): | ||
| obj_list = [data] | ||
| dist.broadcast_object_list(obj_list, src=root, group=self._group) | ||
| return obj_list[0] | ||
|
|
||
| def Split(self, color: int, key: int): | ||
| return self | ||
|
|
||
| def barrier(self): | ||
| dist.barrier(group=self._group) | ||
|
|
||
| return TorchDistributedCommBackend(group) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.