Skip to content

[Feature] Support loading weights from ckpt engine worker#11755

Merged
ByronHsu merged 27 commits intosgl-project:mainfrom
openanolis:kaiyong/checkpoint-engine
Oct 23, 2025
Merged

[Feature] Support loading weights from ckpt engine worker#11755
ByronHsu merged 27 commits intosgl-project:mainfrom
openanolis:kaiyong/checkpoint-engine

Conversation

@stmatengss
Copy link
Copy Markdown
Collaborator

@stmatengss stmatengss commented Oct 17, 2025

Motivation

Based on #10667/#10464 and #10646, we minimize implementation to ease review.
The ckpt engine maintainer approved this design (using worker, not connector) after discussion.

Usage:

pip install 'checkpoint-engine[p2p]'  # install checkpoint engine
  • Launch the server with wait-for-initial-weights option in one terminal (sglang side):
python -m sglang.launch_server --model-path Qwen/Qwen3-8B --tp  8  --load-format dummy --wait-for-initial-weights
  • Torchrun script in another terminal (ckpt engine side):
torchrun --nproc-per-node 8 examples/checkpoint_engine/update.py  --update-method broadcast --checkpoint-path /path/to/Qwen/Qwen3-8B/  --inference-parallel-size 8

Authors: @stmatengss @BraveY @XucSh @zxpdemonio
Thanks to @weixiao-huang for the help.

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

BraveY and others added 10 commits September 19, 2025 12:19
Implement IPC-based weight updates for checkpoint-engine compatibility:
- Add SGLangCheckpointEngineWorkerExtension worker class
- Implement update_weights_from_ipc across scheduler/tokenizer/worker layers
- Add collective_rpc endpoint for vLLM API compatibility
- Support ZMQ communication with device UUID management
- Include post-loading hooks and error handling

This allows efficient model weight updates via IPC without server restart.
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
…ad of copying duplicated codes

Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
refactor(checkpoint-engine): import code from checkpoint-engine instead of copying duplicated codes
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
…hts are loaded.

The health probe will only report a healthy status after the initial weights
have been successfully loaded when enable `wait-for-initial-weights` option.

Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the checkpoint-engine library into SGLang, allowing for dynamic and distributed loading of model weights. This enhancement provides greater flexibility in managing model lifecycles, enabling weight updates without server restarts and supporting scenarios where initial weights are provisioned externally. The changes include new API endpoints, a dedicated worker extension, and an example to showcase the functionality.

Highlights

  • Checkpoint Engine Integration: Introduced support for loading model weights from a checkpoint-engine worker, enabling dynamic weight updates via Inter-Process Communication (IPC).
  • New Server Option: Added a --wait-for-initial-weights command-line argument to the SGLang server, allowing it to pause startup until initial model weights are loaded via the checkpoint engine.
  • IPC Weight Update Endpoint: Implemented a new HTTP endpoint /update_weights_from_ipc to facilitate receiving and processing weight updates from the checkpoint-engine.
  • Modular Worker Extension: Created SGLangCheckpointEngineWorkerExtension to abstract the integration logic, providing a clear interface for SGLang's model runner to interact with the checkpoint engine's IPC mechanism.
  • Example Usage: Provided a new example script (examples/checkpoint_engine/update.py) demonstrating how to launch an SGLang server and update its weights using the checkpoint-engine.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature to support loading model weights from a ckpt-engine worker, enabling dynamic weight updates for a running server. The changes are extensive, touching multiple parts of the system from server entrypoints to the model execution logic. The implementation is generally well-structured, and the inclusion of an example script is very helpful. I've identified a few areas for improvement, including a potential runtime error in the example script, some opportunities for code simplification and refactoring to improve maintainability, and reliance on private members which could be brittle.

uds: str | None = None,
weight_version: str | None = None,
) -> Callable[[list[tuple[str, str]]], None]:
rank = int(os.getenv("RANK", None))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call int(os.getenv("RANK", None)) is not safe. If the "RANK" environment variable is not set, os.getenv will return None, and int(None) will raise a TypeError. While torchrun usually sets this variable, robust code should handle cases where it might be missing. You could provide a default value or add a check to ensure the environment variable is set.

    rank_str = os.getenv("RANK")
    if rank_str is None:
        raise RuntimeError("Environment variable RANK is not set.")
    rank = int(rank_str)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with genimi, None is not a good default value for the following logic.

def check_sglang_ready(
endpoint: str, inference_parallel_size: int, uds: str | None = None
):
if rank != rank // inference_parallel_size * inference_parallel_size:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function check_sglang_ready uses a global variable rank which is defined later in the script. While this works in this script, it's not a robust practice as it makes the function's behavior dependent on a non-local state that is not explicitly passed. This can lead to confusion and bugs if the code is refactored. It would be better to pass rank as an argument to check_sglang_ready and other functions that need it, like update_weights and join.

