Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/runtime/engine/save_remote_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_remote_state.py \
--model-path /path/to/load \
--tensor-parallel-size 8 \
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \

Then, the model can be loaded with

llm = Engine(
model_path="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name],
tensor_parallel_size=8,
)
"""
import dataclasses
from argparse import ArgumentParser
from pathlib import Path

from sglang import Engine, ServerArgs

parser = ArgumentParser()
ServerArgs.add_cli_args(parser)

parser.add_argument(
"--remote-model-save-url",
required=True,
type=str,
help="remote address to store model weights",
)


def main(args):
engine_args = ServerArgs.from_cli_args(args)
model_path = engine_args.model_path
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args))
llm.save_remote_model(url=args.remote_model_save_url)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
74 changes: 74 additions & 0 deletions examples/runtime/engine/save_sharded_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
--model-path /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save

Then, the model can be loaded with

llm = Engine(
model_path="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import dataclasses
import os
import shutil
from argparse import ArgumentParser
from pathlib import Path

from sglang import Engine, ServerArgs

parser = ArgumentParser()
ServerArgs.add_cli_args(parser)

parser.add_argument(
"--output", "-o", required=True, type=str, help="path to output checkpoint"
)
parser.add_argument(
"--file-pattern", type=str, help="string pattern of saved filenames"
)
parser.add_argument(
"--max-file-size",
type=str,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file",
)


def main(args):
engine_args = ServerArgs.from_cli_args(args)
model_path = engine_args.model_path
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args))
Path(args.output).mkdir(exist_ok=True)
llm.save_sharded_model(
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
)

# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(
os.path.join(model_path, file), os.path.join(args.output, file)
)
else:
shutil.copy(os.path.join(model_path, file), args.output)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from sglang.utils import LazyImport

ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
Expand Down Expand Up @@ -67,6 +68,7 @@
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
"ServerArgs",
"Anthropic",
"LiteLLM",
"OpenAI",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"


@dataclass
Expand Down
26 changes: 25 additions & 1 deletion python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ def __init__(
self.quantization = quantization

# Parse args
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()

self.hf_config = get_config(
model_path,
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
Expand Down Expand Up @@ -318,6 +319,29 @@ def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids

def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.

Args:
model: The model name or path.

"""
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url

if is_remote_url(self.model_path):
logger.info("Pulling model configs from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(self.model_path)
if is_remote_url(self.model_path):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model_path
self.model_path = client.get_local_dir()


def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
Expand Down
51 changes: 51 additions & 0 deletions python/sglang/srt/connector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0

import enum
import logging

from sglang.srt.connector.base_connector import (
BaseConnector,
BaseFileConnector,
BaseKVConnector,
)
from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type

logger = logging.getLogger(__name__)


class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"


def create_remote_connector(url, device="cpu") -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
elif connector_type == "s3":
return S3Connector(url)
else:
raise ValueError(f"Invalid connector type: {url}")


def get_connector_type(client: BaseConnector) -> ConnectorType:
if isinstance(client, BaseKVConnector):
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS

raise ValueError(f"Invalid connector type: {client}")


__all__ = [
"BaseConnector",
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
"get_connector_type",
]
112 changes: 112 additions & 0 deletions python/sglang/srt/connector/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0

import os
import shutil
import signal
import tempfile
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple

import torch


class BaseConnector(ABC):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>

For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""

def __init__(self, url: str, device: torch.device = "cpu"):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))

def get_local_dir(self):
return self.local_dir

@abstractmethod
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
raise NotImplementedError()

@abstractmethod
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
raise NotImplementedError()

def close(self):
if self.closed:
return

self.closed = True
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def __del__(self):
self.close()

def _close_by_signal(self, existing_handler=None):

def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)

return new_handler


class BaseKVConnector(BaseConnector):

@abstractmethod
def get(self, key: str) -> Optional[torch.Tensor]:
raise NotImplementedError()

@abstractmethod
def getstr(self, key: str) -> Optional[str]:
raise NotImplementedError()

@abstractmethod
def set(self, key: str, obj: torch.Tensor) -> None:
raise NotImplementedError()

@abstractmethod
def setstr(self, key: str, obj: str) -> None:
raise NotImplementedError()

@abstractmethod
def list(self, prefix: str) -> List[str]:
raise NotImplementedError()


class BaseFileConnector(BaseConnector):
"""
List full file names from remote fs path and filter by allow pattern.

Args:
allow_pattern: A list of patterns of which files to pull.

Returns:
list[str]: List of full paths allowed by the pattern
"""

@abstractmethod
def glob(self, allow_pattern: str) -> List[str]:
raise NotImplementedError()
Loading
Loading