Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions python/sglang/srt/checkpoint_engine_worker.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this file can be moved to the checkpoint-engine/example folder

Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Checkpoint-engine integration for SGLang.
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
"""
import gc
import logging
from typing import Callable, Dict, List, Optional, Tuple, TypedDict

import torch
import zmq

logger = logging.getLogger(__name__)


class FlattenedTensorMetadata(TypedDict):
name: str
shape: torch.Size
dtype: torch.dtype
# specify the start offset of this tensor in shared ipc_buffer tensor
offset: int


def _rebuild_ipc(
handle: tuple[Callable, tuple], device_id: Optional[int] = None
) -> torch.Tensor:
"""Rebuild a tensor from IPC handle, adapting to current device."""
func, args = handle
list_args = list(args)
if device_id is not None:
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
buffer = func(*list_args)
return buffer


def _extract_weights(
payload: List[FlattenedTensorMetadata], buffer: torch.Tensor
) -> List[Tuple[str, torch.Tensor]]:
"""Extract named weights from flattened buffer based on metadata."""
assert buffer is not None
weights: List[Tuple[str, torch.Tensor]] = []
for item in payload:
shape = item["shape"]
if isinstance(shape, (list, tuple)):
shape = torch.Size(shape)
assert isinstance(shape, torch.Size)
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
weights.append((item["name"], tensor))
return weights


def update_weights_from_ipc(
zmq_ctx: zmq.Context,
zmq_handle: str,
device_id: int,
*,
run: Callable[[List[Tuple[str, torch.Tensor]]], None],
post_hook: Callable[[], None] = None,
):
"""
Core IPC weight update logic for SGLang.
Args:
zmq_ctx: ZMQ context for communication
zmq_handle: ZMQ socket path for this device
device_id: Current device ID
run: Function to apply weights to the model (model.load_weights)
post_hook: Optional post-processing function
"""
socket = zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handle)
buffer: Optional[torch.Tensor] = None
logger.info(
f"Starting IPC weight update on device {device_id}, socket: {zmq_handle}"
)
try:
while True:
payload: tuple[Callable, tuple] | List[FlattenedTensorMetadata] | None = (
socket.recv_pyobj()
)
if payload is None:
# means the update is done
logger.info(f"Weight update complete on device {device_id}")
if post_hook is not None:
post_hook()
torch.cuda.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
# an ipc handle that we can use to rebuild GPU tensor
logger.debug(f"Received IPC handle on device {device_id}")
buffer = _rebuild_ipc(payload, device_id)
assert buffer.dtype == torch.uint8
socket.send(b"")
continue
assert isinstance(payload, list)
# weight metadata list - extract and load weights
logger.debug(
f"Received {len(payload)} weight tensors on device {device_id}"
)
weights = _extract_weights(payload, buffer)
run(weights)
torch.cuda.synchronize()
socket.send(b"")
except Exception as e:
logger.error(f"Error in IPC weight update on device {device_id}: {e}")
raise
finally:
socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
logger.info(f"Cleaned up IPC weight update on device {device_id}")


class SGLangCheckpointEngineWorkerExtension:
"""
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
This class provides the interface needed for checkpoint-engine integration.
"""

def __init__(self):
self._zmq_ctx: Optional[zmq.Context] = None

def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# We need to implement this to get the device UUID
# This will be overridden when integrated into SGLang's worker
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)

def get_device_id(self) -> int:
"""Get the device ID."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)

def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)

def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""
return None

def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
"""
Update weights from IPC communication.
Args:
zmq_handles: Dict mapping device UUID to ZMQ socket path
"""
if self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
device_uuid = self.get_device_uuid()
device_id = self.get_device_id()
if device_uuid not in zmq_handles:
raise ValueError(
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
)
update_weights_from_ipc(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe directly use from checkpoint_engine.worker import update_weights_from_ipc instead of copy duplicated codes

self._zmq_ctx,
zmq_handles[device_uuid],
device_id=device_id,
run=self.get_model_loader(),
post_hook=self.get_post_hook(),
)
16 changes: 16 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
Expand Down Expand Up @@ -653,6 +654,21 @@ async def async_score(
request=None,
)

def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)


def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
Expand Down
66 changes: 62 additions & 4 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
UpdateWeightVersionReqInput,
VertexGenerateReqInput,
Expand Down Expand Up @@ -778,6 +779,19 @@ async def update_weights_from_distributed(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.post("/update_weights_from_ipc")
async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):
"""Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration."""
success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(
obj, request
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.post("/update_weight_version")
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
"""Update the weight version. This operation requires no active requests."""
Expand Down Expand Up @@ -1174,10 +1188,54 @@ def _update_weight_version_if_provided(weight_version: Optional[str]) -> None:
_global_state.tokenizer_manager.server_args.weight_version = weight_version


def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/collective_rpc")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this method is unnecessary, client may use /update_weights_from_ipc is enough

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping this api can directly reuse the scripts in checkpoint-engine's examples, allowing users to quickly test the sglang integration. If we remove this API interface, we would need to add a sglang example script, leading to more duplicate code. Perhaps keeping the interface is better?

async def collective_rpc(request: Request):
"""Collective RPC endpoint for compatibility with checkpoint-engine (similar to vLLM)."""
try:
body = await request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail=f"JSON decode error: {e}"
)
method = body.get("method")
if method is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Missing 'method' in request body",
)
# Handle the update_weights_from_ipc method specifically
if method == "update_weights_from_ipc":
args = body.get("args", [])
if not args or not isinstance(args[0], dict):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Invalid args for update_weights_from_ipc",
)
zmq_handles = args[0]
success, message = (
await _global_state.tokenizer_manager.update_weights_from_ipc(
UpdateWeightsFromIPCReqInput(zmq_handles=zmq_handles), request
)
)
if success:
return ORJSONResponse(
content={"results": [{"success": True, "message": message}]}
)
else:
return ORJSONResponse(
content={"results": [{"success": False, "message": message}]},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
else:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED,
detail=f"Method '{method}' not implemented in SGLang collective_rpc",
)


async def http_health(request: Request):
"""Check the health of the http server."""
return Response(status_code=200)


def launch_server(
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,20 @@ class UpdateWeightsFromTensorReqOutput:
message: str


@dataclass
class UpdateWeightsFromIPCReqInput:
# ZMQ socket paths for each device UUID
zmq_handles: Dict[str, str]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This connection logic should probably be generalized, since it won't just be ZMQ.

# Whether to flush cache after weight update
flush_cache: bool = True


@dataclass
class UpdateWeightsFromIPCReqOutput:
success: bool
message: str


@dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput:
# The master address
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
Expand Down Expand Up @@ -579,6 +580,7 @@ def __init__(
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/scheduler_update_weights_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
Expand Down Expand Up @@ -68,6 +70,18 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)

def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
"""Update the online model parameter from IPC for checkpoint-engine integration."""
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromIPCReqOutput(success, message)

def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
Expand Down
Loading
Loading