Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/sglang/srt/configs/device_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

class DeviceConfig:
device: Optional[torch.device]
gpu_id: Optional[int]

def __init__(self, device: str = "cuda") -> None:
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)
self.gpu_id = gpu_id
1 change: 1 addition & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance"


@dataclass
Expand Down
19 changes: 19 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,28 @@ def __init__(
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.model_impl = model_impl
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)

self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
Expand Down Expand Up @@ -320,6 +336,9 @@ def from_server_args(
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)

Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
BaseKVConnector,
)
from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.remote_instance import RemoteInstanceConnector
from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type

Expand All @@ -18,14 +19,17 @@
class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"
INSTANCE = "instance"


def create_remote_connector(url, **kwargs) -> BaseConnector:
def create_remote_connector(url, device, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
elif connector_type == "s3":
return S3Connector(url)
elif connector_type == "instance":
return RemoteInstanceConnector(url, device)
else:
raise ValueError(f"Invalid connector type: {url}")

Expand All @@ -35,6 +39,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType:
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS
if isinstance(client, RemoteInstanceConnector):
return ConnectorType.INSTANCE

raise ValueError(f"Invalid connector type: {client}")

Expand All @@ -44,6 +50,7 @@ def get_connector_type(client: BaseConnector) -> ConnectorType:
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"RemoteInstanceConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
Expand Down
82 changes: 82 additions & 0 deletions python/sglang/srt/connector/remote_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Generator, List, Optional, Tuple
from urllib.parse import urlparse

import torch
import torch.distributed as dist

from sglang.srt.connector import BaseConnector
from sglang.srt.utils import init_custom_process_group

logger = logging.getLogger(__name__)


class RemoteInstanceConnector(BaseConnector):

def __init__(self, url: str, device: torch.device = "cpu"):
assert (
device.type == "cuda"
), "RemoteInstanceConnector only supports cuda device."
super().__init__(url)
self.url = url
self.device = device

def build_group(
self,
gpu_id: int = -1,
tp_rank: int = -1,
instance_ip: str = None,
group_rank: int = 1,
world_size: int = 2,
):
assert (
self.device.type == "cuda"
), "RemoteInstanceConnector only supports cuda device."
assert (
gpu_id != -1 and tp_rank != -1
), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. "

self.device_id = torch.device(self.device.type, gpu_id)

parsed_url = urlparse(self.url)
master_address = parsed_url.hostname
master_port = parsed_url.port
group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}"
backend = "nccl"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the nccl backend be configurable via a new parameter?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasonable suggestion!


logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
)

try:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=group_rank,
group_name=group_name,
device_id=self.device_id,
)
dist.barrier(group=self._model_update_group)
return True, "Succeeded to initialize custom process group."
except Exception as e:
message = f"Failed to initialize custom process group: {e}."
logger.error(message)
return False, message

# Implemented as a no-op to make BaseConnector interface consistent.
def pull_files(
self,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
return

# Implemented as a no-op to make BaseConnector interface consistent.
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
return
34 changes: 34 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,15 @@
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
SendWeightsToRemoteInstanceReqInput,
SeparateReasoningReqInput,
SetInternalStateReq,
SlowDownReqInput,
Expand Down Expand Up @@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
)


@app.post("/init_weights_send_group_for_remote_instance")
async def init_weights_send_group_for_remote_instance(
obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request
):
success, message = (
await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance(
obj, request
)
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.post("/send_weights_to_remote_instance")
async def send_weights_to_remote_instance(
obj: SendWeightsToRemoteInstanceReqInput, request: Request
):
success, message = (
await _global_state.tokenizer_manager.send_weights_to_remote_instance(
obj, request
)
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.post("/init_weights_update_group")
async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request
Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,44 @@ class UpdateWeightsFromTensorReqOutput:
message: str


@dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput:
# The master address
master_address: str
# The ports for each rank's communication group
ports: str
# The rank in the communication group
group_rank: int
# The world size
world_size: int
# The group name
group_name: str = "weight_send_group"
# The backend
backend: str = "nccl"


@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput:
success: bool
message: str


@dataclass
class SendWeightsToRemoteInstanceReqInput:
# The master address
master_address: str
# The ports for each rank's communication group
ports: str
# The group name
group_name: str = "weight_send_group"


@dataclass
class SendWeightsToRemoteInstanceReqOutput:
success: bool
message: str


@dataclass
class InitWeightsUpdateGroupReqInput:
# The master address
Expand Down
28 changes: 28 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
HealthCheckOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsSendGroupForRemoteInstanceReqOutput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
Expand All @@ -93,6 +95,8 @@
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
SendWeightsToRemoteInstanceReqInput,
SendWeightsToRemoteInstanceReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
Expand Down Expand Up @@ -538,6 +542,14 @@ def __init__(
(CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
InitWeightsSendGroupForRemoteInstanceReqInput,
self.init_weights_send_group_for_remote_instance,
),
(
SendWeightsToRemoteInstanceReqInput,
self.send_weights_to_remote_instance,
),
(
UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed,
Expand Down Expand Up @@ -2424,6 +2436,22 @@ def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req

def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
"""Init the seed and client instance communication group."""
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
recv_req
)
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)

def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
"""Send the seed instance weights to the destination instance."""
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
return SendWeightsToRemoteInstanceReqOutput(success, message)

def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
Expand Down
Loading
Loading