Skip to content

[Feature] Support loading weights from ckpt engine connector#10667

Closed
stmatengss wants to merge 46 commits intosgl-project:mainfrom
openanolis:mateng/dev_ckpt_engine
Closed

[Feature] Support loading weights from ckpt engine connector#10667
stmatengss wants to merge 46 commits intosgl-project:mainfrom
openanolis:mateng/dev_ckpt_engine

Conversation

@stmatengss
Copy link
Copy Markdown
Collaborator

@stmatengss stmatengss commented Sep 19, 2025

Motivation

Motivated by #8215, we aim to integrate ckpt-engineinto SGLang to accelerate model loading and weight synchronization.

A proposal is in #10464, and this PR can support both co-locate/disaggregation deployment and TP.

Usage:

pip install 'checkpoint-engine[p2p]'  # install checkpoint engine

Fake sglang server (only occupying model weights).

python3 -m sglang.launch_server --model [model-name] --tp 2 --page-size 1 --mem-fraction-static 0 --enable-ckpt-engine 

New sglang instance

python3 -m sglang.launch_server --model [model-name]  --tp 2 --mem-fraction-static 0.4 --load-format ckpt_engine --port [port_num]

Running Methods:

sglang

python3 -m sglang.launch_server --model /opt/models/Qwen/Qwen3-8b --tp 8 --load-format ckpt_engine --port 30001

checkpoint engine

torchrun --nproc-per-node 8 ckptengine_update.py --update-method all --checkpoint-path /opt/models/Qwen/Qwen3-8b/

Co-author: @XucSh @zxpdemonio @BraveY
Thanks to @weixiao-huang for help.

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the checkpoint-engine library into SGLang, providing a new mechanism for loading and updating large language model weights. The primary goal is to enhance efficiency in model loading and weight synchronization across distributed environments, supporting both co-located/disaggregated deployments and Tensor Parallelism. This involves adding a new load format, a dedicated connector, and an API endpoint for dynamic weight updates, streamlining the management of model parameters.

Highlights

  • Checkpoint Engine Integration: Introduced support for loading model weights using the checkpoint-engine library, enabling accelerated model loading and weight synchronization, particularly beneficial for co-located/disaggregated deployments and Tensor Parallelism (TP).
  • New Load Format: Added a new ckpt_engine load format option, allowing users to specify this method for loading models.
  • API Endpoint for Weight Updates: Implemented a new FastAPI endpoint /update_weights_from_ckpt_engine to facilitate dynamic, in-place updates of model weights via the checkpoint engine without requiring a server restart.
  • Custom Checkpoint Engine Connector: Developed a dedicated CkptEngineConnector to manage the communication and weight extraction process from the checkpoint engine, including handling ZMQ communication and IPC buffers.
  • Parameter Server Initialization: Integrated the ParameterServer from checkpoint-engine into the model runner, allowing it to register checkpoints, gather metadata, and update weights using either broadcast or P2P methods.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for loading model weights from ckpt-engine to accelerate model loading and weight synchronization. The changes are extensive, adding new components for checkpoint engine interaction and modifying existing parts of the model loading and execution pipeline, including configurations, connectors, schedulers, and model runners.

My review focuses on ensuring the correctness, robustness, and maintainability of these new features. I have identified several critical issues that could lead to runtime errors or incorrect behavior, such as improper handling of environment variables, NameError exceptions due to undefined variables in error-handling paths, and potential data loss in the weight iteration logic. Additionally, I've provided suggestions to improve code clarity and security by addressing hardcoded values, removing dead code, and recommending safer serialization alternatives to pickle.

Comment on lines +190 to +191
for key, tensor in self.final_state_dict.items():
yield key, tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

After the main while loop in weight_iterator finishes, any weights remaining in self.pending_weights (e.g., a gate_proj weight without a corresponding up_proj weight in the processed payloads) are not yielded. This will result in missing weights and likely cause model loading to fail. You should process any remaining items in self.pending_weights after the loop.

Suggested change
for key, tensor in self.final_state_dict.items():
yield key, tensor
for key, tensor in self.final_state_dict.items():
yield key, tensor
for key, tensor in self.pending_weights.items():
yield key, tensor

return iter

def model_load_weights(model, iter):
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable target_device is not defined within the scope of update_weights_from_ckpt_engine. This will cause a NameError when model_load_weights is called. It should probably be device_config.device, which is available in the outer scope.

Suggested change
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
DefaultModelLoader.load_weights_and_postprocess(model, iter, device_config.device)

message = (
f"Failed to update weights: {e}.\nRolling back to original weights."
)
del iter
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable iter is not defined in this scope before being deleted, which will cause a NameError. This line should be removed.

Comment on lines +229 to +230
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The script will crash with a TypeError if the RANK or WORLD_SIZE environment variables are not set, because os.getenv will return None and int(None) is invalid. A similar issue exists on line 151. To make the script more robust, you should handle the case where these environment variables might not be set, for example by using os.environ which raises a KeyError if the variable is not found, providing a more explicit error.

Suggested change
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

Comment on lines +73 to +74
def get_zmq_handle(self, tp_rank: int):
# FIXME: There needs a local rank
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The FIXME comment indicates that tp_rank might not be the correct rank to use for getting the physical GPU ID, especially in a multi-node environment. tp_rank is a global rank, but _get_physical_gpu_id seems to expect a local rank on the node. Using the wrong rank could lead to incorrect GPU selection and failures. This should be resolved to ensure correctness in distributed setups. A similar issue is present on line 98.

@@ -339,6 +341,12 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
)
return success, message

def update_weights_from_ckpt_engine(self, recv_req: UpdateWeightFromCkptEngineReqInput):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the type hint for recv_req. It should be UpdateWeightsFromCkptEngineReqInput to match the imported class name.

Suggested change
def update_weights_from_ckpt_engine(self, recv_req: UpdateWeightFromCkptEngineReqInput):
def update_weights_from_ckpt_engine(self, recv_req: UpdateWeightsFromCkptEngineReqInput):

Comment on lines +300 to +301
with open(self.server_args.ckpt_save_meta_file_name, "wb") as f:
pickle.dump(self.ps.get_metas(), f)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using pickle for serialization can introduce security vulnerabilities, as unpickling data from an untrusted source can lead to arbitrary code execution. While you are writing the file here, it will be read elsewhere. As noted in the TODO on line 299, using a safer serialization format like JSON is recommended to avoid potential security risks.

logger.info(f"{msg} duration: {end - start:.2f} seconds")


def check_vllm_ready(endpoint: str, inference_parallel_size: int):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function check_vllm_ready uses a global variable rank which is defined at the bottom of the script (line 146). This is not a good practice as it makes the code harder to understand and maintain. The rank should be passed as an argument to the function.

Suggested change
def check_vllm_ready(endpoint: str, inference_parallel_size: int):
def check_vllm_ready(endpoint: str, inference_parallel_size: int, rank: int):

