Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
635fac7
feat: Add checkpoint-engine IPC integration
BraveY Sep 19, 2025
e4bd670
style: apply pre-commit formatting changes
BraveY Sep 22, 2025
ffe6bd5
Merge remote-tracking branch 'origin' into kaiyong/checkpoint-engine
stmatengss Oct 13, 2025
43f0d92
refactor(checkpoint-engine): import code from checkpoint-engine inste…
BraveY Oct 14, 2025
41e7736
Merge pull request #1 from BraveY/checkpoint-engine-dev
stmatengss Oct 14, 2025
4b027d7
feat: add weight_version for update_weights_from_ipc
zxpdemonio Oct 15, 2025
20ec0cf
fix(checkpoint-engine): defer health check success until initial weig…
BraveY Oct 15, 2025
bfc2f69
add weight_version
zxpdemonio Oct 16, 2025
35cc659
precommit
stmatengss Oct 17, 2025
42e1ad6
Merge branch 'main' into kaiyong/checkpoint-engine
stmatengss Oct 17, 2025
46275ae
Fix lint
XucSh Oct 17, 2025
036d525
feat: Add /ping endpoint for dummy server's liveness probe
BraveY Oct 17, 2025
506567e
Merge pull request #3 from BraveY/checkpoint-engine-dev
stmatengss Oct 19, 2025
e03496f
Merge branch 'main' into kaiyong/checkpoint-engine
stmatengss Oct 20, 2025
0025436
rename args
stmatengss Oct 21, 2025
ae4b128
move worker to ckpt engine dir
stmatengss Oct 21, 2025
9c25784
delete useless variable
stmatengss Oct 21, 2025
3edd5fb
fix
stmatengss Oct 21, 2025
cca33c4
fix(checkpoint-engine): fix attribute error and lazy import checkpoin…
BraveY Oct 22, 2025
7419c41
fix issue
stmatengss Oct 22, 2025
cf2d56d
fix issue
stmatengss Oct 22, 2025
d5db877
pre-commit
stmatengss Oct 22, 2025
98bcd8f
update
XucSh Oct 23, 2025
5f4107e
Update examples/checkpoint_engine/update.py
stmatengss Oct 23, 2025
f4c9e14
fix
stmatengss Oct 23, 2025
077f92b
Update examples/checkpoint_engine/update.py
stmatengss Oct 23, 2025
3afebdf
Update examples/checkpoint_engine/update.py
stmatengss Oct 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,5 @@ SGLang supports various environment variables that can be used to configure its

| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_WAIT_WEIGHTS_READY_TIMEOUT` | Timeout period for waiting on weights | `120` |
| `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `true` |
241 changes: 241 additions & 0 deletions examples/checkpoint_engine/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""
Usage:
1) Launch the server with wait-for-initial-weights option in one terminal:
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7

2) Torchrun this script in another terminal:
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future, have something like python -m sglang.launch_checkpoint_engine inside sglang python package, so whoever installs sglang can use it

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! we can move this file to python/sglang

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix it

"""

import argparse
import json
import os
import pickle
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from typing import Literal

import httpx
import torch
import torch.distributed as dist
from checkpoint_engine.ps import ParameterServer
from loguru import logger
from safetensors import safe_open


@contextmanager
def timer(msg: str):
start = time.perf_counter()
yield
end = time.perf_counter()
logger.info(f"{msg} duration: {end - start:.2f} seconds")


def check_sglang_ready(
endpoint: str, inference_parallel_size: int, uds: str | None = None
):
if rank != rank // inference_parallel_size * inference_parallel_size:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function check_sglang_ready uses a global variable rank which is defined later in the script. While this works in this script, it's not a robust practice as it makes the function's behavior dependent on a non-local state that is not explicitly passed. This can lead to confusion and bugs if the code is refactored. It would be better to pass rank as an argument to check_sglang_ready and other functions that need it, like update_weights and join.

For example, you could change the function signature to:

def check_sglang_ready(
    endpoint: str, inference_parallel_size: int, rank: int, uds: str | None = None
):

And then update the call sites in update_weights and join to pass rank.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is rank this parameter initialized?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is populated by torchrun, which this script currently depends on. This dependency will be removed later.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint engine project will add a general update.py script for both vllm and sglang. This script is a temporary solution.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: fix it when checkpoint engine releases a new version.

return
retry_num = 0
transport = None
if uds is not None:
transport = httpx.HTTPTransport(uds=uds)
with httpx.Client(transport=transport) as client:
while True:
try:
response = client.get(f"{endpoint}/ping", timeout=10)
response.raise_for_status()
break
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
if retry_num % 10 == 0:
logger.warning(
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
)
retry_num += 1
time.sleep(0.1)


def split_checkpoint_files(
checkpoint_path: str, rank: int, world_size: int
) -> list[str]:
checkpoint_files = [
os.path.join(checkpoint_path, f)
for f in filter(
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
)
]
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]


def split_tensors(
checkpoint_path: str, rank: int, world_size: int
) -> dict[str, torch.Tensor]:
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
with open(index_fn) as f:
weight_map: dict[str, str] = json.load(f)["weight_map"]
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
fn_tensors: dict[str, list[str]] = defaultdict(list)
weight_keys = list(weight_map.items())
for name, file in weight_keys[
rank * weights_per_rank : (rank + 1) * weights_per_rank
]:
fn_tensors[file].append(name)
named_tensors = {}
for file, names in fn_tensors.items():
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
for name in names:
named_tensors[name] = f.get_tensor(name)
return named_tensors


def req_inference(
endpoint: str,
inference_parallel_size: int,
timeout: float = 300.0,
uds: str | None = None,
weight_version: str | None = None,
) -> Callable[[list[tuple[str, str]]], None]:
rank = int(os.getenv("RANK", 0))
src = rank // inference_parallel_size * inference_parallel_size

def req_func(socket_paths: list[tuple[str, str]]):
if rank == src:
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
resp = client.post(
f"{endpoint}/update_weights_from_ipc",
json={
"zmq_handles": dict(
socket_paths[src : src + inference_parallel_size]
),
"flush_cache": True,
"weight_version": weight_version,
},
timeout=timeout,
)
resp.raise_for_status()

return req_func


def update_weights(
ps: ParameterServer,
checkpoint_name: str,
checkpoint_files: list[str],
named_tensors: dict[str, torch.Tensor],
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
save_metas_file: str | None = None,
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
uds: str | None = None,
):
ps.register_checkpoint(
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
)
ps.init_process_group()
check_sglang_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas"):
ps.gather_metas(checkpoint_name)
if save_metas_file and int(os.getenv("RANK")) == 0:
with open(save_metas_file, "wb") as f:
pickle.dump(ps.get_metas(), f)

if update_method == "broadcast" or update_method == "all":
with timer("Update weights without setting ranks"):
ps.update(checkpoint_name, req_func)

if update_method == "p2p" or update_method == "all":
if update_method:
# sleep 2s to wait destroy process group
time.sleep(2)
Comment on lines +152 to +154
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if update_method: is redundant because it is nested inside if update_method == "p2p" or update_method == "all":. In this block, update_method will always be a non-empty string, which evaluates to True. You can remove this inner if statement.

        # sleep 2s to wait destroy process group
        time.sleep(2)

with timer("Update weights with setting ranks"):
ps.update(
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
)


def join(
ps: ParameterServer,
checkpoint_name: str,
load_metas_file: str,
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
uds: str | None = None,
):
assert load_metas_file, "load_metas_file is required"
with open(load_metas_file, "rb") as f:
metas = pickle.load(f)
ps.init_process_group()
check_sglang_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas before join"):
ps.gather_metas(checkpoint_name)
ps.load_metas(metas)
with timer(
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
):
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Update weights example")
parser.add_argument("--checkpoint-path", type=str, default=None)
parser.add_argument("--save-metas-file", type=str, default=None)
parser.add_argument("--load-metas-file", type=str, default=None)
parser.add_argument("--sleep-time", type=int, default=0)
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
parser.add_argument("--inference-parallel-size", type=int, default=8)
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
parser.add_argument("--update-method", type=str, default="broadcast")
parser.add_argument("--uds", type=str, default=None)
parser.add_argument("--weight-version", type=str, default=None)
args = parser.parse_args()
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
req_func = req_inference(
args.endpoint,
args.inference_parallel_size,
uds=args.uds,
weight_version=args.weight_version,
)
ps = ParameterServer(auto_pg=True)
ps._p2p_store = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing the private member _p2p_store of the ParameterServer instance is generally not recommended as it relies on internal implementation details that might change in future versions of checkpoint-engine. If this is a necessary workaround, it would be beneficial to add a comment explaining why this is needed. If there's a public API to achieve the same result, it should be preferred.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. rank and world_size are initialized here in main according to env var. So they are the parameters of torch run? Should we make them the global params? I am seeing many os.getenv which don't feel good.

    rank = int(os.getenv("RANK"))
    world_size = int(os.getenv("WORLD_SIZE"))

Should they be given a prefix like SGLANG_CKPT_ENGINE_?

if args.load_metas_file:
join(
ps,
args.checkpoint_name,
args.load_metas_file,
req_func,
args.inference_parallel_size,
args.endpoint,
args.uds,
)
else:
if os.path.exists(
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
):
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
checkpoint_files = []
else:
checkpoint_files = split_checkpoint_files(
args.checkpoint_path, rank, world_size
)
named_tensors = {}
update_weights(
ps,
args.checkpoint_name,
checkpoint_files,
named_tensors,
req_func,
args.inference_parallel_size,
args.endpoint,
args.save_metas_file,
args.update_method,
args.uds,
)
time.sleep(args.sleep_time)
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ test = [
"sentence_transformers",
"tabulate",
]
checkpoint-engine = ["checkpoint-engine==0.1.2"]
all = []
dev = ["sglang[test]"]

Expand Down
142 changes: 142 additions & 0 deletions python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Checkpoint-engine integration for SGLang.
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
"""
import logging
from typing import Callable, Dict, Optional

import torch
import zmq

try:
from checkpoint_engine.worker import update_weights_from_ipc
except ImportError:
raise ImportError(
"checkpoint-engine is not installed. "
"Please install it with: pip install sglang[checkpoint-engine]"
)

logger = logging.getLogger(__name__)


class SGLangCheckpointEngineWorkerExtension:
"""
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
This class provides the interface needed for checkpoint-engine integration.
"""

def __init__(self):
self._zmq_ctx: Optional[zmq.Context] = None

def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# We need to implement this to get the device UUID
# This will be overridden when integrated into SGLang's worker
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)

def get_device_id(self) -> int:
"""Get the device ID."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)

def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)

def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""
return None

def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
"""
Update weights from IPC communication.
Args:
zmq_handles: Dict mapping device UUID to ZMQ socket path
"""
if self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
device_uuid = self.get_device_uuid()
device_id = self.get_device_id()
if device_uuid not in zmq_handles:
raise ValueError(
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
)
update_weights_from_ipc(
self._zmq_ctx,
zmq_handles[device_uuid],
device_id=device_id,
run=self.get_model_loader(),
post_hook=self.get_post_hook(),
)


class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
"""
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
"""

def __init__(self, model_runner):
super().__init__()
self.model_runner = model_runner

def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# Get device UUID for current device
device_id = torch.cuda.current_device()
try:
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
except AssertionError as e:
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e

def get_device_id(self) -> int:
"""Get the device ID."""
return torch.cuda.current_device()

def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
return self.model_runner.model.load_weights

def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""

def post_hook():
# Perform post-processing after weight loading similar to DefaultModelLoader
try:
from sglang.srt.model_loader.loader import device_loading_context

# Process quantization methods after loading weights
for _, module in self.model_runner.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# Move parameters to device if needed for quantization processing
target_device = torch.device(
"cuda", torch.cuda.current_device()
)
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Call model-specific post-loading hook if available
if hasattr(self.model_runner.model, "post_load_weights"):
self.model_runner.model.post_load_weights()
except Exception as e:
logger.warning(f"Post-hook processing failed: {e}")

return post_hook
Loading
Loading