[Feature] Support loading weights from ckpt engine worker#11755
[Feature] Support loading weights from ckpt engine worker#11755ByronHsu merged 27 commits intosgl-project:mainfrom
Conversation
Implement IPC-based weight updates for checkpoint-engine compatibility: - Add SGLangCheckpointEngineWorkerExtension worker class - Implement update_weights_from_ipc across scheduler/tokenizer/worker layers - Add collective_rpc endpoint for vLLM API compatibility - Support ZMQ communication with device UUID management - Include post-loading hooks and error handling This allows efficient model weight updates via IPC without server restart.
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
…ad of copying duplicated codes Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
refactor(checkpoint-engine): import code from checkpoint-engine instead of copying duplicated codes
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
…hts are loaded. The health probe will only report a healthy status after the initial weights have been successfully loaded when enable `wait-for-initial-weights` option. Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
Summary of ChangesHello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature to support loading model weights from a ckpt-engine worker, enabling dynamic weight updates for a running server. The changes are extensive, touching multiple parts of the system from server entrypoints to the model execution logic. The implementation is generally well-structured, and the inclusion of an example script is very helpful. I've identified a few areas for improvement, including a potential runtime error in the example script, some opportunities for code simplification and refactoring to improve maintainability, and reliance on private members which could be brittle.
examples/checkpoint_engine/update.py
Outdated
| uds: str | None = None, | ||
| weight_version: str | None = None, | ||
| ) -> Callable[[list[tuple[str, str]]], None]: | ||
| rank = int(os.getenv("RANK", None)) |
There was a problem hiding this comment.
The call int(os.getenv("RANK", None)) is not safe. If the "RANK" environment variable is not set, os.getenv will return None, and int(None) will raise a TypeError. While torchrun usually sets this variable, robust code should handle cases where it might be missing. You could provide a default value or add a check to ensure the environment variable is set.
rank_str = os.getenv("RANK")
if rank_str is None:
raise RuntimeError("Environment variable RANK is not set.")
rank = int(rank_str)There was a problem hiding this comment.
I agree with genimi, None is not a good default value for the following logic.
| def check_sglang_ready( | ||
| endpoint: str, inference_parallel_size: int, uds: str | None = None | ||
| ): | ||
| if rank != rank // inference_parallel_size * inference_parallel_size: |
There was a problem hiding this comment.
The function check_sglang_ready uses a global variable rank which is defined later in the script. While this works in this script, it's not a robust practice as it makes the function's behavior dependent on a non-local state that is not explicitly passed. This can lead to confusion and bugs if the code is refactored. It would be better to pass rank as an argument to check_sglang_ready and other functions that need it, like update_weights and join.
For example, you could change the function signature to:
def check_sglang_ready(
endpoint: str, inference_parallel_size: int, rank: int, uds: str | None = None
):And then update the call sites in update_weights and join to pass rank.
| if update_method: | ||
| # sleep 2s to wait destroy process group | ||
| time.sleep(2) |
There was a problem hiding this comment.
The condition if update_method: is redundant because it is nested inside if update_method == "p2p" or update_method == "all":. In this block, update_method will always be a non-empty string, which evaluates to True. You can remove this inner if statement.
# sleep 2s to wait destroy process group
time.sleep(2)| weight_version=args.weight_version, | ||
| ) | ||
| ps = ParameterServer(auto_pg=True) | ||
| ps._p2p_store = None |
There was a problem hiding this comment.
Accessing the private member _p2p_store of the ParameterServer instance is generally not recommended as it relies on internal implementation details that might change in future versions of checkpoint-engine. If this is a necessary workaround, it would be beneficial to add a comment explaining why this is needed. If there's a public API to achieve the same result, it should be preferred.
There was a problem hiding this comment.
Oh, I see. rank and world_size are initialized here in main according to env var. So they are the parameters of torch run? Should we make them the global params? I am seeing many os.getenv which don't feel good.
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))Should they be given a prefix like SGLANG_CKPT_ENGINE_?
| def get_post_hook(self): | ||
| def post_hook(): | ||
| # Perform post-processing after weight loading similar to DefaultModelLoader | ||
| try: | ||
| from sglang.srt.model_loader.loader import ( | ||
| device_loading_context, | ||
| ) | ||
|
|
||
| # Process quantization methods after loading weights | ||
| for _, module in self.model_runner.model.named_modules(): | ||
| quant_method = getattr(module, "quant_method", None) | ||
| if quant_method is not None: | ||
| # Move parameters to device if needed for quantization processing | ||
| target_device = torch.device( | ||
| "cuda", torch.cuda.current_device() | ||
| ) | ||
| with device_loading_context(module, target_device): | ||
| quant_method.process_weights_after_loading( | ||
| module | ||
| ) | ||
| # Call model-specific post-loading hook if available | ||
| if hasattr(self.model_runner.model, "post_load_weights"): | ||
| self.model_runner.model.post_load_weights() | ||
| except Exception as e: | ||
| logger.warning(f"Post-hook processing failed: {e}") | ||
|
|
||
| return post_hook # Create worker instance and perform IPC weight update | ||
|
|
There was a problem hiding this comment.
The implementation of get_post_hook duplicates logic for post-processing weights (e.g., for quantization) from sglang.srt.model_loader.loader.DefaultModelLoader. This code duplication can make maintenance harder. Consider refactoring the post-processing logic into a reusable function that can be called from both DefaultModelLoader and here.
Additionally, the comment on line 2387 is misplaced and inaccurate. It should be moved to where the worker is created and used (around line 2389).
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
feat: Add /ping endpoint for dummy server's liveness probe
16c5380 to
cca33c4
Compare
|
Testing was conducted on an 8*A10 GPU machine with the following results: For the Qwen3-8B model with TP=4, the loading time for the Checkpoint Engine was 9s(assuming modelweights have been prefetched into dram by checkpoint engine), compared to 85s for loading from disk. |
|
|
||
|
|
||
| @dataclass | ||
| class UpdateWeightsFromIPCReqInput(BaseReq): |
There was a problem hiding this comment.
nit: say this is used for ckpt engine
| def check_sglang_ready( | ||
| endpoint: str, inference_parallel_size: int, uds: str | None = None | ||
| ): | ||
| if rank != rank // inference_parallel_size * inference_parallel_size: |
There was a problem hiding this comment.
Where is rank this parameter initialized?
There was a problem hiding this comment.
This variable is populated by torchrun, which this script currently depends on. This dependency will be removed later.
There was a problem hiding this comment.
The checkpoint engine project will add a general update.py script for both vllm and sglang. This script is a temporary solution.
There was a problem hiding this comment.
TODO: fix it when checkpoint engine releases a new version.
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
|
Resolve all comments. PTAL. @ShangmingCai |
|
Can you wrap this |
Thx for reminder. Will add this feature ASAP! |
Motivation
Based on #10667/#10464 and #10646, we minimize implementation to ease review.
The ckpt engine maintainer approved this design (using worker, not connector) after discussion.
Usage:
Authors: @stmatengss @BraveY @XucSh @zxpdemonio
Thanks to @weixiao-huang for the help.
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist