[Feature] Support loading weights from ckpt engine connector#10667
[Feature] Support loading weights from ckpt engine connector#10667stmatengss wants to merge 46 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for loading model weights from ckpt-engine to accelerate model loading and weight synchronization. The changes are extensive, adding new components for checkpoint engine interaction and modifying existing parts of the model loading and execution pipeline, including configurations, connectors, schedulers, and model runners.
My review focuses on ensuring the correctness, robustness, and maintainability of these new features. I have identified several critical issues that could lead to runtime errors or incorrect behavior, such as improper handling of environment variables, NameError exceptions due to undefined variables in error-handling paths, and potential data loss in the weight iteration logic. Additionally, I've provided suggestions to improve code clarity and security by addressing hardcoded values, removing dead code, and recommending safer serialization alternatives to pickle.
| for key, tensor in self.final_state_dict.items(): | ||
| yield key, tensor |
There was a problem hiding this comment.
After the main while loop in weight_iterator finishes, any weights remaining in self.pending_weights (e.g., a gate_proj weight without a corresponding up_proj weight in the processed payloads) are not yielded. This will result in missing weights and likely cause model loading to fail. You should process any remaining items in self.pending_weights after the loop.
| for key, tensor in self.final_state_dict.items(): | |
| yield key, tensor | |
| for key, tensor in self.final_state_dict.items(): | |
| yield key, tensor | |
| for key, tensor in self.pending_weights.items(): | |
| yield key, tensor |
| return iter | ||
|
|
||
| def model_load_weights(model, iter): | ||
| DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) |
There was a problem hiding this comment.
The variable target_device is not defined within the scope of update_weights_from_ckpt_engine. This will cause a NameError when model_load_weights is called. It should probably be device_config.device, which is available in the outer scope.
| DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) | |
| DefaultModelLoader.load_weights_and_postprocess(model, iter, device_config.device) |
| message = ( | ||
| f"Failed to update weights: {e}.\nRolling back to original weights." | ||
| ) | ||
| del iter |
| rank = int(os.getenv("RANK")) | ||
| world_size = int(os.getenv("WORLD_SIZE")) |
There was a problem hiding this comment.
The script will crash with a TypeError if the RANK or WORLD_SIZE environment variables are not set, because os.getenv will return None and int(None) is invalid. A similar issue exists on line 151. To make the script more robust, you should handle the case where these environment variables might not be set, for example by using os.environ which raises a KeyError if the variable is not found, providing a more explicit error.
| rank = int(os.getenv("RANK")) | |
| world_size = int(os.getenv("WORLD_SIZE")) | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) |
| def get_zmq_handle(self, tp_rank: int): | ||
| # FIXME: There needs a local rank |
There was a problem hiding this comment.
The FIXME comment indicates that tp_rank might not be the correct rank to use for getting the physical GPU ID, especially in a multi-node environment. tp_rank is a global rank, but _get_physical_gpu_id seems to expect a local rank on the node. Using the wrong rank could lead to incorrect GPU selection and failures. This should be resolved to ensure correctness in distributed setups. A similar issue is present on line 98.
| @@ -339,6 +341,12 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): | |||
| ) | |||
| return success, message | |||
|
|
|||
| def update_weights_from_ckpt_engine(self, recv_req: UpdateWeightFromCkptEngineReqInput): | |||
There was a problem hiding this comment.
There is a typo in the type hint for recv_req. It should be UpdateWeightsFromCkptEngineReqInput to match the imported class name.
| def update_weights_from_ckpt_engine(self, recv_req: UpdateWeightFromCkptEngineReqInput): | |
| def update_weights_from_ckpt_engine(self, recv_req: UpdateWeightsFromCkptEngineReqInput): |
| with open(self.server_args.ckpt_save_meta_file_name, "wb") as f: | ||
| pickle.dump(self.ps.get_metas(), f) |
There was a problem hiding this comment.
Using pickle for serialization can introduce security vulnerabilities, as unpickling data from an untrusted source can lead to arbitrary code execution. While you are writing the file here, it will be read elsewhere. As noted in the TODO on line 299, using a safer serialization format like JSON is recommended to avoid potential security risks.
| logger.info(f"{msg} duration: {end - start:.2f} seconds") | ||
|
|
||
|
|
||
| def check_vllm_ready(endpoint: str, inference_parallel_size: int): |
There was a problem hiding this comment.
The function check_vllm_ready uses a global variable rank which is defined at the bottom of the script (line 146). This is not a good practice as it makes the code harder to understand and maintain. The rank should be passed as an argument to the function.
| def check_vllm_ready(endpoint: str, inference_parallel_size: int): | |
| def check_vllm_ready(endpoint: str, inference_parallel_size: int, rank: int): |
| def load_model_from_ckpt_engine( | ||
| self, model, client, model_config: ModelConfig, device_config: DeviceConfig | ||
| ) -> nn.Module: | ||
| socket = client.get_socket_handle(device_config.gpu_id) |
| # FIXME: use more elegant method | ||
| if key == "model.embed_tokens.weight": | ||
| key = "lm_head.weight" |
There was a problem hiding this comment.
Hardcoding the remapping of model.embed_tokens.weight to lm_head.weight is brittle and may not work for all models. As the FIXME suggests, a more elegant and configurable method for handling weight name discrepancies should be implemented. This could involve a mapping file or a more general remapping logic.
| return weights | ||
|
|
||
| # Implemented as a no-op to make BaseConnector interface consistent. | ||
| def weight_iterator( |
There was a problem hiding this comment.
Could we reuse the method checkpoint_engine.worker.update_weights_from_ipc to make code more simple like below
def _process_gate_up_proj(self, named_tensors: list[Tuple[str, torch.Tensor]]):
for name, tensor in named_tensors:
if "mlp.gate_proj.weight" in name:
up_key = name.replace("gate_proj", "up_proj")
if up_key in self.pending_weights:
up_tensor = self.pending_weights.pop(up_key)
self._merge_and_store(name, tensor, up_key, up_tensor)
else:
self.pending_weights[name] = tensor
elif "mlp.up_proj.weight" in name:
gate_key = name.replace("up_proj", "gate_proj")
if gate_key in self.pending_weights:
gate_tensor = self.pending_weights.pop(gate_key)
self._merge_and_store(gate_key, gate_tensor, name, tensor)
else:
self.pending_weights[name] = tensor
else:
yield name, tensor
for key, tensor in self.final_state_dict.items():
yield key, tensor
def weight_iterator(self, rank: int = 0) -> Generator[Tuple[str, torch.Tensor], None, None]:
from checkpoint_engine.worker import update_weights_from_ipc
if self.socket is None:
self.get_socket_handle(rank)
update_weights_from_ipc(
self.zmq_ctx,
self.zmq_handle,
rank,
run=self._process_gate_up_proj,
)There was a problem hiding this comment.
That's an excellent suggestion. I will address this in the next commit.
| up_tensor = self.pending_weights.pop(up_key) | ||
| self._merge_and_store(item["name"], tensor, up_key, up_tensor) | ||
| else: | ||
| self.pending_weights[item["name"]] = tensor |
There was a problem hiding this comment.
weights are received by using a fixed bucket size. So all weights may be split to multiple turn to be received. This tensor data from buffer may be changed in the next update turn. A workaround method is to tensor.clone() but will occupy more GPU memory. Is it necessary for this tensor to be saved in pending_weights?
There was a problem hiding this comment.
In practice, this pending buffer is unused. The associated code will be removed in a subsequent commit.
| self, rank: int = 0 | ||
| ) -> Generator[Tuple[str, torch.Tensor], None, None]: | ||
| if self.socket is None: | ||
| self.get_socket_handle(rank) |
There was a problem hiding this comment.
self.zmq_handle will be changed in each ps.update method in the future's checkpoint-engine==0.1.2 version. see https://github.com/MoonshotAI/checkpoint-engine/blob/03ff7e7268d614b5c5d3af7388e541fc181bd892/checkpoint_engine/ps.py#L812-L818 since self._zmq_addr_counter += 1 will trigger for each ps.update. So self.zmq_handle should be updated by using self.get_zmq_handle() in each request
There was a problem hiding this comment.
Thanks, that's a helpful reminder. We are refactoring the static port to be dynamically negotiated, and we'll be sure to include this point.
| @@ -324,6 +325,13 @@ def __init__( | |||
| self.enable_overlap = False | |||
| logger.info("Overlap scheduler is disabled for embedding models.") | |||
|
|
|||
| # TODO: May change it to somewhere | |||
| os.environ["RANK"] = str(self.tp_rank) | |||
There was a problem hiding this comment.
Maybe I'll add rank and world_size args in ParameterServer, see https://github.com/MoonshotAI/checkpoint-engine/pull/20/files
There was a problem hiding this comment.
Got it. We're planning to make this a startup argument for SGLang and are working on it.
|
Tested on Qwen3-0.6B (TP1 & TP2) and Qwen3-8B(TP4) models. New inference instances function correctly after the weight update. |
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: CruzZhao <CruzZhao@linux.alibaba.com>
Signed-off-by: CruzZhao <CruzZhao@linux.alibaba.com>
There was a problem hiding this comment.
This is not a good place to put this file
There was a problem hiding this comment.
Is it ok to move it to /scripts?
There was a problem hiding this comment.
This scripts seems to be the same as https://github.com/MoonshotAI/checkpoint-engine/blob/main/examples/update.py. Could we write a code to get this python file and exec it? Just like
wget https://raw.githubusercontent.com/MoonshotAI/checkpoint-engine/refs/heads/main/examples/update.py
python3 update.py --helpThere was a problem hiding this comment.
This scripts seems to be the same as https://github.com/MoonshotAI/checkpoint-engine/blob/main/examples/update.py. Could we write a code to get this python file and exec it? Just like
wget https://raw.githubusercontent.com/MoonshotAI/checkpoint-engine/refs/heads/main/examples/update.py python3 update.py --help
In fact, We need to modify the logic of this script because executing it directly allocates most of the GPU memory to the communication buffer, causing Sglang to run out of memory and fail to start. Therefore, the original script logic cannot be used directly.
There was a problem hiding this comment.
Got it. will fix
There was a problem hiding this comment.
This scripts seems to be the same as https://github.com/MoonshotAI/checkpoint-engine/blob/main/examples/update.py. Could we write a code to get this python file and exec it? Just like
wget https://raw.githubusercontent.com/MoonshotAI/checkpoint-engine/refs/heads/main/examples/update.py python3 update.py --help
We can temporarily maintain this file within sglang, then merge it into the main checkpoint engine repository for easier maintenance.
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
|
Will reimplement the checkpoint engine connector after merging PR #11755. |
ByronHsu
left a comment
There was a problem hiding this comment.
High level questions:
- Checkpoint engine needs to load from disk to cpu, and use the pipeline to send to inference engines. disk-to-cpu is still on the critical path. What is the time we save here?
- Can you provide examples of how to use in online serving cases? For example, I add 10 new engines, how can i use checkpoint engine to make them load faster.
- Can you provide examples of how to use in RL case? How to do efficient broadcast to all inference engines.
Happy to chat online. You can find me at ByronHsu in sglang slack.
The checkpoint service is a persistent process that holds weights in memory (GPU/CPU). Each new instance (specifically, each TP rank) has its own ParameterService(checkpoint engine) from which it fetches weights at startup. |
Motivation
Motivated by #8215, we aim to integrate ckpt-engineinto SGLang to accelerate model loading and weight synchronization.
A proposal is in #10464, and this PR can support both co-locate/disaggregation deployment and TP.
Usage:
Fake sglang server (only occupying model weights).New sglang instanceRunning Methods:
sglang
checkpoint engine
Co-author: @XucSh @zxpdemonio @BraveY
Thanks to @weixiao-huang for help.
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist