Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from sglang.srt.utils import (
MultiprocessingSerializer,
cpu_has_amx_support,
dynamic_import,
enable_show_time_cost,
get_available_gpu_memory,
get_bool_env_var,
Expand Down Expand Up @@ -761,6 +762,9 @@ def update_weights_from_tensor(
]
if load_format == "direct":
_model_load_weights_direct(self.model, named_tensors)
elif load_format in self.server_args.custom_weight_loader:
custom_loader = dynamic_import(load_format)
custom_loader.load(self.model, named_tensors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dynamic_import function returns the imported function itself. Call the custom loader function directly.

Suggested change
custom_loader = dynamic_import(load_format)
custom_loader.load(self.model, named_tensors)
custom_loader_func = dynamic_import(load_format)
custom_loader_func(self.model, named_tensors)

elif load_format is None:
self.model.load_weights(named_tensors)
else:
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class ServerArgs:
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None

# For model weight update
custom_weight_loader: Optional[List[str]] = None

def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
Expand Down Expand Up @@ -519,6 +522,9 @@ def __post_init__(self):
"1" if self.disable_outlines_disk_cache else "0"
)

if self.custom_weight_loader is None:
self.custom_weight_loader = []

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
Expand Down Expand Up @@ -1526,6 +1532,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)
parser.add_argument(
"--custom-weight-loader",
type=str,
default=None,
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The --custom-weight-loader argument should accept a list of import paths. Use nargs='*' and update the help string.

            type=str,
            nargs="*",
            default=None,
            help="A list of import paths for custom weight loader functions. Each path should be like `my_package.my_module.load_func`.",

)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,3 +2340,16 @@ def value(self):
self._value = self._creator()
self._creator = None
return self._value


def dynamic_import(func_path: str):
parts = func_path.split(".")
if len(parts) < 2:
raise ValueError(
"func_path should contain both module name and func name (such as 'module.func')"
)
module_path = ".".join(parts[:-1])
func_name = parts[-1]
module = importlib.import_module(module_path)
func = getattr(module, func_name)
return func
34 changes: 34 additions & 0 deletions test/srt/test_update_weights_from_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,40 @@ def test_update_weights_from_tensor_load_format_direct(self):

engine.shutdown()

def test_update_weights_from_tensor_load_format_custom(self):
custom_loader_name = (
"sglang.srt.model_executor.model_runner._model_load_weights_direct"
)
engine = sgl.Engine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
custom_weight_loader=custom_loader_name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass custom_weight_loader as a list, even if it contains only one element, to align with the expected type.

Suggested change
custom_weight_loader=custom_loader_name,
custom_weight_loader=[custom_loader_name],

)

write_param_names = [
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
]
read_param_names = [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
]

_check_param(
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
)

new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(
[
(write_param_name, new_tensor.clone())
for write_param_name in write_param_names
],
load_format=custom_loader_name,
)

for read_param_name in read_param_names[:3]:
_check_param(engine, read_param_name, [1.5] * 5)

engine.shutdown()


def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
Expand Down