-
Notifications
You must be signed in to change notification settings - Fork 5.4k
[diffusion] Add update_weights_from_tensor checker #21106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
MikukuOvO
wants to merge
8
commits into
sgl-project:main
Choose a base branch
from
MikukuOvO:fenglin/update-weight-from-tensor-checker
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f2e3d9f
multimodal_gen: add update_weights_from_tensor pipeline
gxlvera bd62205
[diffusion] add readme for update weight from tensor
gxlvera 8e77a2a
[diffusion] update weight from tensor reuse FlattenTensorBucket class
gxlvera f34a894
[diffusion] feat: add update_weights_from_tensor checker
MikukuOvO a0eaa41
[diffusion] test: add 2gpu update_weights_from_tensor checker e2e
MikukuOvO 47d9600
[diffusion] test: add multiprocessing serializer e2e coverage
MikukuOvO d928bf6
Merge remote-tracking branch 'upstream/main' into fenglin/update-weig…
MikukuOvO c1016f5
[diffusion] fix: verify full TP tensor checksums on root
MikukuOvO File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -42,6 +42,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import gc | ||||||||||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||
| from torch.distributed.tensor import DTensor, distribute_tensor | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -57,6 +58,10 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model | ||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin | ||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger | ||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.weight_sync.tensor_bucket import ( | ||||||||||||||||||||||||||||||||||||||||||||||
| FlattenedTensorBucket, | ||||||||||||||||||||||||||||||||||||||||||||||
| FlattenedTensorMetadata, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -291,3 +296,158 @@ def _rollback(self, updated_modules: list[str]) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
| weights_iter = _get_weights_iter(str(weights_dir)) | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_weights_into_module(module, weights_iter) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def update_weights_from_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors: Any, | ||||||||||||||||||||||||||||||||||||||||||||||
| load_format: str | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| target_modules: list[str] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[bool, str]: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Update module weights from in-memory tensor payloads.""" | ||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| modules_to_update = self._collect_modules(target_modules) | ||||||||||||||||||||||||||||||||||||||||||||||
| except ValueError as e: | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.error(str(e)) | ||||||||||||||||||||||||||||||||||||||||||||||
| return False, str(e) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if not modules_to_update: | ||||||||||||||||||||||||||||||||||||||||||||||
| error_msg = ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"No matching modules found for update. " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Requested: {target_modules}. " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.error(error_msg) | ||||||||||||||||||||||||||||||||||||||||||||||
| return False, error_msg | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| module_payloads = self._resolve_module_payloads( | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors=named_tensors, | ||||||||||||||||||||||||||||||||||||||||||||||
| modules_to_update=modules_to_update, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| except ValueError as e: | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.error(str(e)) | ||||||||||||||||||||||||||||||||||||||||||||||
| return False, str(e) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| updated_modules: list[str] = [] | ||||||||||||||||||||||||||||||||||||||||||||||
| for module_name, module in modules_to_update: | ||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| payload = module_payloads[module_name] | ||||||||||||||||||||||||||||||||||||||||||||||
| weights_iter = self._materialize_weights_iter(payload, load_format) | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_weights_into_module(module, weights_iter) | ||||||||||||||||||||||||||||||||||||||||||||||
| updated_modules.append(module_name) | ||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||
| rollback_list = updated_modules + [module_name] | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Tensor weight update failed for module '{module_name}': {e}. " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Rolling back {len(rollback_list)} module(s) " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"(including partially-loaded '{module_name}'): " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"{rollback_list}.", | ||||||||||||||||||||||||||||||||||||||||||||||
| exc_info=True, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._rollback(rollback_list) | ||||||||||||||||||||||||||||||||||||||||||||||
| return False, ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Failed to update module '{module_name}' from tensor: {e}. " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"All modules rolled back to original weights." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| gc.collect() | ||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||||||||
| names = ", ".join(updated_modules) | ||||||||||||||||||||||||||||||||||||||||||||||
| message = f"Updated {len(updated_modules)} modules from tensor ({names})." | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(message) | ||||||||||||||||||||||||||||||||||||||||||||||
| return True, message | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _resolve_module_payloads( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors: Any, | ||||||||||||||||||||||||||||||||||||||||||||||
| modules_to_update: list[tuple[str, torch.nn.Module]], | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Resolve a generic tensor payload to per-module payload mapping.""" | ||||||||||||||||||||||||||||||||||||||||||||||
| module_names = [name for name, _ in modules_to_update] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Preferred format for multi-module update: | ||||||||||||||||||||||||||||||||||||||||||||||
| # { | ||||||||||||||||||||||||||||||||||||||||||||||
| # "transformer": <module_payload>, | ||||||||||||||||||||||||||||||||||||||||||||||
| # "vae": <module_payload>, | ||||||||||||||||||||||||||||||||||||||||||||||
| # } | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(named_tensors, dict): | ||||||||||||||||||||||||||||||||||||||||||||||
| missing = [name for name in module_names if name not in named_tensors] | ||||||||||||||||||||||||||||||||||||||||||||||
| if missing: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Missing tensor payload for module(s): {missing}. " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Provided modules: {list(named_tensors.keys())}" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return {name: named_tensors[name] for name in module_names} | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Single-module shortcut: allow direct payload when exactly one target module exists. | ||||||||||||||||||||||||||||||||||||||||||||||
| if len(module_names) == 1: | ||||||||||||||||||||||||||||||||||||||||||||||
| return {module_names[0]: named_tensors} | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| "Ambiguous tensor payload for multi-module update. " | ||||||||||||||||||||||||||||||||||||||||||||||
| "Provide a dict mapping module_name -> module payload, " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"requested modules: {module_names}." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _materialize_weights_iter(self, module_payload: Any, load_format: str | None): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Convert one module payload to an iterator of (param_name, tensor).""" | ||||||||||||||||||||||||||||||||||||||||||||||
| if load_format == "flattened_bucket": | ||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(module_payload, dict): | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| "flattened_bucket payload must be a dict with " | ||||||||||||||||||||||||||||||||||||||||||||||
| "'flattened_tensor' and 'metadata'." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| flattened_tensor = module_payload.get("flattened_tensor") | ||||||||||||||||||||||||||||||||||||||||||||||
| metadata = module_payload.get("metadata") | ||||||||||||||||||||||||||||||||||||||||||||||
| if flattened_tensor is None or metadata is None: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| "flattened_bucket payload missing 'flattened_tensor' or 'metadata'." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return self._reconstruct_from_flattened_bucket(flattened_tensor, metadata) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Default/direct format: list/tuple of (name, tensor) | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(module_payload, (list, tuple)): | ||||||||||||||||||||||||||||||||||||||||||||||
| return iter(module_payload) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Unsupported module payload type for load_format={load_format}: " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"{type(module_payload).__name__}" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _reconstruct_from_flattened_bucket(self, flattened_tensor: Any, metadata: Any): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Reconstruct [(name, tensor), ...] from flattened-bucket payload.""" | ||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(flattened_tensor, torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| "flattened_bucket 'flattened_tensor' must be a torch.Tensor." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(metadata, list): | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("flattened_bucket 'metadata' must be a list.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| converted_metadata: list[FlattenedTensorMetadata] = [] | ||||||||||||||||||||||||||||||||||||||||||||||
| for meta in metadata: | ||||||||||||||||||||||||||||||||||||||||||||||
| converted_meta = FlattenedTensorMetadata( | ||||||||||||||||||||||||||||||||||||||||||||||
| name=meta.name, | ||||||||||||||||||||||||||||||||||||||||||||||
| shape=torch.Size(meta.shape), | ||||||||||||||||||||||||||||||||||||||||||||||
| dtype=self._normalize_torch_dtype(meta.dtype), | ||||||||||||||||||||||||||||||||||||||||||||||
| start_idx=int(meta.start_idx), | ||||||||||||||||||||||||||||||||||||||||||||||
| end_idx=int(meta.end_idx), | ||||||||||||||||||||||||||||||||||||||||||||||
| numel=int(meta.numel), | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| converted_metadata.append(converted_meta) | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+426
to
+436
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For improved readability and conciseness, this loop for building
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| bucket = FlattenedTensorBucket( | ||||||||||||||||||||||||||||||||||||||||||||||
| flattened_tensor=flattened_tensor, | ||||||||||||||||||||||||||||||||||||||||||||||
| metadata=converted_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return bucket.reconstruct_tensors() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _normalize_torch_dtype(self, dtype: Any) -> torch.dtype: | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(dtype, torch.dtype): | ||||||||||||||||||||||||||||||||||||||||||||||
| return dtype | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(dtype, str): | ||||||||||||||||||||||||||||||||||||||||||||||
| # Supports "torch.float16" and "float16". | ||||||||||||||||||||||||||||||||||||||||||||||
| name = dtype.split(".")[-1] | ||||||||||||||||||||||||||||||||||||||||||||||
| normalized = getattr(torch, name, None) | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(normalized, torch.dtype): | ||||||||||||||||||||||||||||||||||||||||||||||
| return normalized | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Unsupported dtype in flattened_bucket metadata: {dtype!r}") | ||||||||||||||||||||||||||||||||||||||||||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication in handling scheduler requests and responses. The logic within this
try...exceptblock and the subsequent response parsing is nearly identical inupdate_weights_from_disk,update_weights_from_tensor, andupdate_weights_from_tensor_checker. To improve maintainability and reduce redundancy, consider extracting this common logic into a helper function.For example, you could create a helper like this:
Each endpoint could then call this helper after creating the specific request object, simplifying the endpoint logic.