def load_model_from_ckpt_engine(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
socket = client.get_socket_handle(device_config.gpu_id)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The socket variable is assigned but never used. This appears to be dead code and should be removed to improve clarity.

Comment on lines +1554 to +1556
# FIXME: use more elegant method
if key == "model.embed_tokens.weight":
key = "lm_head.weight"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding the remapping of model.embed_tokens.weight to lm_head.weight is brittle and may not work for all models. As the FIXME suggests, a more elegant and configurable method for handling weight name discrepancies should be implemented. This could involve a mapping file or a more general remapping logic.

return weights

# Implemented as a no-op to make BaseConnector interface consistent.
def weight_iterator(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse the method checkpoint_engine.worker.update_weights_from_ipc to make code more simple like below

    def _process_gate_up_proj(self, named_tensors: list[Tuple[str, torch.Tensor]]):
        for name, tensor in named_tensors:
            if "mlp.gate_proj.weight" in name:
                up_key = name.replace("gate_proj", "up_proj")
                if up_key in self.pending_weights:
                    up_tensor = self.pending_weights.pop(up_key)
                    self._merge_and_store(name, tensor, up_key, up_tensor)
                else:
                    self.pending_weights[name] = tensor

            elif "mlp.up_proj.weight" in name:
                gate_key = name.replace("up_proj", "gate_proj")
                if gate_key in self.pending_weights:
                    gate_tensor = self.pending_weights.pop(gate_key)
                    self._merge_and_store(gate_key, gate_tensor, name, tensor)
                else:
                    self.pending_weights[name] = tensor

            else:
                yield name, tensor
        for key, tensor in self.final_state_dict.items():
            yield key, tensor


    def weight_iterator(self, rank: int = 0) -> Generator[Tuple[str, torch.Tensor], None, None]:
        from checkpoint_engine.worker import update_weights_from_ipc
        if self.socket is None:
            self.get_socket_handle(rank)

        update_weights_from_ipc(
            self.zmq_ctx,
            self.zmq_handle,
            rank,
            run=self._process_gate_up_proj,
        )

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an excellent suggestion. I will address this in the next commit.

up_tensor = self.pending_weights.pop(up_key)
self._merge_and_store(item["name"], tensor, up_key, up_tensor)
else:
self.pending_weights[item["name"]] = tensor
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights are received by using a fixed bucket size. So all weights may be split to multiple turn to be received. This tensor data from buffer may be changed in the next update turn. A workaround method is to tensor.clone() but will occupy more GPU memory. Is it necessary for this tensor to be saved in pending_weights?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, this pending buffer is unused. The associated code will be removed in a subsequent commit.

self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.socket is None:
self.get_socket_handle(rank)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.zmq_handle will be changed in each ps.update method in the future's checkpoint-engine==0.1.2 version. see https://github.com/MoonshotAI/checkpoint-engine/blob/03ff7e7268d614b5c5d3af7388e541fc181bd892/checkpoint_engine/ps.py#L812-L818 since self._zmq_addr_counter += 1 will trigger for each ps.update. So self.zmq_handle should be updated by using self.get_zmq_handle() in each request

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's a helpful reminder. We are refactoring the static port to be dynamically negotiated, and we'll be sure to include this point.

@@ -324,6 +325,13 @@ def __init__(
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")

# TODO: May change it to somewhere
os.environ["RANK"] = str(self.tp_rank)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'll add rank and world_size args in ParameterServer, see https://github.com/MoonshotAI/checkpoint-engine/pull/20/files

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. We're planning to make this a startup argument for SGLang and are working on it.

@XucSh
Copy link
Copy Markdown
Collaborator

XucSh commented Sep 20, 2025

Tested on Qwen3-0.6B (TP1 & TP2) and Qwen3-8B(TP4) models. New inference instances function correctly after the weight update.

stmatengss and others added 16 commits September 20, 2025 12:46
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: CruzZhao <CruzZhao@linux.alibaba.com>
Signed-off-by: CruzZhao <CruzZhao@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a good place to put this file

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to move it to /scripts?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scripts seems to be the same as https://github.com/MoonshotAI/checkpoint-engine/blob/main/examples/update.py. Could we write a code to get this python file and exec it? Just like

wget https://raw.githubusercontent.com/MoonshotAI/checkpoint-engine/refs/heads/main/examples/update.py
python3 update.py --help

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scripts seems to be the same as https://github.com/MoonshotAI/checkpoint-engine/blob/main/examples/update.py. Could we write a code to get this python file and exec it? Just like

wget https://raw.githubusercontent.com/MoonshotAI/checkpoint-engine/refs/heads/main/examples/update.py
python3 update.py --help

In fact, We need to modify the logic of this script because executing it directly allocates most of the GPU memory to the communication buffer, causing Sglang to run out of memory and fail to start. Therefore, the original script logic cannot be used directly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. will fix

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scripts seems to be the same as https://github.com/MoonshotAI/checkpoint-engine/blob/main/examples/update.py. Could we write a code to get this python file and exec it? Just like

wget https://raw.githubusercontent.com/MoonshotAI/checkpoint-engine/refs/heads/main/examples/update.py
python3 update.py --help

We can temporarily maintain this file within sglang, then merge it into the main checkpoint engine repository for easier maintenance.

stmatengss and others added 2 commits September 29, 2025 18:19
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
@stmatengss stmatengss removed the run-ci label Sep 29, 2025
stmatengss and others added 5 commits October 7, 2025 13:10
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
@stmatengss stmatengss requested a review from whybeyoung October 9, 2025 09:36
XucSh and others added 5 commits October 9, 2025 20:20
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
@stmatengss
Copy link
Copy Markdown
Collaborator Author

Will reimplement the checkpoint engine connector after merging PR #11755.

Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level questions:

  1. Checkpoint engine needs to load from disk to cpu, and use the pipeline to send to inference engines. disk-to-cpu is still on the critical path. What is the time we save here?
  2. Can you provide examples of how to use in online serving cases? For example, I add 10 new engines, how can i use checkpoint engine to make them load faster.
  3. Can you provide examples of how to use in RL case? How to do efficient broadcast to all inference engines.

Happy to chat online. You can find me at ByronHsu in sglang slack.

@XucSh
Copy link
Copy Markdown
Collaborator

XucSh commented Oct 20, 2025

High level questions:

  1. Checkpoint engine needs to load from disk to cpu, and use the pipeline to send to inference engines. disk-to-cpu is still on the critical path. What is the time we save here?
  2. Can you provide examples of how to use in online serving cases? For example, I add 10 new engines, how can i use checkpoint engine to make them load faster.
  3. Can you provide examples of how to use in RL case? How to do efficient broadcast to all inference engines.

Happy to chat online. You can find me at ByronHsu in sglang slack.

The checkpoint service is a persistent process that holds weights in memory (GPU/CPU). Each new instance (specifically, each TP rank) has its own ParameterService(checkpoint engine) from which it fetches weights at startup.

@XucSh XucSh closed this Dec 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.