From f2e3d9fcfa5425b411126d257b30c07a128d71aa Mon Sep 17 00:00:00 2001 From: Xiaole Guo Date: Thu, 12 Mar 2026 17:03:49 +0000 Subject: [PATCH 1/7] multimodal_gen: add update_weights_from_tensor pipeline --- .../entrypoints/post_training/io_struct.py | 19 ++ .../entrypoints/post_training/weights_api.py | 36 ++++ .../runtime/loader/weights_updater.py | 167 ++++++++++++++++++ .../runtime/managers/gpu_worker.py | 40 +++++ .../runtime/managers/scheduler.py | 17 ++ 5 files changed, 279 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py index bda72df12a8f..5b152cf42d8d 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -1,6 +1,7 @@ """Request/response data structures for post-training APIs.""" from dataclasses import dataclass +from typing import Optional, Union @dataclass @@ -12,6 +13,24 @@ class UpdateWeightFromDiskReqInput: target_modules: list[str] | None = None +@dataclass +class UpdateWeightFromTensorReqInput: + """Request to update model weights from tensor payloads for diffusion models.""" + + serialized_named_tensors: list[Union[str, bytes]] + load_format: Optional[str] = None + target_modules: list[str] | None = None + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightFromTensorReqOutput: + """Response for update_weights_from_tensor request.""" + + success: bool + message: str + + @dataclass class GetWeightsChecksumReqInput: """Compute SHA-256 checksum of loaded module weights for verification.""" diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py index 1b9312d8ea0f..ecdade87e0e9 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -6,6 +6,7 @@ from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightFromTensorReqInput, ) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client @@ -46,6 +47,41 @@ async def update_weights_from_disk(request: Request): ) +@router.post("/update_weights_from_tensor") +async def update_weights_from_tensor(request: Request): + """Update model weights from serialized tensor payloads.""" + body = await request.json() + serialized_named_tensors = body.get("serialized_named_tensors") + if not serialized_named_tensors: + return ORJSONResponse( + {"success": False, "message": "serialized_named_tensors is required"}, + status_code=400, + ) + + req = UpdateWeightFromTensorReqInput( + serialized_named_tensors=serialized_named_tensors, + load_format=body.get("load_format"), + target_modules=body.get("target_modules"), + weight_version=body.get("weight_version"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + + @router.post("/get_weights_checksum") async def get_weights_checksum(request: Request): """Return SHA-256 checksum of each requested module's weights.""" diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index f170809a738e..4bd2cf516ff7 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -42,6 +42,7 @@ import gc from pathlib import Path +from typing import Any import torch from torch.distributed.tensor import DTensor, distribute_tensor @@ -291,3 +292,169 @@ def _rollback(self, updated_modules: list[str]) -> None: continue weights_iter = _get_weights_iter(str(weights_dir)) _load_weights_into_module(module, weights_iter) + + def update_weights_from_tensor( + self, + named_tensors: Any, + load_format: str | None = None, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update module weights from in-memory tensor payloads.""" + try: + modules_to_update = self._collect_modules(target_modules) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + + if not modules_to_update: + error_msg = ( + f"No matching modules found for update. " + f"Requested: {target_modules}. " + f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}" + ) + logger.error(error_msg) + return False, error_msg + + try: + module_payloads = self._resolve_module_payloads( + named_tensors=named_tensors, + modules_to_update=modules_to_update, + ) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + + updated_modules: list[str] = [] + for module_name, module in modules_to_update: + try: + payload = module_payloads[module_name] + weights_iter = self._materialize_weights_iter(payload, load_format) + _load_weights_into_module(module, weights_iter) + updated_modules.append(module_name) + except Exception as e: + rollback_list = updated_modules + [module_name] + logger.error( + f"Tensor weight update failed for module '{module_name}': {e}. " + f"Rolling back {len(rollback_list)} module(s) " + f"(including partially-loaded '{module_name}'): " + f"{rollback_list}.", + exc_info=True, + ) + self._rollback(rollback_list) + return False, ( + f"Failed to update module '{module_name}' from tensor: {e}. " + f"All modules rolled back to original weights." + ) + + gc.collect() + torch.cuda.empty_cache() + names = ", ".join(updated_modules) + message = f"Updated {len(updated_modules)} modules from tensor ({names})." + logger.info(message) + return True, message + + def _resolve_module_payloads( + self, + named_tensors: Any, + modules_to_update: list[tuple[str, torch.nn.Module]], + ) -> dict[str, Any]: + """Resolve a generic tensor payload to per-module payload mapping.""" + module_names = [name for name, _ in modules_to_update] + + # Preferred format for multi-module update: + # { + # "transformer": , + # "vae": , + # } + if isinstance(named_tensors, dict): + missing = [name for name in module_names if name not in named_tensors] + if missing: + raise ValueError( + f"Missing tensor payload for module(s): {missing}. " + f"Provided modules: {list(named_tensors.keys())}" + ) + return {name: named_tensors[name] for name in module_names} + + # Single-module shortcut: allow direct payload when exactly one target module exists. + if len(module_names) == 1: + return {module_names[0]: named_tensors} + + raise ValueError( + "Ambiguous tensor payload for multi-module update. " + "Provide a dict mapping module_name -> module payload, " + f"requested modules: {module_names}." + ) + + def _materialize_weights_iter(self, module_payload: Any, load_format: str | None): + """Convert one module payload to an iterator of (param_name, tensor).""" + if load_format == "flattened_bucket": + if not isinstance(module_payload, dict): + raise ValueError( + "flattened_bucket payload must be a dict with " + "'flattened_tensor' and 'metadata'." + ) + flattened_tensor = module_payload.get("flattened_tensor") + metadata = module_payload.get("metadata") + if flattened_tensor is None or metadata is None: + raise ValueError( + "flattened_bucket payload missing 'flattened_tensor' or 'metadata'." + ) + return self._reconstruct_from_flattened_bucket(flattened_tensor, metadata) + + # Default/direct format: list/tuple of (name, tensor) + if isinstance(module_payload, (list, tuple)): + return iter(module_payload) + + raise ValueError( + f"Unsupported module payload type for load_format={load_format}: " + f"{type(module_payload).__name__}" + ) + + def _reconstruct_from_flattened_bucket(self, flattened_tensor: Any, metadata: Any): + """Reconstruct [(name, tensor), ...] from flattened-bucket payload.""" + if not isinstance(flattened_tensor, torch.Tensor): + raise ValueError( + "flattened_bucket 'flattened_tensor' must be a torch.Tensor." + ) + if not isinstance(metadata, list): + raise ValueError("flattened_bucket 'metadata' must be a list.") + + reconstructed: list[tuple[str, torch.Tensor]] = [] + for item in metadata: + if hasattr(item, "name"): + name = item.name + shape = item.shape + dtype = item.dtype + start_idx = item.start_idx + end_idx = item.end_idx + elif isinstance(item, dict): + name = item["name"] + shape = item["shape"] + dtype = item["dtype"] + start_idx = item["start_idx"] + end_idx = item["end_idx"] + else: + raise ValueError( + "Each flattened_bucket metadata item must be an object " + "with fields (name, shape, dtype, start_idx, end_idx) " + "or a dict containing those keys." + ) + + dtype = self._normalize_torch_dtype(dtype) + tensor = ( + flattened_tensor[start_idx:end_idx].view(dtype).reshape(torch.Size(shape)) + ) + reconstructed.append((name, tensor)) + + return iter(reconstructed) + + def _normalize_torch_dtype(self, dtype: Any) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + # Supports "torch.float16" and "float16". + name = dtype.split(".")[-1] + normalized = getattr(torch, name, None) + if isinstance(normalized, torch.dtype): + return normalized + raise ValueError(f"Unsupported dtype in flattened_bucket metadata: {dtype!r}") diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index eb225a43cd5a..b5018fe68d33 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -29,6 +29,9 @@ get_ulysses_parallel_world_size, ) from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + UpdateWeightFromTensorReqInput, +) from sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum from sglang.multimodal_gen.runtime.loader.weights_updater import ( WeightsUpdater, @@ -57,6 +60,7 @@ PerformanceLogger, capture_memory_snapshot, ) +from sglang.srt.utils import MultiprocessingSerializer logger = init_logger(__name__) @@ -417,6 +421,42 @@ def update_weights_from_disk( self.pipeline.model_path = model_path return success, message + def update_weights_from_tensor( + self, + req: UpdateWeightFromTensorReqInput, + ) -> tuple[bool, str]: + """Update model weights from serialized tensor payloads.""" + if not self.pipeline: + return False, "Pipeline is not initialized" + + payloads = req.serialized_named_tensors + if not payloads: + return False, "serialized_named_tensors is required" + + tp_world_size = get_tp_world_size() + if len(payloads) not in (1, tp_world_size): + return ( + False, + "serialized_named_tensors size must be 1 or tp_size " + f"({tp_world_size}), got {len(payloads)}", + ) + + payload_idx = get_tp_rank() if len(payloads) == tp_world_size else 0 + try: + named_tensors = MultiprocessingSerializer.deserialize(payloads[payload_idx]) + except Exception as e: + return False, f"Failed to deserialize serialized_named_tensors: {e}" + + updater = WeightsUpdater(self.pipeline) + if not hasattr(updater, "update_weights_from_tensor"): + return False, "update_weights_from_tensor is not implemented in WeightsUpdater" + + return updater.update_weights_from_tensor( + named_tensors=named_tensors, + load_format=req.load_format, + target_modules=req.target_modules, + ) + def get_weights_checksum( self, module_names: list[str] | None = None ) -> dict[str, str]: diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index c11c2c850224..4279370e8709 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -17,6 +17,7 @@ from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightFromTensorReqInput, ) from sglang.multimodal_gen.runtime.entrypoints.utils import ( ListLorasReq, @@ -95,6 +96,7 @@ def __init__( ListLorasReq: self._handle_list_loras, ShutdownReq: self._handle_shutdown, UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, + UpdateWeightFromTensorReqInput: self._handle_update_weights_from_tensor, GetWeightsChecksumReqInput: self._handle_get_weights_checksum, } @@ -149,6 +151,21 @@ def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: error=None if success else message, ) + def _handle_update_weights_from_tensor(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_tensor request for RL workflows.""" + req = reqs[0] + success, message = self.worker.update_weights_from_tensor(req) + + if self.server_args.tp_size > 1: + import torch + + torch.distributed.barrier(group=self.worker.tp_cpu_group) + + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) + def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch: """Handle get_weights_checksum request.""" req = reqs[0] From bd62205a788f9692585da03f1499377f751979b9 Mon Sep 17 00:00:00 2001 From: Xiaole Guo Date: Thu, 12 Mar 2026 17:16:40 +0000 Subject: [PATCH 2/7] [diffusion] add readme for update weight from tensor --- .../README_update_weights_from_tensor.md | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md diff --git a/python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md b/python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md new file mode 100644 index 000000000000..51c4c2706a45 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md @@ -0,0 +1,55 @@ +# Diffusion `update_weights_from_tensor` README + +This document describes the tensor-based in-place weight update flow for diffusion models in `sglang.multimodal_gen`. + +## Endpoint + +- `POST /update_weights_from_tensor` + +## Request Schema + +- `serialized_named_tensors: List[Union[str, bytes]]` (required) +- `load_format: Optional[str]` (`None`/`direct` or `"flattened_bucket"`) +- `target_modules: Optional[List[str]]` (for example: `["transformer"]`, `["transformer", "vae"]`) +- `weight_version: Optional[str]` + +Notes: +- `flush_cache` is intentionally not part of the tensor request. +- Response shape follows: + - `{"success": bool, "message": str}` + +## TP Payload Rules + +- `len(serialized_named_tensors)` must be either: + - `1`, or + - `tp_size` +- If length is `tp_size`, each TP rank consumes the payload at its own index. +- If length is `1`, all TP ranks consume index `0`. + +## Module Payload Rules + +- Single-module update (`target_modules` has one module): + - Payload can be passed directly. +- Multi-module update: + - Payload must be a dict keyed by module name: + +```python +{ + "transformer": , + "vae": , +} +``` + +## Supported Module Payload Formats + +- `load_format=None` (or `direct`-style payload): + - `[(param_name, tensor), ...]` +- `load_format="flattened_bucket"`: + - `{"flattened_tensor": tensor, "metadata": [...]}` + +## Safety Semantics + +- Only selected modules are updated (`target_modules` aware). +- Update is all-or-nothing across requested modules: + - on failure, already-updated modules are rolled back to previous disk weights. +- TP synchronization barrier is used in scheduler path to avoid mixed-rank model state. From 8e77a2a4f8f50a6ea100feed7be6878a274c2d8e Mon Sep 17 00:00:00 2001 From: Xiaole Guo Date: Sun, 15 Mar 2026 06:50:39 +0000 Subject: [PATCH 3/7] [diffusion] update weight from tensor reuse FlattenTensorBucket class --- .../runtime/loader/weights_updater.py | 45 ++++++++----------- .../runtime/managers/gpu_worker.py | 2 - 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 4bd2cf516ff7..29335350547a 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -58,6 +58,10 @@ from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.weight_sync.tensor_bucket import ( + FlattenedTensorBucket, + FlattenedTensorMetadata, +) logger = init_logger(__name__) @@ -419,34 +423,23 @@ def _reconstruct_from_flattened_bucket(self, flattened_tensor: Any, metadata: An if not isinstance(metadata, list): raise ValueError("flattened_bucket 'metadata' must be a list.") - reconstructed: list[tuple[str, torch.Tensor]] = [] - for item in metadata: - if hasattr(item, "name"): - name = item.name - shape = item.shape - dtype = item.dtype - start_idx = item.start_idx - end_idx = item.end_idx - elif isinstance(item, dict): - name = item["name"] - shape = item["shape"] - dtype = item["dtype"] - start_idx = item["start_idx"] - end_idx = item["end_idx"] - else: - raise ValueError( - "Each flattened_bucket metadata item must be an object " - "with fields (name, shape, dtype, start_idx, end_idx) " - "or a dict containing those keys." - ) - - dtype = self._normalize_torch_dtype(dtype) - tensor = ( - flattened_tensor[start_idx:end_idx].view(dtype).reshape(torch.Size(shape)) + converted_metadata: list[FlattenedTensorMetadata] = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=torch.Size(meta.shape), + dtype=self._normalize_torch_dtype(meta.dtype), + start_idx=int(meta.start_idx), + end_idx=int(meta.end_idx), + numel=int(meta.numel), ) - reconstructed.append((name, tensor)) + converted_metadata.append(converted_meta) - return iter(reconstructed) + bucket = FlattenedTensorBucket( + flattened_tensor=flattened_tensor, + metadata=converted_metadata, + ) + return bucket.reconstruct_tensors() def _normalize_torch_dtype(self, dtype: Any) -> torch.dtype: if isinstance(dtype, torch.dtype): diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index b5018fe68d33..7401d99e079d 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -448,8 +448,6 @@ def update_weights_from_tensor( return False, f"Failed to deserialize serialized_named_tensors: {e}" updater = WeightsUpdater(self.pipeline) - if not hasattr(updater, "update_weights_from_tensor"): - return False, "update_weights_from_tensor is not implemented in WeightsUpdater" return updater.update_weights_from_tensor( named_tensors=named_tensors, From f34a8942c9ab610e220869b5644c20ed4e5914b3 Mon Sep 17 00:00:00 2001 From: "Fenglin Yu (MikukuOvO)" Date: Sat, 21 Mar 2026 19:28:53 +0000 Subject: [PATCH 4/7] [diffusion] feat: add update_weights_from_tensor checker --- .../entrypoints/post_training/io_struct.py | 15 + .../entrypoints/post_training/weights_api.py | 36 +++ .../runtime/managers/gpu_worker.py | 70 ++++- .../runtime/managers/scheduler.py | 11 + .../update_weight_from_tensor_checker.py | 194 ++++++++++++ .../test_update_weight_from_tensor_checker.py | 169 ++++++++++ ..._update_weights_from_tensor_checker_e2e.py | 296 ++++++++++++++++++ 7 files changed, 777 insertions(+), 14 deletions(-) create mode 100644 python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py create mode 100644 test/test_update_weight_from_tensor_checker.py create mode 100644 test/test_update_weights_from_tensor_checker_e2e.py diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py index 5b152cf42d8d..3ea8884e1a48 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -31,6 +31,21 @@ class UpdateWeightFromTensorReqOutput: message: str +@dataclass +class UpdateWeightFromTensorCheckerReqInput: + """Request to verify live transformer weights against expected SHA-256 values.""" + + expected_transformer_sha256: list[dict[str, str]] + + +@dataclass +class UpdateWeightFromTensorCheckerReqOutput: + """Response for update_weights_from_tensor_checker request.""" + + success: bool + message: str + + @dataclass class GetWeightsChecksumReqInput: """Compute SHA-256 checksum of loaded module weights for verification.""" diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py index ecdade87e0e9..6c711bdf5d63 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -6,6 +6,7 @@ from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightFromTensorCheckerReqInput, UpdateWeightFromTensorReqInput, ) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client @@ -82,6 +83,41 @@ async def update_weights_from_tensor(request: Request): ) +@router.post("/update_weights_from_tensor_checker") +async def update_weights_from_tensor_checker(request: Request): + """Verify live transformer weights against expected SHA-256 values.""" + body = await request.json() + expected_transformer_sha256 = body.get("expected_transformer_sha256") + if ( + not isinstance(expected_transformer_sha256, list) + or not expected_transformer_sha256 + ): + return ORJSONResponse( + {"success": False, "message": "expected_transformer_sha256 is required"}, + status_code=400, + ) + + req = UpdateWeightFromTensorCheckerReqInput( + expected_transformer_sha256=expected_transformer_sha256, + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + + @router.post("/get_weights_checksum") async def get_weights_checksum(request: Request): """Return SHA-256 checksum of each requested module's weights.""" diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 7401d99e079d..97528fe30122 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -28,10 +28,11 @@ get_ulysses_parallel_rank, get_ulysses_parallel_world_size, ) -from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + UpdateWeightFromTensorCheckerReqInput, UpdateWeightFromTensorReqInput, ) +from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs from sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum from sglang.multimodal_gen.runtime.loader.weights_updater import ( WeightsUpdater, @@ -60,6 +61,9 @@ PerformanceLogger, capture_memory_snapshot, ) +from sglang.multimodal_gen.runtime.utils.update_weight_from_tensor_checker import ( + UpdateWeightFromTensorChecker, +) from sglang.srt.utils import MultiprocessingSerializer logger = init_logger(__name__) @@ -429,21 +433,15 @@ def update_weights_from_tensor( if not self.pipeline: return False, "Pipeline is not initialized" - payloads = req.serialized_named_tensors - if not payloads: - return False, "serialized_named_tensors is required" - - tp_world_size = get_tp_world_size() - if len(payloads) not in (1, tp_world_size): - return ( - False, - "serialized_named_tensors size must be 1 or tp_size " - f"({tp_world_size}), got {len(payloads)}", - ) + payloads, error = self._select_rank_scoped_payload( + payloads=req.serialized_named_tensors, + field_name="serialized_named_tensors", + ) + if error is not None: + return False, error - payload_idx = get_tp_rank() if len(payloads) == tp_world_size else 0 try: - named_tensors = MultiprocessingSerializer.deserialize(payloads[payload_idx]) + named_tensors = MultiprocessingSerializer.deserialize(payloads) except Exception as e: return False, f"Failed to deserialize serialized_named_tensors: {e}" @@ -455,6 +453,29 @@ def update_weights_from_tensor( target_modules=req.target_modules, ) + def update_weight_from_tensor_checker( + self, + req: UpdateWeightFromTensorCheckerReqInput, + ) -> tuple[bool, str]: + """Verify the live transformer weights against expected SHA-256 values.""" + if not self.pipeline: + return False, "Pipeline is not initialized" + + expected_transformer_sha256, error = self._select_rank_scoped_payload( + payloads=req.expected_transformer_sha256, + field_name="expected_transformer_sha256", + ) + if error is not None: + return False, error + + checker = UpdateWeightFromTensorChecker(self.pipeline) + return checker.verify_across_tp( + expected_transformer_sha256, + tp_rank=get_tp_rank(), + tp_world_size=get_tp_world_size(), + tp_cpu_group=self.tp_cpu_group, + ) + def get_weights_checksum( self, module_names: list[str] | None = None ) -> dict[str, str]: @@ -476,6 +497,27 @@ def get_weights_checksum( ) return checksums + def _select_rank_scoped_payload( + self, + payloads: list, + field_name: str, + ) -> tuple[object | None, str | None]: + if not isinstance(payloads, list): + return None, f"{field_name} must be a list" + if not payloads: + return None, f"{field_name} is required" + + tp_world_size = get_tp_world_size() + if len(payloads) not in (1, tp_world_size): + return ( + None, + f"{field_name} size must be 1 or tp_size ({tp_world_size}), " + f"got {len(payloads)}", + ) + + payload_idx = get_tp_rank() if len(payloads) == tp_world_size else 0 + return payloads[payload_idx], None + OOM_MSG = f""" OOM detected. Possible solutions: diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 4279370e8709..e520a441a094 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -17,6 +17,7 @@ from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightFromTensorCheckerReqInput, UpdateWeightFromTensorReqInput, ) from sglang.multimodal_gen.runtime.entrypoints.utils import ( @@ -97,6 +98,7 @@ def __init__( ShutdownReq: self._handle_shutdown, UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, UpdateWeightFromTensorReqInput: self._handle_update_weights_from_tensor, + UpdateWeightFromTensorCheckerReqInput: self._handle_update_weight_checker, GetWeightsChecksumReqInput: self._handle_get_weights_checksum, } @@ -172,6 +174,15 @@ def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch: checksums = self.worker.get_weights_checksum(module_names=req.module_names) return OutputBatch(output=checksums) + def _handle_update_weight_checker(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_tensor_checker request.""" + req = reqs[0] + success, message = self.worker.update_weight_from_tensor_checker(req) + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) + def _handle_generation(self, reqs: List[Req]): warmup_reqs = [req for req in reqs if req.is_warmup] if warmup_reqs: diff --git a/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py b/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py new file mode 100644 index 000000000000..01b6ba82c5f1 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py @@ -0,0 +1,194 @@ +"""Verification helpers for diffusion update_weights_from_tensor workflows.""" + +from __future__ import annotations + +import hashlib +from collections.abc import Iterable + +import torch +from torch.distributed.tensor import DTensor + +from sglang.multimodal_gen.runtime.utils.layerwise_offload import ( + iter_materialized_weights, +) + +_TRANSFORMER_MODULE_NAME = "transformer" +_MAX_DISPLAY_TENSORS = 5 + + +def _materialize_tensor_for_sha256(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + return tensor.detach().cpu().contiguous() + + +def compute_tensor_sha256(tensor: torch.Tensor) -> str: + materialized = _materialize_tensor_for_sha256(tensor) + + hasher = hashlib.sha256() + hasher.update(str(materialized.dtype).encode("utf-8")) + hasher.update(repr(tuple(materialized.shape)).encode("utf-8")) + hasher.update(materialized.view(torch.uint8).numpy().tobytes()) + return hasher.hexdigest() + + +def build_named_tensor_sha256( + named_tensors: Iterable[tuple[str, torch.Tensor]], +) -> dict[str, str]: + sha256_by_name: dict[str, str] = {} + for name, tensor in named_tensors: + if name in sha256_by_name: + raise ValueError(f"Duplicate tensor name in SHA256 manifest: {name}") + sha256_by_name[name] = compute_tensor_sha256(tensor) + return sha256_by_name + + +class UpdateWeightFromTensorChecker: + def __init__(self, pipeline): + self.pipeline = pipeline + + def verify_across_tp( + self, + expected_transformer_sha256: dict[str, str], + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> tuple[bool, str]: + try: + local_success, local_message = self.verify(expected_transformer_sha256) + except Exception as e: + local_success = False + local_message = ( + "Exception while verifying transformer update from tensor: " + f"{type(e).__name__}: {e}" + ) + + if tp_world_size == 1: + return local_success, local_message + + gathered_results: list[tuple[int, bool, str] | None] = [None] * tp_world_size + torch.distributed.all_gather_object( + gathered_results, + (tp_rank, local_success, local_message), + group=tp_cpu_group, + ) + return self._summarize_tp_results(gathered_results) + + def verify( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[bool, str]: + if not expected_transformer_sha256: + return False, "expected_transformer_sha256 is required" + if not isinstance(expected_transformer_sha256, dict): + return False, "expected_transformer_sha256 must be a dict[str, str]" + + transformer = self.pipeline.get_module(_TRANSFORMER_MODULE_NAME) + if transformer is None: + return False, "Transformer module is not initialized" + if not isinstance(transformer, torch.nn.Module): + return False, "Transformer module is not a torch.nn.Module" + + actual_transformer_sha256 = build_named_tensor_sha256( + self._iter_transformer_named_tensors( + transformer, expected_transformer_sha256.keys() + ) + ) + return self._compare_manifests( + expected_transformer_sha256, + actual_transformer_sha256, + ) + + def _iter_transformer_named_tensors( + self, + transformer: torch.nn.Module, + expected_names: Iterable[str], + ): + expected_name_set = set(expected_names) + seen_names: set[str] = set() + + for name, tensor in iter_materialized_weights(transformer): + if name not in expected_name_set: + continue + seen_names.add(name) + yield name, tensor + + for name, tensor in transformer.named_buffers(): + if name in seen_names or name not in expected_name_set: + continue + seen_names.add(name) + yield name, tensor + + def _compare_manifests( + self, + expected_transformer_sha256: dict[str, str], + actual_transformer_sha256: dict[str, str], + ) -> tuple[bool, str]: + missing_names = sorted( + name + for name in expected_transformer_sha256 + if name not in actual_transformer_sha256 + ) + mismatched_names = sorted( + name + for name, expected_sha256 in expected_transformer_sha256.items() + if name in actual_transformer_sha256 + and actual_transformer_sha256[name] != expected_sha256 + ) + + if missing_names or mismatched_names: + parts: list[str] = [] + if missing_names: + parts.append( + "missing " + f"{len(missing_names)} tensor(s): " + f"{self._format_tensor_names(missing_names)}" + ) + if mismatched_names: + parts.append( + "checksum mismatch for " + f"{len(mismatched_names)} tensor(s): " + f"{self._format_tensor_names(mismatched_names)}" + ) + return ( + False, + "Transformer update weight check failed: " + "; ".join(parts), + ) + + return ( + True, + f"Verified transformer update for {len(expected_transformer_sha256)} tensor(s).", + ) + + def _summarize_tp_results( + self, + gathered_results: list[tuple[int, bool, str] | None], + ) -> tuple[bool, str]: + failures = [ + (rank, message) + for result in gathered_results + if result is not None + for rank, success, message in [result] + if not success + ] + if failures: + rank, message = failures[0] + if len(failures) == 1: + return False, f"TP rank {rank}: {message}" + return ( + False, + f"{len(failures)} TP ranks failed update_weight_from_tensor_checker; " + f"first failure on rank {rank}: {message}", + ) + + return ( + True, + f"Verified transformer update across {len(gathered_results)} TP ranks.", + ) + + def _format_tensor_names(self, names: list[str]) -> str: + displayed = names[:_MAX_DISPLAY_TENSORS] + formatted = ", ".join(displayed) + if len(names) > _MAX_DISPLAY_TENSORS: + formatted += f", ... (+{len(names) - _MAX_DISPLAY_TENSORS} more)" + return formatted diff --git a/test/test_update_weight_from_tensor_checker.py b/test/test_update_weight_from_tensor_checker.py new file mode 100644 index 000000000000..8e776bfa184d --- /dev/null +++ b/test/test_update_weight_from_tensor_checker.py @@ -0,0 +1,169 @@ +import importlib.util +import sys +import tempfile +import types +import unittest +from pathlib import Path + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import Replicate, distribute_tensor + +UTIL_PATH = ( + Path(__file__).resolve().parents[1] + / "python" + / "sglang" + / "multimodal_gen" + / "runtime" + / "utils" + / "update_weight_from_tensor_checker.py" +) + + +def _install_layerwise_offload_stub(): + package_names = [ + "sglang", + "sglang.multimodal_gen", + "sglang.multimodal_gen.runtime", + "sglang.multimodal_gen.runtime.utils", + ] + original_modules: dict[str, types.ModuleType | None] = { + name: sys.modules.get(name) for name in package_names + } + for name in package_names: + sys.modules.setdefault(name, types.ModuleType(name)) + + layerwise_offload = types.ModuleType( + "sglang.multimodal_gen.runtime.utils.layerwise_offload" + ) + + def iter_materialized_weights(module: nn.Module): + yield from module.named_parameters() + + layerwise_offload.iter_materialized_weights = iter_materialized_weights + original_modules[layerwise_offload.__name__] = sys.modules.get( + layerwise_offload.__name__ + ) + sys.modules[layerwise_offload.__name__] = layerwise_offload + return original_modules + + +def _load_checker_module(): + original_modules = _install_layerwise_offload_stub() + try: + spec = importlib.util.spec_from_file_location( + "update_weight_from_tensor_checker_test_module", + UTIL_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + finally: + for name, original_module in original_modules.items(): + if original_module is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original_module + + +_CHECKER_MODULE = _load_checker_module() +UpdateWeightFromTensorChecker = _CHECKER_MODULE.UpdateWeightFromTensorChecker +build_named_tensor_sha256 = _CHECKER_MODULE.build_named_tensor_sha256 + + +class _ToyTransformer(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4, bias=False) + self.register_buffer("running_scale", torch.ones(4)) + + +class _FakePipeline: + def __init__(self, transformer: nn.Module): + self._transformer = transformer + + def get_module(self, name: str): + if name == "transformer": + return self._transformer + return None + + +def _iter_named_tensors(module: nn.Module): + yield from module.named_parameters() + yield from module.named_buffers() + + +class UpdateWeightFromTensorCheckerTest(unittest.TestCase): + def test_matches_live_transformer(self): + transformer = _ToyTransformer() + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + expected_transformer_sha256 = build_named_tensor_sha256( + _iter_named_tensors(transformer) + ) + + success, message = checker.verify(expected_transformer_sha256) + + self.assertTrue(success) + self.assertEqual(message, "Verified transformer update for 2 tensor(s).") + + def test_detects_modified_tensor(self): + transformer = _ToyTransformer() + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + expected_transformer_sha256 = build_named_tensor_sha256( + _iter_named_tensors(transformer) + ) + + with torch.no_grad(): + transformer.linear.weight.add_(1) + + success, message = checker.verify(expected_transformer_sha256) + + self.assertFalse(success) + self.assertIn("checksum mismatch for 1 tensor(s): linear.weight", message) + + def test_detects_missing_tensor(self): + transformer = _ToyTransformer() + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + + success, message = checker.verify({"missing.weight": "deadbeef"}) + + self.assertFalse(success) + self.assertEqual( + message, + "Transformer update weight check failed: " + "missing 1 tensor(s): missing.weight", + ) + + def test_supports_dtensor(self): + if dist.is_initialized(): + self.skipTest("process group already initialized") + + with tempfile.NamedTemporaryFile() as rendezvous_file: + dist.init_process_group( + backend="gloo", + init_method=f"file://{rendezvous_file.name}", + rank=0, + world_size=1, + ) + try: + local_tensor = torch.arange(4, dtype=torch.float32) + device_mesh = init_device_mesh("cpu", (1,)) + distributed_tensor = distribute_tensor( + local_tensor, device_mesh, [Replicate()] + ) + + expected_sha256 = build_named_tensor_sha256([("weight", local_tensor)]) + actual_sha256 = build_named_tensor_sha256( + [("weight", distributed_tensor)] + ) + + self.assertEqual(actual_sha256, expected_sha256) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_update_weights_from_tensor_checker_e2e.py b/test/test_update_weights_from_tensor_checker_e2e.py new file mode 100644 index 000000000000..b7e47a5dd1ec --- /dev/null +++ b/test/test_update_weights_from_tensor_checker_e2e.py @@ -0,0 +1,296 @@ +import base64 +import hashlib +import os +import pickle +import signal +import socket +import subprocess +import tempfile +import time +import unittest +from pathlib import Path + +import requests +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import safe_open + +BASE_MODEL = "Tongyi-MAI/Z-Image-Turbo" +TRANSFORMER_MODULE = "transformer" +NUM_TENSORS_TO_UPDATE = 4 +SERVER_READY_MESSAGE = "Application startup complete." + + +def _get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _list_safetensor_files(model_dir: str) -> list[Path]: + return sorted(Path(model_dir).glob("*.safetensors")) + + +def _iter_transformer_weights_from_disk(model_path: str): + local_model_path = snapshot_download( + repo_id=model_path, + allow_patterns=[f"{TRANSFORMER_MODULE}/*"], + ) + weights_dir = Path(local_model_path) / TRANSFORMER_MODULE + safetensor_files = _list_safetensor_files(str(weights_dir)) + if not safetensor_files: + raise FileNotFoundError( + f"No safetensors files found for module '{TRANSFORMER_MODULE}' in {weights_dir}" + ) + + for safetensor_file in safetensor_files: + with safe_open(str(safetensor_file), framework="pt", device="cpu") as handle: + for name in handle.keys(): + yield name, handle.get_tensor(name) + + +def _compute_tensor_sha256(tensor: torch.Tensor) -> str: + materialized = tensor.detach().cpu().contiguous() + hasher = hashlib.sha256() + hasher.update(str(materialized.dtype).encode("utf-8")) + hasher.update(repr(tuple(materialized.shape)).encode("utf-8")) + hasher.update(materialized.view(torch.uint8).numpy().tobytes()) + return hasher.hexdigest() + + +def _build_named_tensor_sha256( + named_tensors: list[tuple[str, torch.Tensor]], +) -> dict[str, str]: + return {name: _compute_tensor_sha256(tensor) for name, tensor in named_tensors} + + +def _select_transformer_tensors( + model_path: str, + max_tensors: int = NUM_TENSORS_TO_UPDATE, +) -> list[tuple[str, torch.Tensor]]: + selected_named_tensors: list[tuple[str, torch.Tensor]] = [] + for name, tensor in _iter_transformer_weights_from_disk(model_path): + if not tensor.is_floating_point(): + continue + selected_named_tensors.append((name, tensor.to(torch.bfloat16).clone())) + if len(selected_named_tensors) == max_tensors: + break + + if not selected_named_tensors: + raise AssertionError("Expected at least one floating-point transformer tensor") + + return selected_named_tensors + + +def _build_shifted_named_tensors( + named_tensors: list[tuple[str, torch.Tensor]], + delta: float, +) -> list[tuple[str, torch.Tensor]]: + shifted_named_tensors: list[tuple[str, torch.Tensor]] = [] + for index, (name, tensor) in enumerate(named_tensors): + shifted_tensor = tensor.clone() + if index == 0: + shifted_tensor.add_(delta) + shifted_named_tensors.append((name, shifted_tensor)) + return shifted_named_tensors + + +def _serialize_named_tensors(named_tensors: list[tuple[str, torch.Tensor]]) -> str: + return base64.b64encode(pickle.dumps(named_tensors)).decode("utf-8") + + +class _ServerRunner: + def __init__(self, model_path: str): + self.model_path = model_path + self.port = _get_free_port() + self.process: subprocess.Popen | None = None + self.log_file = None + self.log_path = Path(tempfile.gettempdir()) / ( + f"sglang_update_weight_checker_{self.port}.log" + ) + self.log_path.unlink(missing_ok=True) + + @property + def base_url(self) -> str: + return f"http://127.0.0.1:{self.port}" + + def start(self) -> None: + command = [ + "sglang", + "serve", + "--model-path", + self.model_path, + "--port", + str(self.port), + "--log-level=debug", + "--num-gpus", + "1", + ] + env = os.environ.copy() + env["SGLANG_DIFFUSION_STAGE_LOGGING"] = "1" + + self.log_file = self.log_path.open("w", encoding="utf-8") + self.process = subprocess.Popen( + command, + stdout=self.log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + env=env, + ) + try: + self._wait_for_ready() + except Exception: + self.stop() + raise + + def stop(self) -> None: + if self.process is None: + if self.log_file is not None: + self.log_file.close() + self.log_file = None + return + if self.process.poll() is None: + try: + os.killpg(self.process.pid, signal.SIGTERM) + self.process.wait(timeout=30) + except Exception: + os.killpg(self.process.pid, signal.SIGKILL) + self.process.wait(timeout=30) + if self.log_file is not None: + self.log_file.close() + self.log_file = None + self.process = None + + def _wait_for_ready(self) -> None: + deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "1200")) + start_time = time.time() + + while time.time() - start_time < deadline: + if self.process is None: + raise RuntimeError("Server process was not started") + if self.process.poll() is not None: + raise RuntimeError( + f"Server exited early with code {self.process.returncode}.\n" + f"{self._get_log_tail()}" + ) + if self.log_path.exists(): + log_content = self.log_path.read_text(encoding="utf-8", errors="ignore") + if SERVER_READY_MESSAGE in log_content: + return + time.sleep(1) + + raise TimeoutError( + f"Server did not become ready within {deadline}s.\n{self._get_log_tail()}" + ) + + def _get_log_tail(self, lines: int = 200) -> str: + if not self.log_path.exists(): + return "" + log_content = self.log_path.read_text(encoding="utf-8", errors="ignore") + return "\n".join(log_content.splitlines()[-lines:]) + + +class UpdateWeightFromTensorCheckerE2ETest(unittest.TestCase): + @classmethod + def setUpClass(cls): + snapshot_download(repo_id=BASE_MODEL) + cls.base_transformer_tensors = _select_transformer_tensors(BASE_MODEL) + cls.expected_updated_tensors = _build_shifted_named_tensors( + cls.base_transformer_tensors, + delta=1.0, + ) + cls.server = _ServerRunner(BASE_MODEL) + cls.server.start() + + @classmethod + def tearDownClass(cls): + cls.server.stop() + + def _update_weights_from_disk(self, model_path: str) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_disk", + json={"model_path": model_path, "target_modules": [TRANSFORMER_MODULE]}, + timeout=300, + ) + return response.json(), response.status_code + + def _update_weights_from_tensor( + self, + named_tensors: list[tuple[str, torch.Tensor]], + ) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [_serialize_named_tensors(named_tensors)], + "target_modules": [TRANSFORMER_MODULE], + }, + timeout=300, + ) + return response.json(), response.status_code + + def _check_updated_weights_from_tensor( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor_checker", + json={"expected_transformer_sha256": [expected_transformer_sha256]}, + timeout=300, + ) + return response.json(), response.status_code + + def test_update_weights_from_tensor_checker_success(self): + reset_result, reset_status = self._update_weights_from_disk(BASE_MODEL) + self.assertEqual(reset_status, 200, reset_result) + self.assertTrue(reset_result.get("success"), reset_result) + + update_result, update_status = self._update_weights_from_tensor( + self.expected_updated_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + self.expected_updated_tensors + ) + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 200, check_result) + self.assertTrue(check_result.get("success"), check_result) + self.assertIn("Verified transformer update", check_result.get("message", "")) + + def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): + reset_result, reset_status = self._update_weights_from_disk(BASE_MODEL) + self.assertEqual(reset_status, 200, reset_result) + self.assertTrue(reset_result.get("success"), reset_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + self.expected_updated_tensors + ) + corrupted_named_tensors = _build_shifted_named_tensors( + self.base_transformer_tensors, + delta=2.0, + ) + + update_result, update_status = self._update_weights_from_tensor( + corrupted_named_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 400, check_result) + self.assertFalse(check_result.get("success", True), check_result) + self.assertIn("checksum mismatch", check_result.get("message", "")) + self.assertIn( + self.expected_updated_tensors[0][0], + check_result.get("message", ""), + ) + + +if __name__ == "__main__": + unittest.main() From a0eaa41fec0d8e6ebb5124a40ca17aab4321f693 Mon Sep 17 00:00:00 2001 From: "Fenglin Yu (MikukuOvO)" Date: Sun, 22 Mar 2026 20:24:01 +0000 Subject: [PATCH 5/7] [diffusion] test: add 2gpu update_weights_from_tensor checker e2e --- ..._update_weights_from_tensor_checker_e2e.py | 187 +++++++++++++++++- 1 file changed, 185 insertions(+), 2 deletions(-) diff --git a/test/test_update_weights_from_tensor_checker_e2e.py b/test/test_update_weights_from_tensor_checker_e2e.py index b7e47a5dd1ec..b8a4129a209c 100644 --- a/test/test_update_weights_from_tensor_checker_e2e.py +++ b/test/test_update_weights_from_tensor_checker_e2e.py @@ -100,8 +100,16 @@ def _serialize_named_tensors(named_tensors: list[tuple[str, torch.Tensor]]) -> s class _ServerRunner: - def __init__(self, model_path: str): + def __init__( + self, + model_path: str, + *, + num_gpus: int = 1, + tp_size: int | None = None, + ): self.model_path = model_path + self.num_gpus = num_gpus + self.tp_size = tp_size self.port = _get_free_port() self.process: subprocess.Popen | None = None self.log_file = None @@ -124,8 +132,10 @@ def start(self) -> None: str(self.port), "--log-level=debug", "--num-gpus", - "1", + str(self.num_gpus), ] + if self.tp_size is not None: + command.extend(["--tp-size", str(self.tp_size)]) env = os.environ.copy() env["SGLANG_DIFFUSION_STAGE_LOGGING"] = "1" @@ -292,5 +302,178 @@ def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): ) +def _select_tp_candidate_tensors( + model_path: str, + max_tensors: int = 24, +) -> list[tuple[str, torch.Tensor]]: + norm_candidates: list[tuple[str, torch.Tensor]] = [] + other_candidates: list[tuple[str, torch.Tensor]] = [] + for name, tensor in _iter_transformer_weights_from_disk(model_path): + if not tensor.is_floating_point() or tensor.ndim != 1: + continue + candidate = (name, tensor.to(torch.bfloat16).clone()) + if "norm" in name: + norm_candidates.append(candidate) + elif name.endswith(".bias"): + other_candidates.append(candidate) + + if len(norm_candidates) + len(other_candidates) >= max_tensors: + break + + candidates = norm_candidates + other_candidates + if not candidates: + raise AssertionError("Expected at least one 1D transformer tensor candidate") + return candidates[:max_tensors] + + +@unittest.skipUnless(torch.cuda.device_count() >= 2, "requires at least 2 GPUs") +class UpdateWeightFromTensorChecker2GPUTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + snapshot_download(repo_id=BASE_MODEL) + cls.tp_candidates = _select_tp_candidate_tensors(BASE_MODEL) + cls.selected_updated_tensors = None + + def setUp(self): + self.server = _ServerRunner(BASE_MODEL, num_gpus=2, tp_size=2) + self.server.start() + + def tearDown(self): + self.server.stop() + self.server = None + + def _update_weights_from_tensor( + self, + named_tensors: list[tuple[str, torch.Tensor]], + ) -> tuple[dict, int]: + serialized_payload = _serialize_named_tensors(named_tensors) + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [serialized_payload, serialized_payload], + "target_modules": [TRANSFORMER_MODULE], + }, + timeout=300, + ) + return response.json(), response.status_code + + def _check_updated_weights_from_tensor( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor_checker", + json={ + "expected_transformer_sha256": [ + expected_transformer_sha256, + expected_transformer_sha256, + ] + }, + timeout=300, + ) + return response.json(), response.status_code + + def _build_selected_updated_tensors(self): + cls = type(self) + if cls.selected_updated_tensors is not None: + return cls.selected_updated_tensors + + errors: list[str] = [] + for name, tensor in cls.tp_candidates: + updated_tensors = [(name, tensor.clone().add_(1.0))] + + serialized_payload = _serialize_named_tensors(updated_tensors) + update_response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [serialized_payload, serialized_payload], + "target_modules": [TRANSFORMER_MODULE], + }, + timeout=300, + ) + if ( + update_response.status_code != 200 + or not update_response.json().get("success") + ): + errors.append( + f"{name}: update failed with " + f"{update_response.status_code} {update_response.text}" + ) + continue + + expected_transformer_sha256 = _build_named_tensor_sha256(updated_tensors) + check_response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor_checker", + json={ + "expected_transformer_sha256": [ + expected_transformer_sha256, + expected_transformer_sha256, + ] + }, + timeout=300, + ) + if ( + check_response.status_code == 200 + and check_response.json().get("success") + ): + cls.selected_updated_tensors = updated_tensors + return cls.selected_updated_tensors + + errors.append( + f"{name}: checker failed with " + f"{check_response.status_code} {check_response.text}" + ) + + displayed_errors = "; ".join(errors[:5]) + raise AssertionError( + "Could not find a TP-compatible transformer tensor candidate for 2 GPU " + f"update_weights_from_tensor checker test. First errors: {displayed_errors}" + ) + + def test_update_weights_from_tensor_checker_success(self): + selected_updated_tensors = self._build_selected_updated_tensors() + update_result, update_status = self._update_weights_from_tensor( + selected_updated_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + selected_updated_tensors + ) + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 200, check_result) + self.assertTrue(check_result.get("success"), check_result) + self.assertIn("across 2 TP ranks", check_result.get("message", "")) + + def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): + selected_updated_tensors = self._build_selected_updated_tensors() + expected_transformer_sha256 = _build_named_tensor_sha256( + selected_updated_tensors + ) + corrupted_named_tensors = [ + (name, tensor.clone().add_(1.0)) for name, tensor in selected_updated_tensors + ] + + update_result, update_status = self._update_weights_from_tensor( + corrupted_named_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 400, check_result) + self.assertFalse(check_result.get("success", True), check_result) + self.assertIn( + "failed update_weight_from_tensor_checker", + check_result.get("message", ""), + ) + self.assertIn(selected_updated_tensors[0][0], check_result.get("message", "")) + + if __name__ == "__main__": unittest.main() From 47d960000bf90b42b6cafda88de40352824672db Mon Sep 17 00:00:00 2001 From: "Fenglin Yu (MikukuOvO)" Date: Sun, 22 Mar 2026 22:20:08 +0000 Subject: [PATCH 6/7] [diffusion] test: add multiprocessing serializer e2e coverage --- ..._update_weights_from_tensor_checker_e2e.py | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/test/test_update_weights_from_tensor_checker_e2e.py b/test/test_update_weights_from_tensor_checker_e2e.py index b8a4129a209c..1bd508a8f360 100644 --- a/test/test_update_weights_from_tensor_checker_e2e.py +++ b/test/test_update_weights_from_tensor_checker_e2e.py @@ -15,6 +15,8 @@ from huggingface_hub import snapshot_download from safetensors.torch import safe_open +from sglang.srt.utils import MultiprocessingSerializer + BASE_MODEL = "Tongyi-MAI/Z-Image-Turbo" TRANSFORMER_MODULE = "transformer" NUM_TENSORS_TO_UPDATE = 4 @@ -99,6 +101,12 @@ def _serialize_named_tensors(named_tensors: list[tuple[str, torch.Tensor]]) -> s return base64.b64encode(pickle.dumps(named_tensors)).decode("utf-8") +def _serialize_named_tensors_multiprocessing( + named_tensors: list[tuple[str, torch.Tensor]], +) -> str: + return MultiprocessingSerializer.serialize(named_tensors, output_str=True) + + class _ServerRunner: def __init__( self, @@ -228,11 +236,13 @@ def _update_weights_from_disk(self, model_path: str) -> tuple[dict, int]: def _update_weights_from_tensor( self, named_tensors: list[tuple[str, torch.Tensor]], + *, + serializer=_serialize_named_tensors, ) -> tuple[dict, int]: response = requests.post( f"{self.server.base_url}/update_weights_from_tensor", json={ - "serialized_named_tensors": [_serialize_named_tensors(named_tensors)], + "serialized_named_tensors": [serializer(named_tensors)], "target_modules": [TRANSFORMER_MODULE], }, timeout=300, @@ -271,6 +281,36 @@ def test_update_weights_from_tensor_checker_success(self): self.assertTrue(check_result.get("success"), check_result) self.assertIn("Verified transformer update", check_result.get("message", "")) + @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") + def test_update_weights_from_tensor_checker_with_multiprocessing_serializer(self): + cuda_named_tensors = [ + (name, tensor.to(device="cuda", non_blocking=True)) + for name, tensor in self.expected_updated_tensors + ] + try: + update_result, update_status = self._update_weights_from_tensor( + cuda_named_tensors, + serializer=_serialize_named_tensors_multiprocessing, + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + self.expected_updated_tensors + ) + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 200, check_result) + self.assertTrue(check_result.get("success"), check_result) + self.assertIn( + "Verified transformer update", + check_result.get("message", ""), + ) + finally: + del cuda_named_tensors + torch.cuda.empty_cache() + def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): reset_result, reset_status = self._update_weights_from_disk(BASE_MODEL) self.assertEqual(reset_status, 200, reset_result) @@ -345,8 +385,10 @@ def tearDown(self): def _update_weights_from_tensor( self, named_tensors: list[tuple[str, torch.Tensor]], + *, + serializer=_serialize_named_tensors, ) -> tuple[dict, int]: - serialized_payload = _serialize_named_tensors(named_tensors) + serialized_payload = serializer(named_tensors) response = requests.post( f"{self.server.base_url}/update_weights_from_tensor", json={ From c1016f5c5ffaef082fd387ac8f6c0089a19b4ba5 Mon Sep 17 00:00:00 2001 From: "Fenglin Yu (MikukuOvO)" Date: Sun, 29 Mar 2026 22:10:07 +0000 Subject: [PATCH 7/7] [diffusion] fix: verify full TP tensor checksums on root --- .../update_weight_from_tensor_checker.py | 327 +++++++++++++++--- .../test_update_weight_from_tensor_checker.py | 85 +++++ ..._update_weights_from_tensor_checker_e2e.py | 5 +- 3 files changed, 365 insertions(+), 52 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py b/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py index 01b6ba82c5f1..91677d3dc4c0 100644 --- a/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py +++ b/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py @@ -14,6 +14,7 @@ _TRANSFORMER_MODULE_NAME = "transformer" _MAX_DISPLAY_TENSORS = 5 +_ROOT_TP_RANK = 0 def _materialize_tensor_for_sha256(tensor: torch.Tensor) -> torch.Tensor: @@ -22,8 +23,8 @@ def _materialize_tensor_for_sha256(tensor: torch.Tensor) -> torch.Tensor: return tensor.detach().cpu().contiguous() -def compute_tensor_sha256(tensor: torch.Tensor) -> str: - materialized = _materialize_tensor_for_sha256(tensor) +def _compute_sha256_from_materialized(materialized: torch.Tensor) -> str: + materialized = materialized.contiguous() hasher = hashlib.sha256() hasher.update(str(materialized.dtype).encode("utf-8")) @@ -32,6 +33,10 @@ def compute_tensor_sha256(tensor: torch.Tensor) -> str: return hasher.hexdigest() +def compute_tensor_sha256(tensor: torch.Tensor) -> str: + return _compute_sha256_from_materialized(_materialize_tensor_for_sha256(tensor)) + + def build_named_tensor_sha256( named_tensors: Iterable[tuple[str, torch.Tensor]], ) -> dict[str, str]: @@ -54,25 +59,22 @@ def verify_across_tp( tp_world_size: int, tp_cpu_group, ) -> tuple[bool, str]: - try: - local_success, local_message = self.verify(expected_transformer_sha256) - except Exception as e: - local_success = False - local_message = ( - "Exception while verifying transformer update from tensor: " - f"{type(e).__name__}: {e}" - ) - if tp_world_size == 1: - return local_success, local_message + try: + return self.verify(expected_transformer_sha256) + except Exception as e: + return ( + False, + "Exception while verifying transformer update from tensor: " + f"{type(e).__name__}: {e}", + ) - gathered_results: list[tuple[int, bool, str] | None] = [None] * tp_world_size - torch.distributed.all_gather_object( - gathered_results, - (tp_rank, local_success, local_message), - group=tp_cpu_group, + return self._verify_on_tp_root( + expected_transformer_sha256=expected_transformer_sha256, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + tp_cpu_group=tp_cpu_group, ) - return self._summarize_tp_results(gathered_results) def verify( self, @@ -89,16 +91,271 @@ def verify( if not isinstance(transformer, torch.nn.Module): return False, "Transformer module is not a torch.nn.Module" - actual_transformer_sha256 = build_named_tensor_sha256( - self._iter_transformer_named_tensors( - transformer, expected_transformer_sha256.keys() - ) + actual_transformer_sha256 = self._build_local_transformer_sha256( + transformer=transformer, + expected_transformer_sha256=expected_transformer_sha256, ) return self._compare_manifests( expected_transformer_sha256, actual_transformer_sha256, ) + def _verify_on_tp_root( + self, + *, + expected_transformer_sha256: dict[str, str], + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> tuple[bool, str]: + transformer, error_message = self._get_transformer_for_verification( + expected_transformer_sha256 + ) + gathered_errors: list[str | None] | None = ( + [None] * tp_world_size if tp_rank == _ROOT_TP_RANK else None + ) + torch.distributed.gather_object( + error_message, + gathered_errors, + dst=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + + result: tuple[bool, str] | None = None + should_verify_holder = [False] + if tp_rank == _ROOT_TP_RANK: + assert gathered_errors is not None + failures = [ + (rank, message) + for rank, message in enumerate(gathered_errors) + if message is not None + ] + if failures: + rank, message = failures[0] + if len(failures) == 1: + result = (False, f"TP rank {rank}: {message}") + else: + result = ( + False, + f"{len(failures)} TP ranks failed update_weight_from_tensor_checker; " + f"first failure on rank {rank}: {message}", + ) + else: + should_verify_holder[0] = True + + torch.distributed.broadcast_object_list( + should_verify_holder, + src=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + if should_verify_holder[0]: + assert transformer is not None + actual_transformer_sha256 = self._build_root_transformer_sha256( + transformer=transformer, + expected_transformer_sha256=expected_transformer_sha256, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + tp_cpu_group=tp_cpu_group, + ) + if tp_rank == _ROOT_TP_RANK: + assert actual_transformer_sha256 is not None + success, message = self._compare_manifests( + expected_transformer_sha256, + actual_transformer_sha256, + ) + if success: + message = ( + f"Verified transformer update across {tp_world_size} TP ranks." + ) + result = (success, message) + + result_holder = [result] + torch.distributed.broadcast_object_list( + result_holder, + src=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + final_result = result_holder[0] + assert final_result is not None + return final_result + + def _get_transformer_for_verification( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[torch.nn.Module | None, str | None]: + if not expected_transformer_sha256: + return None, "expected_transformer_sha256 is required" + if not isinstance(expected_transformer_sha256, dict): + return None, "expected_transformer_sha256 must be a dict[str, str]" + + transformer = self.pipeline.get_module(_TRANSFORMER_MODULE_NAME) + if transformer is None: + return None, "Transformer module is not initialized" + if not isinstance(transformer, torch.nn.Module): + return None, "Transformer module is not a torch.nn.Module" + return transformer, None + + def _build_local_transformer_sha256( + self, + *, + transformer: torch.nn.Module, + expected_transformer_sha256: dict[str, str], + ) -> dict[str, str]: + return build_named_tensor_sha256( + self._iter_transformer_named_tensors( + transformer, expected_transformer_sha256.keys() + ) + ) + + def _build_root_transformer_sha256( + self, + *, + transformer: torch.nn.Module, + expected_transformer_sha256: dict[str, str], + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> dict[str, str] | None: + local_named_tensors = dict( + self._iter_transformer_named_tensors( + transformer, expected_transformer_sha256.keys() + ) + ) + reference_tensors = dict(transformer.named_parameters()) + reference_tensors.update(dict(transformer.named_buffers())) + actual_transformer_sha256: dict[str, str] | None = ( + {} if tp_rank == _ROOT_TP_RANK else None + ) + for name, expected_sha256 in expected_transformer_sha256.items(): + gathered_tensors = self._gather_materialized_tensors_to_root( + materialized=( + _materialize_tensor_for_sha256(local_named_tensors[name]) + if name in local_named_tensors + else None + ), + tp_rank=tp_rank, + tp_world_size=tp_world_size, + tp_cpu_group=tp_cpu_group, + ) + if tp_rank != _ROOT_TP_RANK: + continue + + reconstructed_sha256 = self._compute_transformer_tensor_sha256_from_gathered( + gathered_tensors=gathered_tensors, + expected_sha256=expected_sha256, + reference_tensor=reference_tensors.get(name), + ) + if reconstructed_sha256 is not None: + assert actual_transformer_sha256 is not None + actual_transformer_sha256[name] = reconstructed_sha256 + + return actual_transformer_sha256 + + def _compute_transformer_tensor_sha256_from_gathered( + self, + *, + gathered_tensors: list[torch.Tensor | None] | None, + expected_sha256: str, + reference_tensor: torch.Tensor | None, + ) -> str | None: + if gathered_tensors is None: + return None + + valid_tensors = [tensor for tensor in gathered_tensors if tensor is not None] + if len(valid_tensors) != len(gathered_tensors): + return None + + local_sha256s = [ + _compute_sha256_from_materialized(materialized) + for materialized in valid_tensors + ] + if all(local_sha256 == expected_sha256 for local_sha256 in local_sha256s): + return expected_sha256 + + candidate_dims = self._get_tp_candidate_dims(reference_tensor) + for shard_dim in candidate_dims: + reconstructed = self._reconstruct_tp_tensor( + gathered_tensors=valid_tensors, + shard_dim=shard_dim, + ) + if reconstructed is None: + continue + reconstructed_sha256 = _compute_sha256_from_materialized(reconstructed) + if reconstructed_sha256 == expected_sha256: + return reconstructed_sha256 + + return local_sha256s[0] + + def _get_tp_candidate_dims( + self, + reference_tensor: torch.Tensor | None, + ) -> list[int]: + if reference_tensor is None: + return [] + + candidate_dims: list[int] = [] + if isinstance(reference_tensor, DTensor): + for placement in reference_tensor.placements: + shard_dim = getattr(placement, "dim", None) + if isinstance(shard_dim, int): + candidate_dims.append(shard_dim) + + for attr in ("input_dim", "output_dim"): + shard_dim = getattr(reference_tensor, attr, None) + if isinstance(shard_dim, int): + candidate_dims.append(shard_dim) + + deduped_dims: list[int] = [] + for shard_dim in candidate_dims: + if shard_dim not in deduped_dims: + deduped_dims.append(shard_dim) + return deduped_dims + + def _gather_materialized_tensors_to_root( + self, + *, + materialized: torch.Tensor | None, + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> list[torch.Tensor | None] | None: + gathered_tensors: list[torch.Tensor | None] | None = ( + [None] * tp_world_size if tp_rank == _ROOT_TP_RANK else None + ) + torch.distributed.gather_object( + materialized, + gathered_tensors, + dst=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + return gathered_tensors + + def _reconstruct_tp_tensor( + self, + *, + gathered_tensors: list[torch.Tensor], + shard_dim: int, + ) -> torch.Tensor | None: + if not gathered_tensors: + return None + + first_tensor = gathered_tensors[0] + if first_tensor.ndim == 0: + return None + + shard_dim %= first_tensor.ndim + for tensor in gathered_tensors[1:]: + if tensor.ndim != first_tensor.ndim or tensor.dtype != first_tensor.dtype: + return None + if any( + lhs != rhs + for dim, (lhs, rhs) in enumerate(zip(first_tensor.shape, tensor.shape)) + if dim != shard_dim + ): + return None + + return torch.cat(gathered_tensors, dim=shard_dim).contiguous() + def _iter_transformer_named_tensors( self, transformer: torch.nn.Module, @@ -160,32 +417,6 @@ def _compare_manifests( f"Verified transformer update for {len(expected_transformer_sha256)} tensor(s).", ) - def _summarize_tp_results( - self, - gathered_results: list[tuple[int, bool, str] | None], - ) -> tuple[bool, str]: - failures = [ - (rank, message) - for result in gathered_results - if result is not None - for rank, success, message in [result] - if not success - ] - if failures: - rank, message = failures[0] - if len(failures) == 1: - return False, f"TP rank {rank}: {message}" - return ( - False, - f"{len(failures)} TP ranks failed update_weight_from_tensor_checker; " - f"first failure on rank {rank}: {message}", - ) - - return ( - True, - f"Verified transformer update across {len(gathered_results)} TP ranks.", - ) - def _format_tensor_names(self, names: list[str]) -> str: displayed = names[:_MAX_DISPLAY_TENSORS] formatted = ", ".join(displayed) diff --git a/test/test_update_weight_from_tensor_checker.py b/test/test_update_weight_from_tensor_checker.py index 8e776bfa184d..206d1e7dde91 100644 --- a/test/test_update_weight_from_tensor_checker.py +++ b/test/test_update_weight_from_tensor_checker.py @@ -1,4 +1,6 @@ import importlib.util +import os +import pickle import sys import tempfile import types @@ -7,6 +9,7 @@ import torch import torch.distributed as dist +import torch.multiprocessing as mp from torch import nn from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import Replicate, distribute_tensor @@ -96,6 +99,55 @@ def _iter_named_tensors(module: nn.Module): yield from module.named_buffers() +class _ShardedTransformer(nn.Module): + def __init__(self, local_weight: torch.Tensor, *, shard_attr: str, shard_dim: int): + super().__init__() + self.weight = nn.Parameter(local_weight) + setattr(self.weight, shard_attr, shard_dim) + + +def _verify_sharded_weight_across_tp_worker( + rank: int, + rendezvous_file: str, + shard_attr: str, + shard_dim: int, + results_dir: str, +): + os.environ.setdefault("GLOO_SOCKET_IFNAME", "lo") + dist.init_process_group( + backend="gloo", + init_method=f"file://{rendezvous_file}", + rank=rank, + world_size=2, + ) + try: + full_weight = torch.arange(16, dtype=torch.float32).reshape(4, 4) + shard_size = full_weight.shape[shard_dim] // 2 + local_weight = full_weight.narrow( + shard_dim, rank * shard_size, shard_size + ).clone() + + transformer = _ShardedTransformer( + local_weight, + shard_attr=shard_attr, + shard_dim=shard_dim, + ) + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + expected_transformer_sha256 = build_named_tensor_sha256([("weight", full_weight)]) + + result = checker.verify_across_tp( + expected_transformer_sha256=expected_transformer_sha256, + tp_rank=rank, + tp_world_size=2, + tp_cpu_group=dist.group.WORLD, + ) + result_path = Path(results_dir) / f"rank_{rank}.pkl" + with result_path.open("wb") as f: + pickle.dump(result, f) + finally: + dist.destroy_process_group() + + class UpdateWeightFromTensorCheckerTest(unittest.TestCase): def test_matches_live_transformer(self): transformer = _ToyTransformer() @@ -142,6 +194,7 @@ def test_supports_dtensor(self): self.skipTest("process group already initialized") with tempfile.NamedTemporaryFile() as rendezvous_file: + os.environ.setdefault("GLOO_SOCKET_IFNAME", "lo") dist.init_process_group( backend="gloo", init_method=f"file://{rendezvous_file.name}", @@ -164,6 +217,38 @@ def test_supports_dtensor(self): finally: dist.destroy_process_group() + def _run_tp_sharded_verification(self, *, shard_attr: str, shard_dim: int): + if dist.is_initialized(): + self.skipTest("process group already initialized") + + with tempfile.TemporaryDirectory() as tmpdir: + rendezvous_file = str(Path(tmpdir) / "dist_init") + Path(rendezvous_file).touch() + mp.spawn( + _verify_sharded_weight_across_tp_worker, + args=(rendezvous_file, shard_attr, shard_dim, tmpdir), + nprocs=2, + join=True, + ) + + results = {} + for rank in range(2): + result_path = Path(tmpdir) / f"rank_{rank}.pkl" + with result_path.open("rb") as f: + results[rank] = pickle.load(f) + + self.assertEqual(sorted(results.keys()), [0, 1]) + for rank in range(2): + success, message = results[rank] + self.assertTrue(success, f"rank {rank}: {message}") + self.assertIn("across 2 TP ranks", message) + + def test_verify_across_tp_gathers_output_sharded_tensor(self): + self._run_tp_sharded_verification(shard_attr="output_dim", shard_dim=0) + + def test_verify_across_tp_gathers_input_sharded_tensor(self): + self._run_tp_sharded_verification(shard_attr="input_dim", shard_dim=1) + if __name__ == "__main__": unittest.main() diff --git a/test/test_update_weights_from_tensor_checker_e2e.py b/test/test_update_weights_from_tensor_checker_e2e.py index 1bd508a8f360..d21b7d8d8c2e 100644 --- a/test/test_update_weights_from_tensor_checker_e2e.py +++ b/test/test_update_weights_from_tensor_checker_e2e.py @@ -510,10 +510,7 @@ def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): ) self.assertEqual(check_status, 400, check_result) self.assertFalse(check_result.get("success", True), check_result) - self.assertIn( - "failed update_weight_from_tensor_checker", - check_result.get("message", ""), - ) + self.assertIn("checksum mismatch", check_result.get("message", "")) self.assertIn(selected_updated_tensors[0][0], check_result.get("message", ""))