diff --git a/python/sglang/srt/checkpoint_engine_worker.py b/python/sglang/srt/checkpoint_engine_worker.py new file mode 100644 index 000000000000..3dd4922b1115 --- /dev/null +++ b/python/sglang/srt/checkpoint_engine_worker.py @@ -0,0 +1,184 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Checkpoint-engine integration for SGLang. +This module provides weight update functionality via IPC for checkpoint-engine compatibility. +""" +import gc +import logging +from typing import Callable, Dict, List, Optional, Tuple, TypedDict + +import torch +import zmq + +logger = logging.getLogger(__name__) + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # specify the start offset of this tensor in shared ipc_buffer tensor + offset: int + + +def _rebuild_ipc( + handle: tuple[Callable, tuple], device_id: Optional[int] = None +) -> torch.Tensor: + """Rebuild a tensor from IPC handle, adapting to current device.""" + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +def _extract_weights( + payload: List[FlattenedTensorMetadata], buffer: torch.Tensor +) -> List[Tuple[str, torch.Tensor]]: + """Extract named weights from flattened buffer based on metadata.""" + assert buffer is not None + weights: List[Tuple[str, torch.Tensor]] = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + return weights + + +def update_weights_from_ipc( + zmq_ctx: zmq.Context, + zmq_handle: str, + device_id: int, + *, + run: Callable[[List[Tuple[str, torch.Tensor]]], None], + post_hook: Callable[[], None] = None, +): + """ + Core IPC weight update logic for SGLang. + Args: + zmq_ctx: ZMQ context for communication + zmq_handle: ZMQ socket path for this device + device_id: Current device ID + run: Function to apply weights to the model (model.load_weights) + post_hook: Optional post-processing function + """ + socket = zmq_ctx.socket(zmq.REP) + socket.connect(zmq_handle) + buffer: Optional[torch.Tensor] = None + logger.info( + f"Starting IPC weight update on device {device_id}, socket: {zmq_handle}" + ) + try: + while True: + payload: tuple[Callable, tuple] | List[FlattenedTensorMetadata] | None = ( + socket.recv_pyobj() + ) + if payload is None: + # means the update is done + logger.info(f"Weight update complete on device {device_id}") + if post_hook is not None: + post_hook() + torch.cuda.synchronize() + socket.send(b"") + break + if isinstance(payload, tuple): + # an ipc handle that we can use to rebuild GPU tensor + logger.debug(f"Received IPC handle on device {device_id}") + buffer = _rebuild_ipc(payload, device_id) + assert buffer.dtype == torch.uint8 + socket.send(b"") + continue + assert isinstance(payload, list) + # weight metadata list - extract and load weights + logger.debug( + f"Received {len(payload)} weight tensors on device {device_id}" + ) + weights = _extract_weights(payload, buffer) + run(weights) + torch.cuda.synchronize() + socket.send(b"") + except Exception as e: + logger.error(f"Error in IPC weight update on device {device_id}: {e}") + raise + finally: + socket.close() + del buffer + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Cleaned up IPC weight update on device {device_id}") + + +class SGLangCheckpointEngineWorkerExtension: + """ + Worker extension for SGLang to support checkpoint-engine IPC weight updates. + This class provides the interface needed for checkpoint-engine integration. + """ + + def __init__(self): + self._zmq_ctx: Optional[zmq.Context] = None + + def get_device_uuid(self) -> str: + """Get the UUID of current device.""" + # We need to implement this to get the device UUID + # This will be overridden when integrated into SGLang's worker + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_device_id(self) -> int: + """Get the device ID.""" + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_model_loader(self) -> Callable: + """Get the model weight loader function.""" + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_post_hook(self) -> Optional[Callable]: + """Get the post-processing hook after weight loading.""" + return None + + def update_weights_from_ipc(self, zmq_handles: Dict[str, str]): + """ + Update weights from IPC communication. + Args: + zmq_handles: Dict mapping device UUID to ZMQ socket path + """ + if self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + device_uuid = self.get_device_uuid() + device_id = self.get_device_id() + if device_uuid not in zmq_handles: + raise ValueError( + f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}" + ) + update_weights_from_ipc( + self._zmq_ctx, + zmq_handles[device_uuid], + device_id=device_id, + run=self.get_model_loader(), + post_hook=self.get_post_hook(), + ) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f86e9a751bf1..2c6c69837ab9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -60,6 +60,7 @@ UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter @@ -653,6 +654,21 @@ async def async_score( request=None, ) + def update_weights_from_ipc( + self, + zmq_handles: Dict[str, str], + flush_cache: bool = True, + ): + """Update weights from IPC for checkpoint-engine integration.""" + obj = UpdateWeightsFromIPCReqInput( + zmq_handles=zmq_handles, + flush_cache=flush_cache, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_ipc(obj, None) + ) + def _set_envs_and_config(server_args: ServerArgs): # Set global environments diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 3be69159c527..40a4e4a0436d 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -88,6 +88,7 @@ UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, UpdateWeightVersionReqInput, VertexGenerateReqInput, @@ -778,6 +779,19 @@ async def update_weights_from_distributed( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/update_weights_from_ipc") +async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request): + """Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration.""" + success, message = await _global_state.tokenizer_manager.update_weights_from_ipc( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + @app.post("/update_weight_version") async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): """Update the weight version. This operation requires no active requests.""" @@ -1174,10 +1188,54 @@ def _update_weight_version_if_provided(weight_version: Optional[str]) -> None: _global_state.tokenizer_manager.server_args.weight_version = weight_version -def _create_error_response(e): - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) +@app.post("/collective_rpc") +async def collective_rpc(request: Request): + """Collective RPC endpoint for compatibility with checkpoint-engine (similar to vLLM).""" + try: + body = await request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail=f"JSON decode error: {e}" + ) + method = body.get("method") + if method is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Missing 'method' in request body", + ) + # Handle the update_weights_from_ipc method specifically + if method == "update_weights_from_ipc": + args = body.get("args", []) + if not args or not isinstance(args[0], dict): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Invalid args for update_weights_from_ipc", + ) + zmq_handles = args[0] + success, message = ( + await _global_state.tokenizer_manager.update_weights_from_ipc( + UpdateWeightsFromIPCReqInput(zmq_handles=zmq_handles), request + ) + ) + if success: + return ORJSONResponse( + content={"results": [{"success": True, "message": message}]} + ) + else: + return ORJSONResponse( + content={"results": [{"success": False, "message": message}]}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + else: + raise HTTPException( + status_code=HTTPStatus.NOT_IMPLEMENTED, + detail=f"Method '{method}' not implemented in SGLang collective_rpc", + ) + + +async def http_health(request: Request): + """Check the health of the http server.""" + return Response(status_code=200) def launch_server( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c479f6d549ef..e0eb757bbe41 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1034,6 +1034,20 @@ class UpdateWeightsFromTensorReqOutput: message: str +@dataclass +class UpdateWeightsFromIPCReqInput: + # ZMQ socket paths for each device UUID + zmq_handles: Dict[str, str] + # Whether to flush cache after weight update + flush_cache: bool = True + + +@dataclass +class UpdateWeightsFromIPCReqOutput: + success: bool + message: str + + @dataclass class InitWeightsSendGroupForRemoteInstanceReqInput: # The master address diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a246534cb41e..eaec1f5cbb32 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -109,6 +109,7 @@ UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_embedding_cache @@ -579,6 +580,7 @@ def __init__( self.update_weights_from_distributed, ), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index fdae2142cd3d..4bd7eb55b07b 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -17,6 +17,8 @@ UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromIPCReqInput, + UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) @@ -68,6 +70,18 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): torch.distributed.barrier(group=self.tp_cpu_group) return UpdateWeightsFromTensorReqOutput(success, message) + def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput): + """Update the online model parameter from IPC for checkpoint-engine integration.""" + success, message = self.tp_worker.update_weights_from_ipc(recv_req) + if success: + if recv_req.flush_cache: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + torch.distributed.barrier(group=self.tp_cpu_group) + return UpdateWeightsFromIPCReqOutput(success, message) + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 8970d5ad50b9..5c86fe889740 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -59,6 +59,8 @@ UnloadLoRAAdapterReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromIPCReqInput, + UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) @@ -161,6 +163,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.update_weights_from_tensor_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.update_weights_from_ipc_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -223,6 +228,10 @@ def _get_communicator_dispatcher(self: TokenizerManager): UpdateWeightsFromTensorReqOutput, self.update_weights_from_tensor_communicator.handle_recv, ), + ( + UpdateWeightsFromIPCReqOutput, + self.update_weights_from_ipc_communicator.handle_recv, + ), ( GetWeightsByNameReqOutput, self.get_weights_by_name_communicator.handle_recv, @@ -411,6 +420,28 @@ async def update_weights_from_tensor( result = (await self.update_weights_from_tensor_communicator(obj))[0] return result.success, result.message + async def update_weights_from_ipc( + self, + obj: UpdateWeightsFromIPCReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + """Update weights via IPC for checkpoint-engine integration.""" + self.auto_create_handle_loop() + try: + # For now, we only support single data parallel instance + assert ( + self.server_args.dp_size == 1 or self.server_args.enable_dp_attention + ), "dp_size must be 1 or dp attention must be enabled for update weights from IPC" + logger.info("Starting IPC weight update") + # This means that weight sync cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_ipc_communicator(obj))[0] + return result.success, result.message + except Exception as e: + error_msg = f"IPC weight update failed: {str(e)}" + logger.error(error_msg) + return False, error_msg + async def load_lora_adapter( self: TokenizerManager, obj: LoadLoRAAdapterReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7bed87592719..4db39b46d7f1 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -37,6 +37,7 @@ UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict @@ -362,5 +363,14 @@ def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) return result + def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput): + """Update weights from IPC for checkpoint-engine integration.""" + try: + success, message = self.model_runner.update_weights_from_ipc(recv_req) + return success, message + except Exception as e: + logger.error(f"IPC weight update failed: {e}") + return False, str(e) + def can_run_lora_batch(self, lora_ids: list[str]) -> bool: return self.model_runner.lora_manager.validate_lora_batch(lora_ids) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index e34399a41de2..d96471dcd896 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -33,6 +33,7 @@ UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch @@ -311,6 +312,10 @@ def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): return self.worker.unload_lora_adapter(recv_req) + def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput): + """Update weights from IPC for checkpoint-engine integration.""" + return self.worker.update_weights_from_ipc(recv_req) + def can_run_lora_batch(self, lora_ids: list[str]) -> bool: return self.worker.can_run_lora_batch(lora_ids) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7835c3fa1755..013f7d142fce 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2228,6 +2228,79 @@ def save_sharded_model( ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) + def update_weights_from_ipc(self, recv_req): + """Update weights from IPC for checkpoint-engine integration.""" + try: + from sglang.srt.checkpoint_engine_worker import ( + SGLangCheckpointEngineWorkerExtension, + ) + + # Create a worker extension that integrates with SGLang's model + class SGLangWorkerImpl(SGLangCheckpointEngineWorkerExtension): + def __init__(self, model_runner): + super().__init__() + self.model_runner = model_runner + + def get_device_uuid(self) -> str: + # Get device UUID for current device + import subprocess + + device_id = torch.cuda.current_device() + result = subprocess.run( + ["nvidia-smi", "-L"], capture_output=True, text=True + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to get GPU UUID: {result.stderr}") + lines = result.stdout.strip().split("\n") + for line in lines: + if f"GPU {device_id}:" in line: + uuid = line.split("UUID: ")[1].strip(")") + return uuid + raise RuntimeError(f"Could not find UUID for GPU {device_id}") + + def get_device_id(self) -> int: + return torch.cuda.current_device() + + def get_model_loader(self): + return self.model_runner.model.load_weights + + def get_post_hook(self): + def post_hook(): + # Perform post-processing after weight loading similar to DefaultModelLoader + try: + from sglang.srt.model_loader.loader import ( + device_loading_context, + ) + + # Process quantization methods after loading weights + for _, module in self.model_runner.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # Move parameters to device if needed for quantization processing + target_device = torch.device( + "cuda", torch.cuda.current_device() + ) + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading( + module + ) + # Call model-specific post-loading hook if available + if hasattr(self.model_runner.model, "post_load_weights"): + self.model_runner.model.post_load_weights() + except Exception as e: + logger.warning(f"Post-hook processing failed: {e}") + + return post_hook # Create worker instance and perform IPC weight update + + worker = SGLangWorkerImpl(self) + worker.update_weights_from_ipc(recv_req.zmq_handles) + return True, "IPC weight update completed successfully" + except ImportError: + return False, "IPC weight update failed: ImportError" + except Exception as e: + logger.error(f"IPC weight update failed: {e}") + return False, str(e) + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters())