diff --git a/docs/advanced_features/checkpoint_engine.md b/docs/advanced_features/checkpoint_engine.md new file mode 100644 index 000000000000..df563f57962b --- /dev/null +++ b/docs/advanced_features/checkpoint_engine.md @@ -0,0 +1,250 @@ +# Checkpoint Engine Design Documentation + +## Overview + +ckpt-engine is a lightweight library specifically designed to accelerate weight synchronization in large-scale distributed training. It operates on a parameter server architecture (ps.py, worker.py). It support two deployment methods: co-locate and disaggregation. Its core mechanism is to establish an asynchronous, pipelined data transfer process based mooncake transfer engine. This allows sglang inference engine to offload the weight update task to background workers, effectively hiding the I/O and communication latency. + +Two key scenarios can benefit from this ckpt-engine: + +- Reinforcement Learning (RL) Workloads – including RLHF, DPO, and continual pre-training – where model weights are updated frequently. Current methods for synchronizing these updates into the inference engine introduce significant latency, creating a bottleneck. This underutilizes GPUs during weight updates and slows the overall training-inference loop. +- Bulk Deployment – The boot time is a performance bottleneck when launching multiple SGLang instances. + +## Use Cases + +Prerequisites: installing checkpoint-engine +```bash +pip install 'checkpoint-engine[p2p]' # install checkpoint engine +``` + +Running Methods: + +- sglang +```bash +python3 -m sglang.launch_server --model /opt/models/Qwen/Qwen3-8b --tp 8 --load-format ckpt_engine --port 30001 +``` + +- checkpoint engine +```bash +torchrun --nproc-per-node 8 ckptengine_update.py --update-method all --checkpoint-path /opt/models/Qwen/Qwen3-8b/ +``` + +## Architecture + +### Core Components + +The checkpoint engine consists of several key components: + +1. **CkptEngineConnector** - The main connector that handles checkpoint engine communication +2. **CkptEngineModelLoader** - Specialized model loader for checkpoint engine format +3. **CkptEngineUpdate** - Standalone script for updating weights via checkpoint engine +4. **IPC-based Weight Transfer** - Efficient inter-process communication for weight updates + +### System Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ SGLang Server │ +├─────────────────────────────────────────────────────────────────┤ +│ HTTP API │ +│ ├── /update_weights_from_ckpt_engine │ +│ └── /update_weights_from_distributed │ +├─────────────────────────────────────────────────────────────────┤ +│ Scheduler │ +│ ├── SchedulerUpdateWeightsMixin │ +│ └── Request Dispatcher │ +├─────────────────────────────────────────────────────────────────┤ +│ Model Runner │ +│ ├── ModelRunner.update_weights_from_ckpt_engine() │ +│ └── ModelRunner.update_weights_from_distributed() │ +├─────────────────────────────────────────────────────────────────┤ +│ Connector │ +│ ├── CkptEngineConnector │ +│ └── BaseConnector Interface │ +├─────────────────────────────────────────────────────────────────┤ +│ Model Loader. │ +│ ├── CkptEngineModelLoader │ +│ └── get_model_loader() │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Key Features + +### 1. Checkpoint Engine Format Support + +The system supports a new load format called `ckpt_engine` that enables: + +- **Efficient Weight Loading**: Load models from checkpoint engine format +- **Distributed Loading**: Support for tensor parallel weight distribution +- **Memory Optimization**: Optimized memory usage during weight loading + +### 2. In-Place Weight Updates + +The checkpoint engine enables updating model weights without restarting the server: + +- **Hot Swapping**: Update weights while the server is running +- **Rollback Support**: Automatic rollback on update failures +- **Memory Safety**: Safe memory management during updates + +### 3. Inter-Process Communication (IPC) + +Efficient IPC-based weight transfer: + +- **Shared Memory**: Utilizes shared memory for efficient tensor transfer +- **Metadata Management**: Handles tensor metadata for proper reconstruction +- **Error Handling**: Robust error handling and cleanup + +### 4. Distributed Weight Synchronization + +Supports distributed weight updates across tensor parallel workers: + +- **Broadcast Updates**: Broadcast weight updates to all workers +- **P2P Updates**: Point-to-point weight updates for specific workers +- **Synchronization**: Proper synchronization barriers for consistency + +## Implementation Details + +### CkptEngineConnector + +The `CkptEngineConnector` class implements the core checkpoint engine functionality: + +```python +class CkptEngineConnector(BaseConnector): + def __init__(self, url: str, device: torch.device = "cpu"): + super().__init__(url) + self.url = url + self.device = device + self.zmq_handle = None + self.zmq_ctx = None + self.device_uuid = None + self.socket = None + self.buffer: Optional[torch.Tensor] = None + self.local_rank = None + self.final_state_dict = OrderedDict() + self.pending_weights: Dict[str, torch.Tensor] = {} +``` + +Key methods: +- `get_zmq_handle()`: Establishes ZMQ connection for weight transfer +- `update_weights_from_ipc()`: Handles IPC-based weight updates +- `_extract_weights()`: Extracts individual tensors from shared buffer + +### CkptEngineModelLoader + +The `CkptEngineModelLoader` handles loading models from checkpoint engine format: + +```python +class CkptEngineModelLoader(BaseModelLoader): + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig) -> nn.Module: + """Load model using checkpoint engine format.""" + logger.info("Loading weights from checkpoint engine format ...") + + model_weights = f"ckptengine://" + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config) + + with create_remote_connector(model_weights, device_config.device) as client: + connector_type = get_connector_type(client) + if connector_type == ConnectorType.CKPTENGINE: + self.load_model_from_ckpt_engine( + model, client, model_config, device_config + ) + else: + raise ValueError(f"Unsupported connector type {connector_type}") + + return model.eval() +``` + +### Weight Update Process + +The weight update process involves several steps: + +1. **Initialization**: Set up ZMQ connections and shared memory +2. **Metadata Transfer**: Send tensor metadata (shapes, dtypes, offsets) +3. **Buffer Transfer**: Transfer shared memory buffer containing weights +4. **Weight Loading**: Load weights into model using standard load_weights method +5. **Cleanup**: Clean up resources and synchronize + +### IPC Protocol + +The IPC protocol uses ZMQ for communication: + +- **Port Assignment**: Dynamic port assignment (base port 33001 + rank) +- **Message Types**: Support for tensor metadata, buffer handles, and termination signals +- **Error Handling**: Robust error handling with proper cleanup + +## API Integration + +### HTTP Endpoints + +The system exposes HTTP endpoints for weight updates: + +```python +@app.post("/update_weights_from_ckpt_engine") +async def update_weights_from_ckpt_engine( + obj: UpdateWeightsFromCkptEngineReqInput, request: Request +): + """Update the weights from disk inplace without re-launching the server.""" +``` + +### Request Structure + +Weight update requests include: +- `model_path`: Path to the new model weights +- `load_format`: Format of the weights (e.g., "ckpt_engine") + +## Configuration + +### Load Format Configuration + +The checkpoint engine format is registered in the load configuration: + +```python +class LoadFormat(str, enum.Enum): + # ... existing formats ... + CKPT_ENGINE = "ckpt_engine" +``` + +### Server Arguments + +The system supports configuration through server arguments: +- `--load-format ckpt_engine`: Use checkpoint engine format for initial loading +- Custom weight loader support for extensibility + + + +## Use Cases + +### 1. Online Model Updates + +Update model weights without server downtime: +```bash +curl -X POST http://localhost:30000/update_weights_from_ckpt_engine \ + -H "Content-Type: application/json" \ + -d '{"model_path": "/path/to/new/checkpoint", "load_format": "ckpt_engine"}' +``` + +### 2. Distributed Training Integration + +Integrate with distributed training systems for seamless model updates. + + +### 3. Model Serving at Scale + +Efficient weight management for large-scale model serving deployments. + +## Future Enhancements + +### Planned Features + +1. **Incremental Updates**: Support for incremental weight updates +2. **Compression**: Advanced compression algorithms for weight transfer +3. **Caching**: Intelligent caching for frequently used weights +4. **Monitoring**: Enhanced monitoring and metrics for weight updates + +### Performance Optimizations + +1. **Parallel Transfer**: Parallel weight transfer for large models +2. **Streaming**: Streaming weight updates for very large models +3. **GPU Direct**: GPU-direct memory transfer for improved performance diff --git a/docs/index.rst b/docs/index.rst index 691bc8524d74..5e60f8e04986 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,6 +49,7 @@ The core features include: advanced_features/router.md advanced_features/observability.md advanced_features/attention_backend.md + advanced_features/checkpoint_engine.md advanced_features/hicache.rst .. toctree:: diff --git a/python/sglang/ckptengine_update.py b/python/sglang/ckptengine_update.py new file mode 100644 index 000000000000..87786d3c748c --- /dev/null +++ b/python/sglang/ckptengine_update.py @@ -0,0 +1,193 @@ +import argparse +import json +import os +import pickle +import threading +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Callable, Literal + +import requests +import torch +import torch.distributed as dist +import zmq +from checkpoint_engine.ps import ( + ParameterServer, + _gen_h2d_buckets, + _to_named_tensor, + request_inference_to_update, +) +from loguru import logger +from safetensors import safe_open +from torch.multiprocessing.reductions import reduce_tensor + +CKPTENGINE_PORT = 33001 + + +@contextmanager +def timer(msg: str): + start = time.perf_counter() + yield + end = time.perf_counter() + logger.info(f"{msg} duration: {end - start:.2f} seconds") + + +def request_inference_to_update( + port, socket_paths: dict[str, str], host="localhost", timeout: float = 300.0 +): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(f"tcp://{host}:{port}") + message = json.dumps(socket_paths).encode("utf-8") + socket.send(message) + + +def split_checkpoint_files( + checkpoint_path: str, rank: int, world_size: int +) -> list[str]: + checkpoint_files = [ + os.path.join(checkpoint_path, f) + for f in filter( + lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path) + ) + ] + files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size + return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank] + + +def split_tensors( + checkpoint_path: str, rank: int, world_size: int +) -> dict[str, torch.Tensor]: + index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json") + with open(index_fn, "r") as f: + weight_map: dict[str, str] = json.load(f)["weight_map"] + weights_per_rank = (len(weight_map) + world_size - 1) // world_size + fn_tensors: dict[str, list[str]] = defaultdict(list) + weight_keys = list(weight_map.items()) + for name, file in weight_keys[ + rank * weights_per_rank : (rank + 1) * weights_per_rank + ]: + fn_tensors[file].append(name) + named_tensors = {} + for file, names in fn_tensors.items(): + with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f: + for name in names: + named_tensors[name] = f.get_tensor(name) + return named_tensors + + +def req_inference(inference_parallel_size: int, port: int): + rank = int(os.getenv("RANK", None)) + src = rank // inference_parallel_size * inference_parallel_size + + def req_func(socket_paths: list[tuple[str, str]]): + request_inference_to_update( + port, + dict(socket_paths[src : src + inference_parallel_size]), + ) + + return req_func + + +def update_weights( + ps: ParameterServer, + checkpoint_name: str, + checkpoint_files: list[str], + named_tensors: dict[str, torch.Tensor], + req_func: Callable[[list[tuple[str, str]]], None], + inference_parallel_size: int, + save_metas_file: str | None = None, + update_method: Literal["broadcast", "p2p", "all"] = "broadcast", +): + ps.register_checkpoint( + checkpoint_name, files=checkpoint_files, named_tensors=named_tensors + ) + ps.init_process_group() + dist.barrier() + with timer("Gather metas"): + ps.gather_metas(checkpoint_name) + if save_metas_file and int(os.getenv("RANK")) == 0: + with open(save_metas_file, "wb") as f: + pickle.dump(ps.get_metas(), f) + + if update_method == "broadcast" or update_method == "all": + with timer("Update weights without setting ranks"): + ps.update(checkpoint_name, req_func) + + if update_method == "p2p" or update_method == "all": + if update_method: + # sleep 2s to wait destroy process group + time.sleep(2) + with timer("Update weights with setting ranks"): + ps.update( + checkpoint_name, req_func, ranks=list(range(inference_parallel_size)) + ) + + +def join( + ps: ParameterServer, + checkpoint_name: str, + save_metas_file: str, + req_func: Callable[[list[tuple[str, str]]], None], + inference_parallel_size: int, +): + assert save_metas_file, "save_metas_file is required" + with open(save_metas_file, "rb") as f: + metas = pickle.load(f) + ps.init_process_group() + dist.barrier() + with timer("Gather metas before join"): + ps.gather_metas(checkpoint_name) + ps.load_metas(metas) + with timer( + f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p" + ): + ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Update weights example") + parser.add_argument("--checkpoint-path", type=str, default=None) + parser.add_argument("--save-metas-file", type=str, default=None) + parser.add_argument("--load-metas-file", type=str, default=None) + parser.add_argument("--sleep-time", type=int, default=0) + parser.add_argument("--inference-parallel-size", type=int, default=8) + parser.add_argument("--checkpoint-name", type=str, default="sglang-ckpt-iter-0") + parser.add_argument("--update-method", type=str, default="broadcast") + parser.add_argument("--ckpt-setup-port", type=int, default=CKPTENGINE_PORT) + args = parser.parse_args() + rank = int(os.getenv("RANK")) + world_size = int(os.getenv("WORLD_SIZE")) + req_func = req_inference(args.inference_parallel_size, args.ckpt_setup_port) + ps = ParameterServer(auto_pg=True) + ps._gpu_count = args.inference_parallel_size + if args.load_metas_file: + join( + ps, + args.checkpoint_name, + args.load_metas_file, + req_func, + args.inference_parallel_size, + ) + else: + if os.path.exists( + os.path.join(args.checkpoint_path, "model.safetensors.index.json") + ): + named_tensors = split_tensors(args.checkpoint_path, rank, world_size) + checkpoint_files = [] + else: + checkpoint_files = split_checkpoint_files( + args.checkpoint_path, rank, world_size + ) + named_tensors = {} + update_weights( + ps, + args.checkpoint_name, + checkpoint_files, + named_tensors, + req_func, + args.inference_parallel_size, + args.save_metas_file, + args.update_method, + ) + time.sleep(args.sleep_time) diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 7059fd95a32f..7b80eeec99e5 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -27,6 +27,7 @@ class LoadFormat(str, enum.Enum): REMOTE_INSTANCE = "remote_instance" RDMA = "rdma" LOCAL_CACHED = "local_cached" + CKPT_ENGINE = "ckpt_engine" @dataclass @@ -63,6 +64,7 @@ class LoadConfig: remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None + ckpt_engine_port: int = 33001 def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index c9663a836d14..182986b3edf0 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -8,6 +8,7 @@ BaseFileConnector, BaseKVConnector, ) +from sglang.srt.connector.ckpt_engine import CkptEngineConnector from sglang.srt.connector.redis import RedisConnector from sglang.srt.connector.remote_instance import RemoteInstanceConnector from sglang.srt.connector.s3 import S3Connector @@ -20,6 +21,7 @@ class ConnectorType(str, enum.Enum): FS = "filesystem" KV = "KV" INSTANCE = "instance" + CKPTENGINE = "ckptengine" def create_remote_connector(url, device, **kwargs) -> BaseConnector: @@ -30,6 +32,9 @@ def create_remote_connector(url, device, **kwargs) -> BaseConnector: return S3Connector(url) elif connector_type == "instance": return RemoteInstanceConnector(url, device) + elif connector_type == "ckptengine": + ckpt_engine_port = kwargs.get("ckpt_engine_port", 33001) + return CkptEngineConnector(url, device, ckpt_engine_port) else: raise ValueError(f"Invalid connector type: {url}") @@ -41,6 +46,8 @@ def get_connector_type(client: BaseConnector) -> ConnectorType: return ConnectorType.FS if isinstance(client, RemoteInstanceConnector): return ConnectorType.INSTANCE + if isinstance(client, CkptEngineConnector): + return ConnectorType.CKPTENGINE raise ValueError(f"Invalid connector type: {client}") @@ -51,6 +58,7 @@ def get_connector_type(client: BaseConnector) -> ConnectorType: "BaseKVConnector", "RedisConnector", "RemoteInstanceConnector", + "CkptEngineConnector", "S3Connector", "ConnectorType", "create_remote_connector", diff --git a/python/sglang/srt/connector/ckpt_engine.py b/python/sglang/srt/connector/ckpt_engine.py new file mode 100644 index 000000000000..293efda37615 --- /dev/null +++ b/python/sglang/srt/connector/ckpt_engine.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +import gc +import json +import logging +import subprocess +from collections import OrderedDict +from typing import Callable, Dict, Generator, List, Optional, Tuple, TypedDict +from urllib.parse import urlparse + +import torch +import torch.distributed as dist +import zmq + +from sglang.srt.connector import BaseConnector + +logger = logging.getLogger(__name__) + + +def _get_physical_gpu_id(device_index: int | None = None) -> str: + try: + return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}" + except AssertionError as e: + raise ValueError(f"fail to get physical gpu id {device_index}") from e + + +def _resolve_zmq_handle( + device_uuid: str, + all_device_uuids: List[str], + received_handles: Dict[str, str], +) -> str: + if device_uuid in received_handles: + logger.info(f"Rank for UUID {device_uuid}: Found direct ZMQ handle match.") + return received_handles[device_uuid] + + logger.warning( + f"Rank for UUID {device_uuid}: Direct match failed. Attempting fallback mapping for unshared GPUs." + ) + + device_uuids_set = set(all_device_uuids) + sender_uuids_set = set(received_handles.keys()) + + unmatched_my_uuids = sorted(list(device_uuids_set - sender_uuids_set)) + unmatched_sender_uuids = sorted(list(sender_uuids_set - device_uuids_set)) + + if len(unmatched_my_uuids) != len(unmatched_sender_uuids): + raise RuntimeError( + f"Unmatched GPU count mismatch. My unmatched: {len(unmatched_my_uuids)} " + f"({unmatched_my_uuids}), Sender's unmatched: {len(unmatched_sender_uuids)} " + f"({unmatched_sender_uuids}). Cannot establish a 1-to-1 mapping." + ) + + if not unmatched_my_uuids: + raise RuntimeError( + f"UUID {device_uuid} not found in received handles, but there are no " + "unmatched GPUs to perform fallback mapping. This indicates a logic error." + ) + + mapping = dict(zip(unmatched_my_uuids, unmatched_sender_uuids)) + + target_sender_uuid = mapping.get(device_uuid) + if not target_sender_uuid: + raise RuntimeError( + f"Failed to find UUID {device_uuid} in the fallback mapping. Mapping: {mapping}" + ) + + handle = received_handles[target_sender_uuid] + logger.info( + f"Rank for UUID {device_uuid}: Mapped to sender's UUID {target_sender_uuid} via fallback." + ) + return handle + + +def _rebuild_ipc( + handle: tuple[Callable, tuple], device_id: Optional[int] = None +) -> torch.Tensor: + """ + Rebuilds a tensor from a shared memory IPC handle on the correct GPU device. + """ + func, args = handle + list_args = list(args) + if device_id is not None: + # This ensures the tensor is mapped to the current process's specific GPU. + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # The starting offset of this tensor's data in the shared buffer. + offset: int + + +class CkptEngineConnector(BaseConnector): + + def __init__( + self, url: str, device: torch.device = "cpu", ckpt_engine_port: int = 33001 + ): + super().__init__(url) + self.url = url + self.device = device + self.ckpt_engine_port = ckpt_engine_port + self.zmq_handle = None + self.zmq_ctx = None + self.device_uuid = None + self.socket = None + self.buffer: Optional[torch.Tensor] = None + self.local_rank = None + + def get_zmq_handle(self, tp_rank: int): + # FIXME: There needs a local rank + self.device_uuid = _get_physical_gpu_id(tp_rank) + + data_container = [None] + if tp_rank == 0: + socket = zmq.Context().socket(zmq.PULL) + socket.bind(f"tcp://*:{self.ckpt_engine_port}") + + data = None + try: + raw_message = socket.recv() + + try: + data = json.loads(raw_message.decode("utf-8")) + + if not isinstance(data, dict): + logger.warning("CKPTENGINE: Not exactly the socket handle.") + + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.error(f"can not parse the socket raw message: {e}") + + except KeyboardInterrupt: + logger.info("\n shutting down the server.") + finally: + socket.close() + + if data is None: + raise RuntimeError( + "Rank 0 failed to receive or parse the ZMQ handle data." + ) + + data_container[0] = data + logger.info("Rank 0: Received handle data. Broadcasting to other ranks...") + + dist.broadcast_object_list(data_container, src=0) + received_data = data_container[0] + world_size = dist.get_world_size() + all_device_uuids = [None] * world_size + dist.all_gather_object(all_device_uuids, self.device_uuid) + self.zmq_handle = _resolve_zmq_handle( + self.device_uuid, all_device_uuids, received_data + ) + + def get_socket_handle(self, tp_rank: int): + # FIXME: local_rank is not tp_rank + if self.zmq_handle is not None: + return + self.local_rank = tp_rank + self.get_zmq_handle(tp_rank) + self.zmq_ctx = zmq.Context() + self.socket = self.zmq_ctx.socket(zmq.REP) + self.socket.connect(self.zmq_handle) + + # Implemented as a no-op to make BaseConnector interface consistent. + def pull_files( + self, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, + ) -> None: + return + + def _extract_weights( + self, payload: list[FlattenedTensorMetadata], buffer: torch.Tensor + ) -> list[tuple[str, torch.Tensor]]: + """ + Extracts individual weight tensors from a shared buffer using metadata. + """ + assert buffer is not None + weights: List[Tuple[str, torch.Tensor]] = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + return weights + + # Implemented as a no-op to make BaseConnector interface consistent. + def weight_iterator( + self, rank: int = 0 + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return + + def update_weights_from_ipc( + self, model, rank: int = 0, post_hook: Callable[[], None] = None + ): + self.get_socket_handle(rank) + try: + while True: + payload: tuple | list | None = self.socket.recv_pyobj() + + # Handle termination signal + if payload is None: + if post_hook is not None: + post_hook() + torch.cuda.synchronize() + self.socket.send(b"") + break + + # Handle IPC buffer setup + if isinstance(payload, tuple): + buffer = _rebuild_ipc(payload, self.local_rank) + assert buffer.dtype == torch.uint8 + self.socket.send(b"") + continue + + # Handle weight metadata payload + assert isinstance(payload, list) + + model.load_weights(self._extract_weights(payload, buffer)) + + torch.cuda.synchronize() + self.socket.send(b"") + except Exception as e: + logger.error(f"Error in IPC weight update on device {rank}: {e}") + raise + finally: + self.socket.close() + del self.buffer + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Cleaned up IPC weight update on device {rank}") diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 4da8e880e9ad..6b2071fd52eb 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -94,6 +94,7 @@ SlowDownReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightsFromCkptEngineReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, UpdateWeightVersionReqInput, @@ -832,6 +833,33 @@ async def update_weights_from_distributed( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/update_weights_from_ckpt_engine") +async def update_weights_from_ckpt_engine( + obj: UpdateWeightsFromCkptEngineReqInput, request: Request +): + """Update the weights from disk inplace without re-launching the server.""" + success, message = ( + await _global_state.tokenizer_manager.update_weights_from_ckpt_engine( + obj, request + ) + ) + + # Update weight version if provided and weights update was successful + if success and obj.weight_version is not None: + _update_weight_version_if_provided(obj.weight_version) + message += f" Weight version updated to {obj.weight_version}." + + content = { + "success": success, + "message": message, + } + status_code = HTTPStatus.OK if success else HTTPStatus.BAD_REQUEST + return ORJSONResponse( + content, + status_code=status_code, + ) + + @app.post("/update_weight_version") async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): """Update the weight version. This operation requires no active requests.""" diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index bb542b7bd19d..207c36f9fed9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1059,6 +1059,27 @@ class UpdateWeightsFromTensorReqOutput(BaseReq): @dataclass +class UpdateWeightsFromCkptEngineReqInput(BaseReq): + # The model path with the new weights + model_path: str + # The format to load the weights + load_format: Optional[str] = None + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + # Whether to flush the cache after updating weights + flush_cache: bool = True + + +@dataclass +class UpdateWeightsFromCkptEngineReqOutput(BaseReq): + success: bool + message: str + # Number of paused requests during weight sync. + num_paused_requests: Optional[int] = 0 + + class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq): # The master address master_address: str diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ea7b8222b974..4f8769339c0c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -110,6 +110,7 @@ UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, + UpdateWeightsFromCkptEngineReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) @@ -638,6 +639,10 @@ def __init__( self.update_weights_from_distributed, ), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + ( + UpdateWeightsFromCkptEngineReqInput, + self.update_weights_from_ckpt_engine, + ), (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index fdb7acd64419..162170312c24 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -17,6 +17,8 @@ ResumeMemoryOccupationReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, + UpdateWeightsFromCkptEngineReqInput, + UpdateWeightsFromCkptEngineReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, @@ -75,6 +77,20 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): torch.distributed.barrier(group=self.tp_cpu_group) return UpdateWeightsFromTensorReqOutput(success, message) + def update_weights_from_ckpt_engine( + self, recv_req: UpdateWeightsFromCkptEngineReqInput + ): + """Update the online model parameter from tensors.""" + success, message = self.tp_worker.update_weights_from_ckpt_engine(recv_req) + if success: + if recv_req.flush_cache: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + torch.distributed.barrier(group=self.tp_cpu_group) + return UpdateWeightsFromCkptEngineReqOutput(success, message) + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index cc929e5a780d..9e36d117aa0b 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -22,6 +22,7 @@ import fastapi import zmq +from sglang.srt.configs.load_config import LoadFormat from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, ClearHiCacheReqOutput, @@ -63,6 +64,8 @@ SlowDownReqOutput, UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqOutput, + UpdateWeightsFromCkptEngineReqInput, + UpdateWeightsFromCkptEngineReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, @@ -170,6 +173,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.update_weights_from_tensor_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.update_weights_from_ckpt_engine_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -236,6 +242,10 @@ def _get_communicator_dispatcher(self: TokenizerManager): UpdateWeightsFromTensorReqOutput, self.update_weights_from_tensor_communicator.handle_recv, ), + ( + UpdateWeightsFromCkptEngineReqOutput, + self.update_weights_from_ckpt_engine_communicator.handle_recv, + ), ( GetWeightsByNameReqOutput, self.get_weights_by_name_communicator.handle_recv, @@ -392,6 +402,30 @@ async def update_weights_from_distributed( result = (await self.update_weights_from_distributed_communicator(obj))[0] return result.success, result.message + async def update_weights_from_ckpt_engine( + self, + obj: UpdateWeightsFromCkptEngineReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + + # default the load format to ckpt_engine + if obj.load_format is None: + obj.load_format = LoadFormat.CKPT_ENGINE + logger.info("Start update_weights. Load format=%s", obj.load_format) + + if obj.abort_all_requests: + self.abort_request(abort_all=True) + + if True: # Keep this redundant check to simplify some internal code sync + # Hold the lock if it is not async. This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_ckpt_engine_communicator(obj))[ + 0 + ] + return result.success, result.message + async def init_weights_send_group_for_remote_instance( self, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 52a40a37122e..8acde825f5de 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -30,6 +30,8 @@ SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightsFromCkptEngineReqInput, + UpdateWeightsFromCkptEngineReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) @@ -381,6 +383,14 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): ) return success, message + def update_weights_from_ckpt_engine( + self, recv_req: UpdateWeightsFromCkptEngineReqInput + ): + success, message = self.model_runner.update_weights_from_ckpt_engine( + recv_req.model_path, recv_req.load_format + ) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fea4a49effd3..9ed3bf714ac8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,12 +18,14 @@ import inspect import json import logging +import multiprocessing as mp import os import socket import threading import time from collections import defaultdict from dataclasses import dataclass +from multiprocessing.connection import Connection from typing import List, Optional, Tuple, Union import torch @@ -786,6 +788,7 @@ def load_model(self): remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip, remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port, remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, + ckpt_engine_port=self.server_args.ckpt_engine_port, ) if self.device == "cpu": self.model_config = adjust_config_with_unaligned_cpu_tp( @@ -813,6 +816,7 @@ def load_model(self): monkey_patch_vllm_parallel_state() monkey_patch_isinstance_for_vllm_base_layer() + # Use standard model loading with self.memory_saver_adapter.region( GPU_MEMORY_TYPE_WEIGHTS, enable_cpu_backup=self.server_args.enable_weights_cpu_backup, @@ -1218,6 +1222,54 @@ def _update_weights_from_flattened_bucket( return True, "Success" + def update_weights_from_ckpt_engine( + self, model_path: str, load_format + ) -> tuple[bool, str]: + """Update engine weights in-place from the checkpoint engine.""" + logger.info( + f"Update engine weights online from checkpoint engine begin. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + self.model_config.model_path = model_path + load_config = LoadConfig(load_format=load_format) + loader = get_model_loader(load_config) + target_device = torch.device(self.device) + device_config = DeviceConfig(self.device, self.gpu_id) + + def get_weight_iter(config): + iter = loader._get_weights_iterator( + DefaultModelLoader.Source.init_new(config, self.model) + ) + return iter + + def model_load_weights(model, iter): + DefaultModelLoader.load_weights_and_postprocess( + model, iter, device_config.device + ) + return model + + with set_default_torch_dtype(self.model_config.dtype): + try: + model = loader.load_model( + model_config=self.model_config, device_config=device_config + ) + except Exception as e: + message = ( + f"Failed to update weights: {e}.\nRolling back to original weights." + ) + gc.collect() + iter = get_weight_iter(self.model_config) + self.model = model_load_weights(self.model, iter) + return False, message + + self.model = model + self.server_args.model_path = model_path + self.server_args.load_format = load_format + self.load_config = load_config + + logger.info("Update weights from ckpt engine end.") + return True, "Succeeded to update model weights." + def get_weights_by_name( self, name: str, truncate_size: int = 100 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index de58a8dd792d..291ea60e6256 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -16,6 +16,7 @@ import socket import threading import time +import uuid from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress @@ -1578,6 +1579,66 @@ def load_model_from_remote_instance( torch.cuda.empty_cache() +class CkptEngineModelLoader(BaseModelLoader): + """Model loader for checkpoint engine format.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def download_model(self, model_config: ModelConfig) -> None: + raise NotImplementedError + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + """Load model using checkpoint engine format.""" + logger.info("Loading weights from checkpoint engine format ...") + model_weights = f"ckptengine://" + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config) + + with create_remote_connector( + model_weights, + device_config.device, + ckpt_engine_port=self.load_config.ckpt_engine_port, + ) as client: + connector_type = get_connector_type(client) + if connector_type == ConnectorType.CKPTENGINE: + self.load_model_from_ckpt_engine( + model, client, model_config, device_config + ) + else: + raise ValueError( + f"Unsupported connector type {connector_type} for " + f"remote tensor model loading." + ) + + return model.eval() + + def load_model_from_ckpt_engine( + self, model, client, model_config: ModelConfig, device_config: DeviceConfig + ) -> nn.Module: + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + + def post_hook(): + post_load_weights(model, model_config) + + client.update_weights_from_ipc(model, device_config.gpu_id, post_hook) + + class RemoteModelLoader(BaseModelLoader): """Model loader that can load Tensors from remote database.""" @@ -1958,4 +2019,7 @@ def get_model_loader( if load_config.load_format == LoadFormat.REMOTE_INSTANCE: return RemoteInstanceModelLoader(load_config) + if load_config.load_format == LoadFormat.CKPT_ENGINE: + return CkptEngineModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b19b7bb320fb..8c8077ce70ef 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -68,6 +68,7 @@ "layered", "remote", "remote_instance", + "ckpt_engine", ] QUANTIZATION_CHOICES = [ @@ -477,6 +478,8 @@ class ServerArgs: remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None + ckpt_engine_port: int = 33001 + # For PD-Multiplexing enable_pdmux: bool = False pdmux_config_path: Optional[str] = None @@ -2945,6 +2948,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports, help="The communication group ports for loading weights from remote instance.", ) + parser.add_argument( + "--ckpt-engine-port", + type=int, + default=ServerArgs.ckpt_engine_port, + help="The base port for checkpoint engine communication. Default is 33001.", + ) # For PD-Multiplexing parser.add_argument(