-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[Feature] Support loading weights from ckpt engine worker #11755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
635fac7
e4bd670
ffe6bd5
43f0d92
41e7736
4b027d7
20ec0cf
bfc2f69
35cc659
42e1ad6
46275ae
036d525
506567e
e03496f
0025436
ae4b128
9c25784
3edd5fb
cca33c4
7419c41
cf2d56d
d5db877
98bcd8f
5f4107e
f4c9e14
077f92b
3afebdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,241 @@ | ||
| """ | ||
| Usage: | ||
| 1) Launch the server with wait-for-initial-weights option in one terminal: | ||
| python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7 | ||
|
|
||
| 2) Torchrun this script in another terminal: | ||
| torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2 | ||
| """ | ||
|
|
||
| import argparse | ||
| import json | ||
| import os | ||
| import pickle | ||
| import time | ||
| from collections import defaultdict | ||
| from collections.abc import Callable | ||
| from contextlib import contextmanager | ||
| from typing import Literal | ||
|
|
||
| import httpx | ||
| import torch | ||
| import torch.distributed as dist | ||
| from checkpoint_engine.ps import ParameterServer | ||
| from loguru import logger | ||
| from safetensors import safe_open | ||
|
|
||
|
|
||
| @contextmanager | ||
| def timer(msg: str): | ||
| start = time.perf_counter() | ||
| yield | ||
| end = time.perf_counter() | ||
| logger.info(f"{msg} duration: {end - start:.2f} seconds") | ||
|
|
||
|
|
||
| def check_sglang_ready( | ||
| endpoint: str, inference_parallel_size: int, uds: str | None = None | ||
| ): | ||
| if rank != rank // inference_parallel_size * inference_parallel_size: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function For example, you could change the function signature to: def check_sglang_ready(
endpoint: str, inference_parallel_size: int, rank: int, uds: str | None = None
):And then update the call sites in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This variable is populated by torchrun, which this script currently depends on. This dependency will be removed later.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The checkpoint engine project will add a general update.py script for both vllm and sglang. This script is a temporary solution.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: fix it when checkpoint engine releases a new version. |
||
| return | ||
| retry_num = 0 | ||
| transport = None | ||
| if uds is not None: | ||
| transport = httpx.HTTPTransport(uds=uds) | ||
| with httpx.Client(transport=transport) as client: | ||
| while True: | ||
| try: | ||
| response = client.get(f"{endpoint}/ping", timeout=10) | ||
| response.raise_for_status() | ||
| break | ||
| except (httpx.ConnectError, httpx.HTTPStatusError) as e: | ||
| if retry_num % 10 == 0: | ||
| logger.warning( | ||
| f"fail to check sglang ready, retry {retry_num} times, error: {e}" | ||
| ) | ||
| retry_num += 1 | ||
| time.sleep(0.1) | ||
|
|
||
|
|
||
| def split_checkpoint_files( | ||
| checkpoint_path: str, rank: int, world_size: int | ||
| ) -> list[str]: | ||
| checkpoint_files = [ | ||
| os.path.join(checkpoint_path, f) | ||
| for f in filter( | ||
| lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path) | ||
| ) | ||
| ] | ||
| files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size | ||
| return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank] | ||
|
|
||
|
|
||
| def split_tensors( | ||
| checkpoint_path: str, rank: int, world_size: int | ||
| ) -> dict[str, torch.Tensor]: | ||
| index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json") | ||
| with open(index_fn) as f: | ||
| weight_map: dict[str, str] = json.load(f)["weight_map"] | ||
| weights_per_rank = (len(weight_map) + world_size - 1) // world_size | ||
| fn_tensors: dict[str, list[str]] = defaultdict(list) | ||
| weight_keys = list(weight_map.items()) | ||
| for name, file in weight_keys[ | ||
| rank * weights_per_rank : (rank + 1) * weights_per_rank | ||
| ]: | ||
| fn_tensors[file].append(name) | ||
| named_tensors = {} | ||
| for file, names in fn_tensors.items(): | ||
| with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f: | ||
| for name in names: | ||
| named_tensors[name] = f.get_tensor(name) | ||
| return named_tensors | ||
|
|
||
|
|
||
| def req_inference( | ||
| endpoint: str, | ||
| inference_parallel_size: int, | ||
| timeout: float = 300.0, | ||
| uds: str | None = None, | ||
| weight_version: str | None = None, | ||
| ) -> Callable[[list[tuple[str, str]]], None]: | ||
| rank = int(os.getenv("RANK", 0)) | ||
| src = rank // inference_parallel_size * inference_parallel_size | ||
|
|
||
| def req_func(socket_paths: list[tuple[str, str]]): | ||
| if rank == src: | ||
| with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client: | ||
| resp = client.post( | ||
| f"{endpoint}/update_weights_from_ipc", | ||
| json={ | ||
| "zmq_handles": dict( | ||
| socket_paths[src : src + inference_parallel_size] | ||
| ), | ||
| "flush_cache": True, | ||
| "weight_version": weight_version, | ||
| }, | ||
| timeout=timeout, | ||
| ) | ||
| resp.raise_for_status() | ||
|
|
||
| return req_func | ||
|
|
||
|
|
||
| def update_weights( | ||
| ps: ParameterServer, | ||
| checkpoint_name: str, | ||
| checkpoint_files: list[str], | ||
| named_tensors: dict[str, torch.Tensor], | ||
| req_func: Callable[[list[tuple[str, str]]], None], | ||
| inference_parallel_size: int, | ||
| endpoint: str, | ||
| save_metas_file: str | None = None, | ||
| update_method: Literal["broadcast", "p2p", "all"] = "broadcast", | ||
| uds: str | None = None, | ||
| ): | ||
| ps.register_checkpoint( | ||
| checkpoint_name, files=checkpoint_files, named_tensors=named_tensors | ||
| ) | ||
| ps.init_process_group() | ||
| check_sglang_ready(endpoint, inference_parallel_size, uds) | ||
| dist.barrier() | ||
| with timer("Gather metas"): | ||
| ps.gather_metas(checkpoint_name) | ||
| if save_metas_file and int(os.getenv("RANK")) == 0: | ||
| with open(save_metas_file, "wb") as f: | ||
| pickle.dump(ps.get_metas(), f) | ||
|
|
||
| if update_method == "broadcast" or update_method == "all": | ||
| with timer("Update weights without setting ranks"): | ||
| ps.update(checkpoint_name, req_func) | ||
|
|
||
| if update_method == "p2p" or update_method == "all": | ||
| if update_method: | ||
| # sleep 2s to wait destroy process group | ||
| time.sleep(2) | ||
|
Comment on lines
+152
to
+154
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition # sleep 2s to wait destroy process group
time.sleep(2) |
||
| with timer("Update weights with setting ranks"): | ||
| ps.update( | ||
| checkpoint_name, req_func, ranks=list(range(inference_parallel_size)) | ||
| ) | ||
|
|
||
|
|
||
| def join( | ||
| ps: ParameterServer, | ||
| checkpoint_name: str, | ||
| load_metas_file: str, | ||
| req_func: Callable[[list[tuple[str, str]]], None], | ||
| inference_parallel_size: int, | ||
| endpoint: str, | ||
| uds: str | None = None, | ||
| ): | ||
| assert load_metas_file, "load_metas_file is required" | ||
| with open(load_metas_file, "rb") as f: | ||
| metas = pickle.load(f) | ||
| ps.init_process_group() | ||
| check_sglang_ready(endpoint, inference_parallel_size, uds) | ||
| dist.barrier() | ||
| with timer("Gather metas before join"): | ||
| ps.gather_metas(checkpoint_name) | ||
| ps.load_metas(metas) | ||
| with timer( | ||
| f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p" | ||
| ): | ||
| ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size))) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Update weights example") | ||
| parser.add_argument("--checkpoint-path", type=str, default=None) | ||
| parser.add_argument("--save-metas-file", type=str, default=None) | ||
| parser.add_argument("--load-metas-file", type=str, default=None) | ||
| parser.add_argument("--sleep-time", type=int, default=0) | ||
| parser.add_argument("--endpoint", type=str, default="http://localhost:19730") | ||
| parser.add_argument("--inference-parallel-size", type=int, default=8) | ||
| parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") | ||
| parser.add_argument("--update-method", type=str, default="broadcast") | ||
| parser.add_argument("--uds", type=str, default=None) | ||
| parser.add_argument("--weight-version", type=str, default=None) | ||
| args = parser.parse_args() | ||
| rank = int(os.getenv("RANK")) | ||
| world_size = int(os.getenv("WORLD_SIZE")) | ||
| req_func = req_inference( | ||
| args.endpoint, | ||
| args.inference_parallel_size, | ||
| uds=args.uds, | ||
| weight_version=args.weight_version, | ||
| ) | ||
| ps = ParameterServer(auto_pg=True) | ||
| ps._p2p_store = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing the private member
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see. rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))Should they be given a prefix like |
||
| if args.load_metas_file: | ||
| join( | ||
| ps, | ||
| args.checkpoint_name, | ||
| args.load_metas_file, | ||
| req_func, | ||
| args.inference_parallel_size, | ||
| args.endpoint, | ||
| args.uds, | ||
| ) | ||
| else: | ||
| if os.path.exists( | ||
| os.path.join(args.checkpoint_path, "model.safetensors.index.json") | ||
| ): | ||
| named_tensors = split_tensors(args.checkpoint_path, rank, world_size) | ||
| checkpoint_files = [] | ||
| else: | ||
| checkpoint_files = split_checkpoint_files( | ||
| args.checkpoint_path, rank, world_size | ||
| ) | ||
| named_tensors = {} | ||
| update_weights( | ||
| ps, | ||
| args.checkpoint_name, | ||
| checkpoint_files, | ||
| named_tensors, | ||
| req_func, | ||
| args.inference_parallel_size, | ||
| args.endpoint, | ||
| args.save_metas_file, | ||
| args.update_method, | ||
| args.uds, | ||
| ) | ||
| time.sleep(args.sleep_time) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| # Copyright 2023-2024 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """ | ||
| Checkpoint-engine integration for SGLang. | ||
| This module provides weight update functionality via IPC for checkpoint-engine compatibility. | ||
| """ | ||
| import logging | ||
| from typing import Callable, Dict, Optional | ||
|
|
||
| import torch | ||
| import zmq | ||
|
|
||
| try: | ||
| from checkpoint_engine.worker import update_weights_from_ipc | ||
| except ImportError: | ||
| raise ImportError( | ||
| "checkpoint-engine is not installed. " | ||
| "Please install it with: pip install sglang[checkpoint-engine]" | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class SGLangCheckpointEngineWorkerExtension: | ||
| """ | ||
| Worker extension for SGLang to support checkpoint-engine IPC weight updates. | ||
| This class provides the interface needed for checkpoint-engine integration. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._zmq_ctx: Optional[zmq.Context] = None | ||
|
|
||
| def get_device_uuid(self) -> str: | ||
| """Get the UUID of current device.""" | ||
| # We need to implement this to get the device UUID | ||
| # This will be overridden when integrated into SGLang's worker | ||
| raise NotImplementedError( | ||
| "This method should be overridden by SGLang integration" | ||
| ) | ||
|
|
||
| def get_device_id(self) -> int: | ||
| """Get the device ID.""" | ||
| raise NotImplementedError( | ||
| "This method should be overridden by SGLang integration" | ||
| ) | ||
|
|
||
| def get_model_loader(self) -> Callable: | ||
| """Get the model weight loader function.""" | ||
| raise NotImplementedError( | ||
| "This method should be overridden by SGLang integration" | ||
| ) | ||
|
|
||
| def get_post_hook(self) -> Optional[Callable]: | ||
| """Get the post-processing hook after weight loading.""" | ||
| return None | ||
|
|
||
| def update_weights_from_ipc(self, zmq_handles: Dict[str, str]): | ||
| """ | ||
| Update weights from IPC communication. | ||
| Args: | ||
| zmq_handles: Dict mapping device UUID to ZMQ socket path | ||
| """ | ||
| if self._zmq_ctx is None: | ||
| self._zmq_ctx = zmq.Context() | ||
| device_uuid = self.get_device_uuid() | ||
| device_id = self.get_device_id() | ||
| if device_uuid not in zmq_handles: | ||
| raise ValueError( | ||
| f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}" | ||
| ) | ||
| update_weights_from_ipc( | ||
| self._zmq_ctx, | ||
| zmq_handles[device_uuid], | ||
| device_id=device_id, | ||
| run=self.get_model_loader(), | ||
| post_hook=self.get_post_hook(), | ||
| ) | ||
|
|
||
|
|
||
| class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension): | ||
| """ | ||
| Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner. | ||
| This class provides the concrete implementation for checkpoint-engine IPC weight updates. | ||
| """ | ||
|
|
||
| def __init__(self, model_runner): | ||
| super().__init__() | ||
| self.model_runner = model_runner | ||
|
|
||
| def get_device_uuid(self) -> str: | ||
| """Get the UUID of current device.""" | ||
| # Get device UUID for current device | ||
| device_id = torch.cuda.current_device() | ||
| try: | ||
| return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}" | ||
| except AssertionError as e: | ||
| raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e | ||
|
|
||
| def get_device_id(self) -> int: | ||
| """Get the device ID.""" | ||
| return torch.cuda.current_device() | ||
|
|
||
| def get_model_loader(self) -> Callable: | ||
| """Get the model weight loader function.""" | ||
| return self.model_runner.model.load_weights | ||
|
|
||
| def get_post_hook(self) -> Optional[Callable]: | ||
| """Get the post-processing hook after weight loading.""" | ||
|
|
||
| def post_hook(): | ||
| # Perform post-processing after weight loading similar to DefaultModelLoader | ||
| try: | ||
| from sglang.srt.model_loader.loader import device_loading_context | ||
|
|
||
| # Process quantization methods after loading weights | ||
| for _, module in self.model_runner.model.named_modules(): | ||
| quant_method = getattr(module, "quant_method", None) | ||
| if quant_method is not None: | ||
| # Move parameters to device if needed for quantization processing | ||
| target_device = torch.device( | ||
| "cuda", torch.cuda.current_device() | ||
| ) | ||
| with device_loading_context(module, target_device): | ||
| quant_method.process_weights_after_loading(module) | ||
| # Call model-specific post-loading hook if available | ||
| if hasattr(self.model_runner.model, "post_load_weights"): | ||
| self.model_runner.model.post_load_weights() | ||
| except Exception as e: | ||
| logger.warning(f"Post-hook processing failed: {e}") | ||
|
|
||
| return post_hook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future, have something like
python -m sglang.launch_checkpoint_engineinside sglang python package, so whoever installs sglang can use itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! we can move this file to python/sglang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will fix it