For example, you could change the function signature to:

def check_sglang_ready(
    endpoint: str, inference_parallel_size: int, rank: int, uds: str | None = None
):

And then update the call sites in update_weights and join to pass rank.

Comment on lines +152 to +154
if update_method:
# sleep 2s to wait destroy process group
time.sleep(2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if update_method: is redundant because it is nested inside if update_method == "p2p" or update_method == "all":. In this block, update_method will always be a non-empty string, which evaluates to True. You can remove this inner if statement.

        # sleep 2s to wait destroy process group
        time.sleep(2)

weight_version=args.weight_version,
)
ps = ParameterServer(auto_pg=True)
ps._p2p_store = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing the private member _p2p_store of the ParameterServer instance is generally not recommended as it relies on internal implementation details that might change in future versions of checkpoint-engine. If this is a necessary workaround, it would be beneficial to add a comment explaining why this is needed. If there's a public API to achieve the same result, it should be preferred.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. rank and world_size are initialized here in main according to env var. So they are the parameters of torch run? Should we make them the global params? I am seeing many os.getenv which don't feel good.

    rank = int(os.getenv("RANK"))
    world_size = int(os.getenv("WORLD_SIZE"))

Should they be given a prefix like SGLANG_CKPT_ENGINE_?

Comment on lines +2361 to +2388
def get_post_hook(self):
def post_hook():
# Perform post-processing after weight loading similar to DefaultModelLoader
try:
from sglang.srt.model_loader.loader import (
device_loading_context,
)

# Process quantization methods after loading weights
for _, module in self.model_runner.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# Move parameters to device if needed for quantization processing
target_device = torch.device(
"cuda", torch.cuda.current_device()
)
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(
module
)
# Call model-specific post-loading hook if available
if hasattr(self.model_runner.model, "post_load_weights"):
self.model_runner.model.post_load_weights()
except Exception as e:
logger.warning(f"Post-hook processing failed: {e}")

return post_hook # Create worker instance and perform IPC weight update

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of get_post_hook duplicates logic for post-processing weights (e.g., for quantization) from sglang.srt.model_loader.loader.DefaultModelLoader. This code duplication can make maintenance harder. Consider refactoring the post-processing logic into a reusable function that can be called from both DefaultModelLoader and here.

Additionally, the comment on line 2387 is misplaced and inaccurate. It should be moved to where the worker is created and used (around line 2389).

XucSh and others added 2 commits October 17, 2025 17:22
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
feat: Add /ping endpoint for dummy server's liveness probe
@BraveY BraveY requested a review from kushanam as a code owner October 22, 2025 04:00
@BraveY BraveY force-pushed the kaiyong/checkpoint-engine branch from 16c5380 to cca33c4 Compare October 22, 2025 04:11
@XucSh
Copy link
Copy Markdown
Collaborator

XucSh commented Oct 22, 2025

Testing was conducted on an 8*A10 GPU machine with the following results:

For the Qwen3-8B model with TP=4, the loading time for the Checkpoint Engine was 9s(assuming modelweights have been prefetched into dram by checkpoint engine), compared to 85s for loading from disk.



@dataclass
class UpdateWeightsFromIPCReqInput(BaseReq):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: say this is used for ckpt engine

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added the notes

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
def check_sglang_ready(
endpoint: str, inference_parallel_size: int, uds: str | None = None
):
if rank != rank // inference_parallel_size * inference_parallel_size:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is rank this parameter initialized?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is populated by torchrun, which this script currently depends on. This dependency will be removed later.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint engine project will add a general update.py script for both vllm and sglang. This script is a temporary solution.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: fix it when checkpoint engine releases a new version.

stmatengss and others added 4 commits October 23, 2025 13:37
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
@stmatengss
Copy link
Copy Markdown
Collaborator Author

Resolve all comments. PTAL. @ShangmingCai

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@stmatengss stmatengss added the ready-to-merge The PR is ready to merge after the CI is green. label Oct 23, 2025
@ByronHsu ByronHsu merged commit 96a5e4d into sgl-project:main Oct 23, 2025
122 of 185 checks passed
@sglang-bot
Copy link
Copy Markdown
Member

Can you wrap this torchrun --nproc-per-node 8 examples/checkpoint_engine/update.py --update-method broadcast --checkpoint-path /path/to/Qwen/Qwen3-8B/ --inference-parallel-size 8 as a simple command like python3 -m sglang.ckpt.update...

@stmatengss
Copy link
Copy Markdown
Collaborator Author

Can you wrap this torchrun --nproc-per-node 8 examples/checkpoint_engine/update.py --update-method broadcast --checkpoint-path /path/to/Qwen/Qwen3-8B/ --inference-parallel-size 8 as a simple command like python3 -m sglang.ckpt.update...

Thx for reminder. Will add this feature ASAP!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority ready-to-merge The PR is ready to merge after the CI is green. run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants