-
Notifications
You must be signed in to change notification settings - Fork 3.4k
support custom weight loader for model runner #7122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
merrymercy
merged 6 commits into
sgl-project:main
from
yukavio:support_custom_weight_updator
Jun 16, 2025
Merged
Changes from 2 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6c068f7
support custom weight loader for model runner
86f8886
Merge branch 'main' into support_custom_weight_updator
yukavio 953da94
refine the code
bda4710
refine test
3888353
update to main and fix conflict
0409724
Merge branch 'main' into support_custom_weight_updator
yukavio File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,6 +230,9 @@ class ServerArgs: | |
| num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD | ||
| pdlb_url: Optional[str] = None | ||
|
|
||
| # For model weight update | ||
| custom_weight_loader: Optional[List[str]] = None | ||
|
|
||
| def __post_init__(self): | ||
| # Expert parallelism | ||
| if self.enable_ep_moe: | ||
|
|
@@ -519,6 +522,9 @@ def __post_init__(self): | |
| "1" if self.disable_outlines_disk_cache else "0" | ||
| ) | ||
|
|
||
| if self.custom_weight_loader is None: | ||
| self.custom_weight_loader = [] | ||
|
|
||
| @staticmethod | ||
| def add_cli_args(parser: argparse.ArgumentParser): | ||
| # Model and port args | ||
|
|
@@ -1526,6 +1532,12 @@ def add_cli_args(parser: argparse.ArgumentParser): | |
| default=None, | ||
| help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", | ||
| ) | ||
| parser.add_argument( | ||
| "--custom-weight-loader", | ||
| type=str, | ||
| default=None, | ||
| help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
|
|
||
| @classmethod | ||
| def from_cli_args(cls, args: argparse.Namespace): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,6 +78,40 @@ def test_update_weights_from_tensor_load_format_direct(self): | |
|
|
||
| engine.shutdown() | ||
|
|
||
| def test_update_weights_from_tensor_load_format_custom(self): | ||
| custom_loader_name = ( | ||
| "sglang.srt.model_executor.model_runner._model_load_weights_direct" | ||
| ) | ||
| engine = sgl.Engine( | ||
| model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, | ||
| custom_weight_loader=custom_loader_name, | ||
|
||
| ) | ||
|
|
||
| write_param_names = [ | ||
| f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16) | ||
| ] | ||
| read_param_names = [ | ||
| f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16) | ||
| ] | ||
|
|
||
| _check_param( | ||
| engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178] | ||
| ) | ||
|
|
||
| new_tensor = torch.full((3072, 2048), 1.5) | ||
| engine.update_weights_from_tensor( | ||
| [ | ||
| (write_param_name, new_tensor.clone()) | ||
| for write_param_name in write_param_names | ||
| ], | ||
| load_format=custom_loader_name, | ||
| ) | ||
|
|
||
| for read_param_name in read_param_names[:3]: | ||
| _check_param(engine, read_param_name, [1.5] * 5) | ||
|
|
||
| engine.shutdown() | ||
|
|
||
|
|
||
| def _check_param(engine, param_name, expect_values): | ||
| actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5] | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
dynamic_importfunction returns the imported function itself. Call the custom loader function directly.