-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Support loading weights from remote instance #8215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
hnyls2002
merged 2 commits into
sgl-project:main
from
amysaq2023:amy/support-loading-weights-from-remote-instance
Sep 12, 2025
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import logging | ||
| from typing import Generator, List, Optional, Tuple | ||
| from urllib.parse import urlparse | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from sglang.srt.connector import BaseConnector | ||
| from sglang.srt.utils import init_custom_process_group | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class RemoteInstanceConnector(BaseConnector): | ||
|
|
||
| def __init__(self, url: str, device: torch.device = "cpu"): | ||
| assert ( | ||
| device.type == "cuda" | ||
| ), "RemoteInstanceConnector only supports cuda device." | ||
| super().__init__(url) | ||
| self.url = url | ||
| self.device = device | ||
|
|
||
| def build_group( | ||
| self, | ||
| gpu_id: int = -1, | ||
| tp_rank: int = -1, | ||
| instance_ip: str = None, | ||
| group_rank: int = 1, | ||
| world_size: int = 2, | ||
| ): | ||
| assert ( | ||
| self.device.type == "cuda" | ||
| ), "RemoteInstanceConnector only supports cuda device." | ||
| assert ( | ||
| gpu_id != -1 and tp_rank != -1 | ||
| ), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. " | ||
|
|
||
| self.device_id = torch.device(self.device.type, gpu_id) | ||
|
|
||
| parsed_url = urlparse(self.url) | ||
| master_address = parsed_url.hostname | ||
| master_port = parsed_url.port | ||
| group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}" | ||
| backend = "nccl" | ||
|
|
||
| logger.info( | ||
| f"init custom process group: master_address={master_address}, master_port={master_port}, " | ||
| f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}" | ||
| ) | ||
|
|
||
| try: | ||
| self._model_update_group = init_custom_process_group( | ||
| backend=backend, | ||
| init_method=f"tcp://{master_address}:{master_port}", | ||
| world_size=world_size, | ||
| rank=group_rank, | ||
| group_name=group_name, | ||
| device_id=self.device_id, | ||
| ) | ||
| dist.barrier(group=self._model_update_group) | ||
| return True, "Succeeded to initialize custom process group." | ||
| except Exception as e: | ||
| message = f"Failed to initialize custom process group: {e}." | ||
| logger.error(message) | ||
| return False, message | ||
|
|
||
| # Implemented as a no-op to make BaseConnector interface consistent. | ||
| def pull_files( | ||
| self, | ||
| allow_pattern: Optional[list[str]] = None, | ||
| ignore_pattern: Optional[list[str]] = None, | ||
| ) -> None: | ||
| return | ||
|
|
||
| # Implemented as a no-op to make BaseConnector interface consistent. | ||
| def weight_iterator( | ||
| self, rank: int = 0 | ||
| ) -> Generator[Tuple[str, torch.Tensor], None, None]: | ||
| return | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the nccl backend be configurable via a new parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reasonable suggestion!