diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py index bda72df12a8f..3ea8884e1a48 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -1,6 +1,7 @@ """Request/response data structures for post-training APIs.""" from dataclasses import dataclass +from typing import Optional, Union @dataclass @@ -12,6 +13,39 @@ class UpdateWeightFromDiskReqInput: target_modules: list[str] | None = None +@dataclass +class UpdateWeightFromTensorReqInput: + """Request to update model weights from tensor payloads for diffusion models.""" + + serialized_named_tensors: list[Union[str, bytes]] + load_format: Optional[str] = None + target_modules: list[str] | None = None + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightFromTensorReqOutput: + """Response for update_weights_from_tensor request.""" + + success: bool + message: str + + +@dataclass +class UpdateWeightFromTensorCheckerReqInput: + """Request to verify live transformer weights against expected SHA-256 values.""" + + expected_transformer_sha256: list[dict[str, str]] + + +@dataclass +class UpdateWeightFromTensorCheckerReqOutput: + """Response for update_weights_from_tensor_checker request.""" + + success: bool + message: str + + @dataclass class GetWeightsChecksumReqInput: """Compute SHA-256 checksum of loaded module weights for verification.""" diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py index 1b9312d8ea0f..6c711bdf5d63 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -6,6 +6,8 @@ from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightFromTensorCheckerReqInput, + UpdateWeightFromTensorReqInput, ) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client @@ -46,6 +48,76 @@ async def update_weights_from_disk(request: Request): ) +@router.post("/update_weights_from_tensor") +async def update_weights_from_tensor(request: Request): + """Update model weights from serialized tensor payloads.""" + body = await request.json() + serialized_named_tensors = body.get("serialized_named_tensors") + if not serialized_named_tensors: + return ORJSONResponse( + {"success": False, "message": "serialized_named_tensors is required"}, + status_code=400, + ) + + req = UpdateWeightFromTensorReqInput( + serialized_named_tensors=serialized_named_tensors, + load_format=body.get("load_format"), + target_modules=body.get("target_modules"), + weight_version=body.get("weight_version"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + + +@router.post("/update_weights_from_tensor_checker") +async def update_weights_from_tensor_checker(request: Request): + """Verify live transformer weights against expected SHA-256 values.""" + body = await request.json() + expected_transformer_sha256 = body.get("expected_transformer_sha256") + if ( + not isinstance(expected_transformer_sha256, list) + or not expected_transformer_sha256 + ): + return ORJSONResponse( + {"success": False, "message": "expected_transformer_sha256 is required"}, + status_code=400, + ) + + req = UpdateWeightFromTensorCheckerReqInput( + expected_transformer_sha256=expected_transformer_sha256, + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + + @router.post("/get_weights_checksum") async def get_weights_checksum(request: Request): """Return SHA-256 checksum of each requested module's weights.""" diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index f170809a738e..29335350547a 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -42,6 +42,7 @@ import gc from pathlib import Path +from typing import Any import torch from torch.distributed.tensor import DTensor, distribute_tensor @@ -57,6 +58,10 @@ from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.weight_sync.tensor_bucket import ( + FlattenedTensorBucket, + FlattenedTensorMetadata, +) logger = init_logger(__name__) @@ -291,3 +296,158 @@ def _rollback(self, updated_modules: list[str]) -> None: continue weights_iter = _get_weights_iter(str(weights_dir)) _load_weights_into_module(module, weights_iter) + + def update_weights_from_tensor( + self, + named_tensors: Any, + load_format: str | None = None, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update module weights from in-memory tensor payloads.""" + try: + modules_to_update = self._collect_modules(target_modules) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + + if not modules_to_update: + error_msg = ( + f"No matching modules found for update. " + f"Requested: {target_modules}. " + f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}" + ) + logger.error(error_msg) + return False, error_msg + + try: + module_payloads = self._resolve_module_payloads( + named_tensors=named_tensors, + modules_to_update=modules_to_update, + ) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + + updated_modules: list[str] = [] + for module_name, module in modules_to_update: + try: + payload = module_payloads[module_name] + weights_iter = self._materialize_weights_iter(payload, load_format) + _load_weights_into_module(module, weights_iter) + updated_modules.append(module_name) + except Exception as e: + rollback_list = updated_modules + [module_name] + logger.error( + f"Tensor weight update failed for module '{module_name}': {e}. " + f"Rolling back {len(rollback_list)} module(s) " + f"(including partially-loaded '{module_name}'): " + f"{rollback_list}.", + exc_info=True, + ) + self._rollback(rollback_list) + return False, ( + f"Failed to update module '{module_name}' from tensor: {e}. " + f"All modules rolled back to original weights." + ) + + gc.collect() + torch.cuda.empty_cache() + names = ", ".join(updated_modules) + message = f"Updated {len(updated_modules)} modules from tensor ({names})." + logger.info(message) + return True, message + + def _resolve_module_payloads( + self, + named_tensors: Any, + modules_to_update: list[tuple[str, torch.nn.Module]], + ) -> dict[str, Any]: + """Resolve a generic tensor payload to per-module payload mapping.""" + module_names = [name for name, _ in modules_to_update] + + # Preferred format for multi-module update: + # { + # "transformer": , + # "vae": , + # } + if isinstance(named_tensors, dict): + missing = [name for name in module_names if name not in named_tensors] + if missing: + raise ValueError( + f"Missing tensor payload for module(s): {missing}. " + f"Provided modules: {list(named_tensors.keys())}" + ) + return {name: named_tensors[name] for name in module_names} + + # Single-module shortcut: allow direct payload when exactly one target module exists. + if len(module_names) == 1: + return {module_names[0]: named_tensors} + + raise ValueError( + "Ambiguous tensor payload for multi-module update. " + "Provide a dict mapping module_name -> module payload, " + f"requested modules: {module_names}." + ) + + def _materialize_weights_iter(self, module_payload: Any, load_format: str | None): + """Convert one module payload to an iterator of (param_name, tensor).""" + if load_format == "flattened_bucket": + if not isinstance(module_payload, dict): + raise ValueError( + "flattened_bucket payload must be a dict with " + "'flattened_tensor' and 'metadata'." + ) + flattened_tensor = module_payload.get("flattened_tensor") + metadata = module_payload.get("metadata") + if flattened_tensor is None or metadata is None: + raise ValueError( + "flattened_bucket payload missing 'flattened_tensor' or 'metadata'." + ) + return self._reconstruct_from_flattened_bucket(flattened_tensor, metadata) + + # Default/direct format: list/tuple of (name, tensor) + if isinstance(module_payload, (list, tuple)): + return iter(module_payload) + + raise ValueError( + f"Unsupported module payload type for load_format={load_format}: " + f"{type(module_payload).__name__}" + ) + + def _reconstruct_from_flattened_bucket(self, flattened_tensor: Any, metadata: Any): + """Reconstruct [(name, tensor), ...] from flattened-bucket payload.""" + if not isinstance(flattened_tensor, torch.Tensor): + raise ValueError( + "flattened_bucket 'flattened_tensor' must be a torch.Tensor." + ) + if not isinstance(metadata, list): + raise ValueError("flattened_bucket 'metadata' must be a list.") + + converted_metadata: list[FlattenedTensorMetadata] = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=torch.Size(meta.shape), + dtype=self._normalize_torch_dtype(meta.dtype), + start_idx=int(meta.start_idx), + end_idx=int(meta.end_idx), + numel=int(meta.numel), + ) + converted_metadata.append(converted_meta) + + bucket = FlattenedTensorBucket( + flattened_tensor=flattened_tensor, + metadata=converted_metadata, + ) + return bucket.reconstruct_tensors() + + def _normalize_torch_dtype(self, dtype: Any) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + # Supports "torch.float16" and "float16". + name = dtype.split(".")[-1] + normalized = getattr(torch, name, None) + if isinstance(normalized, torch.dtype): + return normalized + raise ValueError(f"Unsupported dtype in flattened_bucket metadata: {dtype!r}") diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 8757da74ee68..881c08bbac3d 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -29,6 +29,10 @@ get_ulysses_parallel_rank, get_ulysses_parallel_world_size, ) +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + UpdateWeightFromTensorCheckerReqInput, + UpdateWeightFromTensorReqInput, +) from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs from sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum from sglang.multimodal_gen.runtime.loader.weights_updater import ( @@ -58,6 +62,10 @@ PerformanceLogger, capture_memory_snapshot, ) +from sglang.multimodal_gen.runtime.utils.update_weight_from_tensor_checker import ( + UpdateWeightFromTensorChecker, +) +from sglang.srt.utils import MultiprocessingSerializer from sglang.srt.utils.network import NetworkAddress logger = init_logger(__name__) @@ -425,6 +433,57 @@ def update_weights_from_disk( self.pipeline.model_path = model_path return success, message + def update_weights_from_tensor( + self, + req: UpdateWeightFromTensorReqInput, + ) -> tuple[bool, str]: + """Update model weights from serialized tensor payloads.""" + if not self.pipeline: + return False, "Pipeline is not initialized" + + payloads, error = self._select_rank_scoped_payload( + payloads=req.serialized_named_tensors, + field_name="serialized_named_tensors", + ) + if error is not None: + return False, error + + try: + named_tensors = MultiprocessingSerializer.deserialize(payloads) + except Exception as e: + return False, f"Failed to deserialize serialized_named_tensors: {e}" + + updater = WeightsUpdater(self.pipeline) + + return updater.update_weights_from_tensor( + named_tensors=named_tensors, + load_format=req.load_format, + target_modules=req.target_modules, + ) + + def update_weight_from_tensor_checker( + self, + req: UpdateWeightFromTensorCheckerReqInput, + ) -> tuple[bool, str]: + """Verify the live transformer weights against expected SHA-256 values.""" + if not self.pipeline: + return False, "Pipeline is not initialized" + + expected_transformer_sha256, error = self._select_rank_scoped_payload( + payloads=req.expected_transformer_sha256, + field_name="expected_transformer_sha256", + ) + if error is not None: + return False, error + + checker = UpdateWeightFromTensorChecker(self.pipeline) + return checker.verify_across_tp( + expected_transformer_sha256, + tp_rank=get_tp_rank(), + tp_world_size=get_tp_world_size(), + tp_cpu_group=self.tp_cpu_group, + ) + def get_weights_checksum( self, module_names: list[str] | None = None ) -> dict[str, str]: @@ -446,6 +505,27 @@ def get_weights_checksum( ) return checksums + def _select_rank_scoped_payload( + self, + payloads: list, + field_name: str, + ) -> tuple[object | None, str | None]: + if not isinstance(payloads, list): + return None, f"{field_name} must be a list" + if not payloads: + return None, f"{field_name} is required" + + tp_world_size = get_tp_world_size() + if len(payloads) not in (1, tp_world_size): + return ( + None, + f"{field_name} size must be 1 or tp_size ({tp_world_size}), " + f"got {len(payloads)}", + ) + + payload_idx = get_tp_rank() if len(payloads) == tp_world_size else 0 + return payloads[payload_idx], None + OOM_MSG = f""" OOM detected. Possible solutions: diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index c11c2c850224..e520a441a094 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -17,6 +17,8 @@ from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, + UpdateWeightFromTensorCheckerReqInput, + UpdateWeightFromTensorReqInput, ) from sglang.multimodal_gen.runtime.entrypoints.utils import ( ListLorasReq, @@ -95,6 +97,8 @@ def __init__( ListLorasReq: self._handle_list_loras, ShutdownReq: self._handle_shutdown, UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, + UpdateWeightFromTensorReqInput: self._handle_update_weights_from_tensor, + UpdateWeightFromTensorCheckerReqInput: self._handle_update_weight_checker, GetWeightsChecksumReqInput: self._handle_get_weights_checksum, } @@ -149,12 +153,36 @@ def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: error=None if success else message, ) + def _handle_update_weights_from_tensor(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_tensor request for RL workflows.""" + req = reqs[0] + success, message = self.worker.update_weights_from_tensor(req) + + if self.server_args.tp_size > 1: + import torch + + torch.distributed.barrier(group=self.worker.tp_cpu_group) + + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) + def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch: """Handle get_weights_checksum request.""" req = reqs[0] checksums = self.worker.get_weights_checksum(module_names=req.module_names) return OutputBatch(output=checksums) + def _handle_update_weight_checker(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_tensor_checker request.""" + req = reqs[0] + success, message = self.worker.update_weight_from_tensor_checker(req) + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) + def _handle_generation(self, reqs: List[Req]): warmup_reqs = [req for req in reqs if req.is_warmup] if warmup_reqs: diff --git a/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py b/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py new file mode 100644 index 000000000000..91677d3dc4c0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/update_weight_from_tensor_checker.py @@ -0,0 +1,425 @@ +"""Verification helpers for diffusion update_weights_from_tensor workflows.""" + +from __future__ import annotations + +import hashlib +from collections.abc import Iterable + +import torch +from torch.distributed.tensor import DTensor + +from sglang.multimodal_gen.runtime.utils.layerwise_offload import ( + iter_materialized_weights, +) + +_TRANSFORMER_MODULE_NAME = "transformer" +_MAX_DISPLAY_TENSORS = 5 +_ROOT_TP_RANK = 0 + + +def _materialize_tensor_for_sha256(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + return tensor.detach().cpu().contiguous() + + +def _compute_sha256_from_materialized(materialized: torch.Tensor) -> str: + materialized = materialized.contiguous() + + hasher = hashlib.sha256() + hasher.update(str(materialized.dtype).encode("utf-8")) + hasher.update(repr(tuple(materialized.shape)).encode("utf-8")) + hasher.update(materialized.view(torch.uint8).numpy().tobytes()) + return hasher.hexdigest() + + +def compute_tensor_sha256(tensor: torch.Tensor) -> str: + return _compute_sha256_from_materialized(_materialize_tensor_for_sha256(tensor)) + + +def build_named_tensor_sha256( + named_tensors: Iterable[tuple[str, torch.Tensor]], +) -> dict[str, str]: + sha256_by_name: dict[str, str] = {} + for name, tensor in named_tensors: + if name in sha256_by_name: + raise ValueError(f"Duplicate tensor name in SHA256 manifest: {name}") + sha256_by_name[name] = compute_tensor_sha256(tensor) + return sha256_by_name + + +class UpdateWeightFromTensorChecker: + def __init__(self, pipeline): + self.pipeline = pipeline + + def verify_across_tp( + self, + expected_transformer_sha256: dict[str, str], + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> tuple[bool, str]: + if tp_world_size == 1: + try: + return self.verify(expected_transformer_sha256) + except Exception as e: + return ( + False, + "Exception while verifying transformer update from tensor: " + f"{type(e).__name__}: {e}", + ) + + return self._verify_on_tp_root( + expected_transformer_sha256=expected_transformer_sha256, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + tp_cpu_group=tp_cpu_group, + ) + + def verify( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[bool, str]: + if not expected_transformer_sha256: + return False, "expected_transformer_sha256 is required" + if not isinstance(expected_transformer_sha256, dict): + return False, "expected_transformer_sha256 must be a dict[str, str]" + + transformer = self.pipeline.get_module(_TRANSFORMER_MODULE_NAME) + if transformer is None: + return False, "Transformer module is not initialized" + if not isinstance(transformer, torch.nn.Module): + return False, "Transformer module is not a torch.nn.Module" + + actual_transformer_sha256 = self._build_local_transformer_sha256( + transformer=transformer, + expected_transformer_sha256=expected_transformer_sha256, + ) + return self._compare_manifests( + expected_transformer_sha256, + actual_transformer_sha256, + ) + + def _verify_on_tp_root( + self, + *, + expected_transformer_sha256: dict[str, str], + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> tuple[bool, str]: + transformer, error_message = self._get_transformer_for_verification( + expected_transformer_sha256 + ) + gathered_errors: list[str | None] | None = ( + [None] * tp_world_size if tp_rank == _ROOT_TP_RANK else None + ) + torch.distributed.gather_object( + error_message, + gathered_errors, + dst=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + + result: tuple[bool, str] | None = None + should_verify_holder = [False] + if tp_rank == _ROOT_TP_RANK: + assert gathered_errors is not None + failures = [ + (rank, message) + for rank, message in enumerate(gathered_errors) + if message is not None + ] + if failures: + rank, message = failures[0] + if len(failures) == 1: + result = (False, f"TP rank {rank}: {message}") + else: + result = ( + False, + f"{len(failures)} TP ranks failed update_weight_from_tensor_checker; " + f"first failure on rank {rank}: {message}", + ) + else: + should_verify_holder[0] = True + + torch.distributed.broadcast_object_list( + should_verify_holder, + src=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + if should_verify_holder[0]: + assert transformer is not None + actual_transformer_sha256 = self._build_root_transformer_sha256( + transformer=transformer, + expected_transformer_sha256=expected_transformer_sha256, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + tp_cpu_group=tp_cpu_group, + ) + if tp_rank == _ROOT_TP_RANK: + assert actual_transformer_sha256 is not None + success, message = self._compare_manifests( + expected_transformer_sha256, + actual_transformer_sha256, + ) + if success: + message = ( + f"Verified transformer update across {tp_world_size} TP ranks." + ) + result = (success, message) + + result_holder = [result] + torch.distributed.broadcast_object_list( + result_holder, + src=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + final_result = result_holder[0] + assert final_result is not None + return final_result + + def _get_transformer_for_verification( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[torch.nn.Module | None, str | None]: + if not expected_transformer_sha256: + return None, "expected_transformer_sha256 is required" + if not isinstance(expected_transformer_sha256, dict): + return None, "expected_transformer_sha256 must be a dict[str, str]" + + transformer = self.pipeline.get_module(_TRANSFORMER_MODULE_NAME) + if transformer is None: + return None, "Transformer module is not initialized" + if not isinstance(transformer, torch.nn.Module): + return None, "Transformer module is not a torch.nn.Module" + return transformer, None + + def _build_local_transformer_sha256( + self, + *, + transformer: torch.nn.Module, + expected_transformer_sha256: dict[str, str], + ) -> dict[str, str]: + return build_named_tensor_sha256( + self._iter_transformer_named_tensors( + transformer, expected_transformer_sha256.keys() + ) + ) + + def _build_root_transformer_sha256( + self, + *, + transformer: torch.nn.Module, + expected_transformer_sha256: dict[str, str], + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> dict[str, str] | None: + local_named_tensors = dict( + self._iter_transformer_named_tensors( + transformer, expected_transformer_sha256.keys() + ) + ) + reference_tensors = dict(transformer.named_parameters()) + reference_tensors.update(dict(transformer.named_buffers())) + actual_transformer_sha256: dict[str, str] | None = ( + {} if tp_rank == _ROOT_TP_RANK else None + ) + for name, expected_sha256 in expected_transformer_sha256.items(): + gathered_tensors = self._gather_materialized_tensors_to_root( + materialized=( + _materialize_tensor_for_sha256(local_named_tensors[name]) + if name in local_named_tensors + else None + ), + tp_rank=tp_rank, + tp_world_size=tp_world_size, + tp_cpu_group=tp_cpu_group, + ) + if tp_rank != _ROOT_TP_RANK: + continue + + reconstructed_sha256 = self._compute_transformer_tensor_sha256_from_gathered( + gathered_tensors=gathered_tensors, + expected_sha256=expected_sha256, + reference_tensor=reference_tensors.get(name), + ) + if reconstructed_sha256 is not None: + assert actual_transformer_sha256 is not None + actual_transformer_sha256[name] = reconstructed_sha256 + + return actual_transformer_sha256 + + def _compute_transformer_tensor_sha256_from_gathered( + self, + *, + gathered_tensors: list[torch.Tensor | None] | None, + expected_sha256: str, + reference_tensor: torch.Tensor | None, + ) -> str | None: + if gathered_tensors is None: + return None + + valid_tensors = [tensor for tensor in gathered_tensors if tensor is not None] + if len(valid_tensors) != len(gathered_tensors): + return None + + local_sha256s = [ + _compute_sha256_from_materialized(materialized) + for materialized in valid_tensors + ] + if all(local_sha256 == expected_sha256 for local_sha256 in local_sha256s): + return expected_sha256 + + candidate_dims = self._get_tp_candidate_dims(reference_tensor) + for shard_dim in candidate_dims: + reconstructed = self._reconstruct_tp_tensor( + gathered_tensors=valid_tensors, + shard_dim=shard_dim, + ) + if reconstructed is None: + continue + reconstructed_sha256 = _compute_sha256_from_materialized(reconstructed) + if reconstructed_sha256 == expected_sha256: + return reconstructed_sha256 + + return local_sha256s[0] + + def _get_tp_candidate_dims( + self, + reference_tensor: torch.Tensor | None, + ) -> list[int]: + if reference_tensor is None: + return [] + + candidate_dims: list[int] = [] + if isinstance(reference_tensor, DTensor): + for placement in reference_tensor.placements: + shard_dim = getattr(placement, "dim", None) + if isinstance(shard_dim, int): + candidate_dims.append(shard_dim) + + for attr in ("input_dim", "output_dim"): + shard_dim = getattr(reference_tensor, attr, None) + if isinstance(shard_dim, int): + candidate_dims.append(shard_dim) + + deduped_dims: list[int] = [] + for shard_dim in candidate_dims: + if shard_dim not in deduped_dims: + deduped_dims.append(shard_dim) + return deduped_dims + + def _gather_materialized_tensors_to_root( + self, + *, + materialized: torch.Tensor | None, + tp_rank: int, + tp_world_size: int, + tp_cpu_group, + ) -> list[torch.Tensor | None] | None: + gathered_tensors: list[torch.Tensor | None] | None = ( + [None] * tp_world_size if tp_rank == _ROOT_TP_RANK else None + ) + torch.distributed.gather_object( + materialized, + gathered_tensors, + dst=_ROOT_TP_RANK, + group=tp_cpu_group, + ) + return gathered_tensors + + def _reconstruct_tp_tensor( + self, + *, + gathered_tensors: list[torch.Tensor], + shard_dim: int, + ) -> torch.Tensor | None: + if not gathered_tensors: + return None + + first_tensor = gathered_tensors[0] + if first_tensor.ndim == 0: + return None + + shard_dim %= first_tensor.ndim + for tensor in gathered_tensors[1:]: + if tensor.ndim != first_tensor.ndim or tensor.dtype != first_tensor.dtype: + return None + if any( + lhs != rhs + for dim, (lhs, rhs) in enumerate(zip(first_tensor.shape, tensor.shape)) + if dim != shard_dim + ): + return None + + return torch.cat(gathered_tensors, dim=shard_dim).contiguous() + + def _iter_transformer_named_tensors( + self, + transformer: torch.nn.Module, + expected_names: Iterable[str], + ): + expected_name_set = set(expected_names) + seen_names: set[str] = set() + + for name, tensor in iter_materialized_weights(transformer): + if name not in expected_name_set: + continue + seen_names.add(name) + yield name, tensor + + for name, tensor in transformer.named_buffers(): + if name in seen_names or name not in expected_name_set: + continue + seen_names.add(name) + yield name, tensor + + def _compare_manifests( + self, + expected_transformer_sha256: dict[str, str], + actual_transformer_sha256: dict[str, str], + ) -> tuple[bool, str]: + missing_names = sorted( + name + for name in expected_transformer_sha256 + if name not in actual_transformer_sha256 + ) + mismatched_names = sorted( + name + for name, expected_sha256 in expected_transformer_sha256.items() + if name in actual_transformer_sha256 + and actual_transformer_sha256[name] != expected_sha256 + ) + + if missing_names or mismatched_names: + parts: list[str] = [] + if missing_names: + parts.append( + "missing " + f"{len(missing_names)} tensor(s): " + f"{self._format_tensor_names(missing_names)}" + ) + if mismatched_names: + parts.append( + "checksum mismatch for " + f"{len(mismatched_names)} tensor(s): " + f"{self._format_tensor_names(mismatched_names)}" + ) + return ( + False, + "Transformer update weight check failed: " + "; ".join(parts), + ) + + return ( + True, + f"Verified transformer update for {len(expected_transformer_sha256)} tensor(s).", + ) + + def _format_tensor_names(self, names: list[str]) -> str: + displayed = names[:_MAX_DISPLAY_TENSORS] + formatted = ", ".join(displayed) + if len(names) > _MAX_DISPLAY_TENSORS: + formatted += f", ... (+{len(names) - _MAX_DISPLAY_TENSORS} more)" + return formatted diff --git a/python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md b/python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md new file mode 100644 index 000000000000..51c4c2706a45 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/README_update_weights_from_tensor.md @@ -0,0 +1,55 @@ +# Diffusion `update_weights_from_tensor` README + +This document describes the tensor-based in-place weight update flow for diffusion models in `sglang.multimodal_gen`. + +## Endpoint + +- `POST /update_weights_from_tensor` + +## Request Schema + +- `serialized_named_tensors: List[Union[str, bytes]]` (required) +- `load_format: Optional[str]` (`None`/`direct` or `"flattened_bucket"`) +- `target_modules: Optional[List[str]]` (for example: `["transformer"]`, `["transformer", "vae"]`) +- `weight_version: Optional[str]` + +Notes: +- `flush_cache` is intentionally not part of the tensor request. +- Response shape follows: + - `{"success": bool, "message": str}` + +## TP Payload Rules + +- `len(serialized_named_tensors)` must be either: + - `1`, or + - `tp_size` +- If length is `tp_size`, each TP rank consumes the payload at its own index. +- If length is `1`, all TP ranks consume index `0`. + +## Module Payload Rules + +- Single-module update (`target_modules` has one module): + - Payload can be passed directly. +- Multi-module update: + - Payload must be a dict keyed by module name: + +```python +{ + "transformer": , + "vae": , +} +``` + +## Supported Module Payload Formats + +- `load_format=None` (or `direct`-style payload): + - `[(param_name, tensor), ...]` +- `load_format="flattened_bucket"`: + - `{"flattened_tensor": tensor, "metadata": [...]}` + +## Safety Semantics + +- Only selected modules are updated (`target_modules` aware). +- Update is all-or-nothing across requested modules: + - on failure, already-updated modules are rolled back to previous disk weights. +- TP synchronization barrier is used in scheduler path to avoid mixed-rank model state. diff --git a/test/test_update_weight_from_tensor_checker.py b/test/test_update_weight_from_tensor_checker.py new file mode 100644 index 000000000000..206d1e7dde91 --- /dev/null +++ b/test/test_update_weight_from_tensor_checker.py @@ -0,0 +1,254 @@ +import importlib.util +import os +import pickle +import sys +import tempfile +import types +import unittest +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import nn +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import Replicate, distribute_tensor + +UTIL_PATH = ( + Path(__file__).resolve().parents[1] + / "python" + / "sglang" + / "multimodal_gen" + / "runtime" + / "utils" + / "update_weight_from_tensor_checker.py" +) + + +def _install_layerwise_offload_stub(): + package_names = [ + "sglang", + "sglang.multimodal_gen", + "sglang.multimodal_gen.runtime", + "sglang.multimodal_gen.runtime.utils", + ] + original_modules: dict[str, types.ModuleType | None] = { + name: sys.modules.get(name) for name in package_names + } + for name in package_names: + sys.modules.setdefault(name, types.ModuleType(name)) + + layerwise_offload = types.ModuleType( + "sglang.multimodal_gen.runtime.utils.layerwise_offload" + ) + + def iter_materialized_weights(module: nn.Module): + yield from module.named_parameters() + + layerwise_offload.iter_materialized_weights = iter_materialized_weights + original_modules[layerwise_offload.__name__] = sys.modules.get( + layerwise_offload.__name__ + ) + sys.modules[layerwise_offload.__name__] = layerwise_offload + return original_modules + + +def _load_checker_module(): + original_modules = _install_layerwise_offload_stub() + try: + spec = importlib.util.spec_from_file_location( + "update_weight_from_tensor_checker_test_module", + UTIL_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + finally: + for name, original_module in original_modules.items(): + if original_module is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original_module + + +_CHECKER_MODULE = _load_checker_module() +UpdateWeightFromTensorChecker = _CHECKER_MODULE.UpdateWeightFromTensorChecker +build_named_tensor_sha256 = _CHECKER_MODULE.build_named_tensor_sha256 + + +class _ToyTransformer(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4, bias=False) + self.register_buffer("running_scale", torch.ones(4)) + + +class _FakePipeline: + def __init__(self, transformer: nn.Module): + self._transformer = transformer + + def get_module(self, name: str): + if name == "transformer": + return self._transformer + return None + + +def _iter_named_tensors(module: nn.Module): + yield from module.named_parameters() + yield from module.named_buffers() + + +class _ShardedTransformer(nn.Module): + def __init__(self, local_weight: torch.Tensor, *, shard_attr: str, shard_dim: int): + super().__init__() + self.weight = nn.Parameter(local_weight) + setattr(self.weight, shard_attr, shard_dim) + + +def _verify_sharded_weight_across_tp_worker( + rank: int, + rendezvous_file: str, + shard_attr: str, + shard_dim: int, + results_dir: str, +): + os.environ.setdefault("GLOO_SOCKET_IFNAME", "lo") + dist.init_process_group( + backend="gloo", + init_method=f"file://{rendezvous_file}", + rank=rank, + world_size=2, + ) + try: + full_weight = torch.arange(16, dtype=torch.float32).reshape(4, 4) + shard_size = full_weight.shape[shard_dim] // 2 + local_weight = full_weight.narrow( + shard_dim, rank * shard_size, shard_size + ).clone() + + transformer = _ShardedTransformer( + local_weight, + shard_attr=shard_attr, + shard_dim=shard_dim, + ) + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + expected_transformer_sha256 = build_named_tensor_sha256([("weight", full_weight)]) + + result = checker.verify_across_tp( + expected_transformer_sha256=expected_transformer_sha256, + tp_rank=rank, + tp_world_size=2, + tp_cpu_group=dist.group.WORLD, + ) + result_path = Path(results_dir) / f"rank_{rank}.pkl" + with result_path.open("wb") as f: + pickle.dump(result, f) + finally: + dist.destroy_process_group() + + +class UpdateWeightFromTensorCheckerTest(unittest.TestCase): + def test_matches_live_transformer(self): + transformer = _ToyTransformer() + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + expected_transformer_sha256 = build_named_tensor_sha256( + _iter_named_tensors(transformer) + ) + + success, message = checker.verify(expected_transformer_sha256) + + self.assertTrue(success) + self.assertEqual(message, "Verified transformer update for 2 tensor(s).") + + def test_detects_modified_tensor(self): + transformer = _ToyTransformer() + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + expected_transformer_sha256 = build_named_tensor_sha256( + _iter_named_tensors(transformer) + ) + + with torch.no_grad(): + transformer.linear.weight.add_(1) + + success, message = checker.verify(expected_transformer_sha256) + + self.assertFalse(success) + self.assertIn("checksum mismatch for 1 tensor(s): linear.weight", message) + + def test_detects_missing_tensor(self): + transformer = _ToyTransformer() + checker = UpdateWeightFromTensorChecker(_FakePipeline(transformer)) + + success, message = checker.verify({"missing.weight": "deadbeef"}) + + self.assertFalse(success) + self.assertEqual( + message, + "Transformer update weight check failed: " + "missing 1 tensor(s): missing.weight", + ) + + def test_supports_dtensor(self): + if dist.is_initialized(): + self.skipTest("process group already initialized") + + with tempfile.NamedTemporaryFile() as rendezvous_file: + os.environ.setdefault("GLOO_SOCKET_IFNAME", "lo") + dist.init_process_group( + backend="gloo", + init_method=f"file://{rendezvous_file.name}", + rank=0, + world_size=1, + ) + try: + local_tensor = torch.arange(4, dtype=torch.float32) + device_mesh = init_device_mesh("cpu", (1,)) + distributed_tensor = distribute_tensor( + local_tensor, device_mesh, [Replicate()] + ) + + expected_sha256 = build_named_tensor_sha256([("weight", local_tensor)]) + actual_sha256 = build_named_tensor_sha256( + [("weight", distributed_tensor)] + ) + + self.assertEqual(actual_sha256, expected_sha256) + finally: + dist.destroy_process_group() + + def _run_tp_sharded_verification(self, *, shard_attr: str, shard_dim: int): + if dist.is_initialized(): + self.skipTest("process group already initialized") + + with tempfile.TemporaryDirectory() as tmpdir: + rendezvous_file = str(Path(tmpdir) / "dist_init") + Path(rendezvous_file).touch() + mp.spawn( + _verify_sharded_weight_across_tp_worker, + args=(rendezvous_file, shard_attr, shard_dim, tmpdir), + nprocs=2, + join=True, + ) + + results = {} + for rank in range(2): + result_path = Path(tmpdir) / f"rank_{rank}.pkl" + with result_path.open("rb") as f: + results[rank] = pickle.load(f) + + self.assertEqual(sorted(results.keys()), [0, 1]) + for rank in range(2): + success, message = results[rank] + self.assertTrue(success, f"rank {rank}: {message}") + self.assertIn("across 2 TP ranks", message) + + def test_verify_across_tp_gathers_output_sharded_tensor(self): + self._run_tp_sharded_verification(shard_attr="output_dim", shard_dim=0) + + def test_verify_across_tp_gathers_input_sharded_tensor(self): + self._run_tp_sharded_verification(shard_attr="input_dim", shard_dim=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_update_weights_from_tensor_checker_e2e.py b/test/test_update_weights_from_tensor_checker_e2e.py new file mode 100644 index 000000000000..d21b7d8d8c2e --- /dev/null +++ b/test/test_update_weights_from_tensor_checker_e2e.py @@ -0,0 +1,518 @@ +import base64 +import hashlib +import os +import pickle +import signal +import socket +import subprocess +import tempfile +import time +import unittest +from pathlib import Path + +import requests +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import safe_open + +from sglang.srt.utils import MultiprocessingSerializer + +BASE_MODEL = "Tongyi-MAI/Z-Image-Turbo" +TRANSFORMER_MODULE = "transformer" +NUM_TENSORS_TO_UPDATE = 4 +SERVER_READY_MESSAGE = "Application startup complete." + + +def _get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _list_safetensor_files(model_dir: str) -> list[Path]: + return sorted(Path(model_dir).glob("*.safetensors")) + + +def _iter_transformer_weights_from_disk(model_path: str): + local_model_path = snapshot_download( + repo_id=model_path, + allow_patterns=[f"{TRANSFORMER_MODULE}/*"], + ) + weights_dir = Path(local_model_path) / TRANSFORMER_MODULE + safetensor_files = _list_safetensor_files(str(weights_dir)) + if not safetensor_files: + raise FileNotFoundError( + f"No safetensors files found for module '{TRANSFORMER_MODULE}' in {weights_dir}" + ) + + for safetensor_file in safetensor_files: + with safe_open(str(safetensor_file), framework="pt", device="cpu") as handle: + for name in handle.keys(): + yield name, handle.get_tensor(name) + + +def _compute_tensor_sha256(tensor: torch.Tensor) -> str: + materialized = tensor.detach().cpu().contiguous() + hasher = hashlib.sha256() + hasher.update(str(materialized.dtype).encode("utf-8")) + hasher.update(repr(tuple(materialized.shape)).encode("utf-8")) + hasher.update(materialized.view(torch.uint8).numpy().tobytes()) + return hasher.hexdigest() + + +def _build_named_tensor_sha256( + named_tensors: list[tuple[str, torch.Tensor]], +) -> dict[str, str]: + return {name: _compute_tensor_sha256(tensor) for name, tensor in named_tensors} + + +def _select_transformer_tensors( + model_path: str, + max_tensors: int = NUM_TENSORS_TO_UPDATE, +) -> list[tuple[str, torch.Tensor]]: + selected_named_tensors: list[tuple[str, torch.Tensor]] = [] + for name, tensor in _iter_transformer_weights_from_disk(model_path): + if not tensor.is_floating_point(): + continue + selected_named_tensors.append((name, tensor.to(torch.bfloat16).clone())) + if len(selected_named_tensors) == max_tensors: + break + + if not selected_named_tensors: + raise AssertionError("Expected at least one floating-point transformer tensor") + + return selected_named_tensors + + +def _build_shifted_named_tensors( + named_tensors: list[tuple[str, torch.Tensor]], + delta: float, +) -> list[tuple[str, torch.Tensor]]: + shifted_named_tensors: list[tuple[str, torch.Tensor]] = [] + for index, (name, tensor) in enumerate(named_tensors): + shifted_tensor = tensor.clone() + if index == 0: + shifted_tensor.add_(delta) + shifted_named_tensors.append((name, shifted_tensor)) + return shifted_named_tensors + + +def _serialize_named_tensors(named_tensors: list[tuple[str, torch.Tensor]]) -> str: + return base64.b64encode(pickle.dumps(named_tensors)).decode("utf-8") + + +def _serialize_named_tensors_multiprocessing( + named_tensors: list[tuple[str, torch.Tensor]], +) -> str: + return MultiprocessingSerializer.serialize(named_tensors, output_str=True) + + +class _ServerRunner: + def __init__( + self, + model_path: str, + *, + num_gpus: int = 1, + tp_size: int | None = None, + ): + self.model_path = model_path + self.num_gpus = num_gpus + self.tp_size = tp_size + self.port = _get_free_port() + self.process: subprocess.Popen | None = None + self.log_file = None + self.log_path = Path(tempfile.gettempdir()) / ( + f"sglang_update_weight_checker_{self.port}.log" + ) + self.log_path.unlink(missing_ok=True) + + @property + def base_url(self) -> str: + return f"http://127.0.0.1:{self.port}" + + def start(self) -> None: + command = [ + "sglang", + "serve", + "--model-path", + self.model_path, + "--port", + str(self.port), + "--log-level=debug", + "--num-gpus", + str(self.num_gpus), + ] + if self.tp_size is not None: + command.extend(["--tp-size", str(self.tp_size)]) + env = os.environ.copy() + env["SGLANG_DIFFUSION_STAGE_LOGGING"] = "1" + + self.log_file = self.log_path.open("w", encoding="utf-8") + self.process = subprocess.Popen( + command, + stdout=self.log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + env=env, + ) + try: + self._wait_for_ready() + except Exception: + self.stop() + raise + + def stop(self) -> None: + if self.process is None: + if self.log_file is not None: + self.log_file.close() + self.log_file = None + return + if self.process.poll() is None: + try: + os.killpg(self.process.pid, signal.SIGTERM) + self.process.wait(timeout=30) + except Exception: + os.killpg(self.process.pid, signal.SIGKILL) + self.process.wait(timeout=30) + if self.log_file is not None: + self.log_file.close() + self.log_file = None + self.process = None + + def _wait_for_ready(self) -> None: + deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "1200")) + start_time = time.time() + + while time.time() - start_time < deadline: + if self.process is None: + raise RuntimeError("Server process was not started") + if self.process.poll() is not None: + raise RuntimeError( + f"Server exited early with code {self.process.returncode}.\n" + f"{self._get_log_tail()}" + ) + if self.log_path.exists(): + log_content = self.log_path.read_text(encoding="utf-8", errors="ignore") + if SERVER_READY_MESSAGE in log_content: + return + time.sleep(1) + + raise TimeoutError( + f"Server did not become ready within {deadline}s.\n{self._get_log_tail()}" + ) + + def _get_log_tail(self, lines: int = 200) -> str: + if not self.log_path.exists(): + return "" + log_content = self.log_path.read_text(encoding="utf-8", errors="ignore") + return "\n".join(log_content.splitlines()[-lines:]) + + +class UpdateWeightFromTensorCheckerE2ETest(unittest.TestCase): + @classmethod + def setUpClass(cls): + snapshot_download(repo_id=BASE_MODEL) + cls.base_transformer_tensors = _select_transformer_tensors(BASE_MODEL) + cls.expected_updated_tensors = _build_shifted_named_tensors( + cls.base_transformer_tensors, + delta=1.0, + ) + cls.server = _ServerRunner(BASE_MODEL) + cls.server.start() + + @classmethod + def tearDownClass(cls): + cls.server.stop() + + def _update_weights_from_disk(self, model_path: str) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_disk", + json={"model_path": model_path, "target_modules": [TRANSFORMER_MODULE]}, + timeout=300, + ) + return response.json(), response.status_code + + def _update_weights_from_tensor( + self, + named_tensors: list[tuple[str, torch.Tensor]], + *, + serializer=_serialize_named_tensors, + ) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [serializer(named_tensors)], + "target_modules": [TRANSFORMER_MODULE], + }, + timeout=300, + ) + return response.json(), response.status_code + + def _check_updated_weights_from_tensor( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor_checker", + json={"expected_transformer_sha256": [expected_transformer_sha256]}, + timeout=300, + ) + return response.json(), response.status_code + + def test_update_weights_from_tensor_checker_success(self): + reset_result, reset_status = self._update_weights_from_disk(BASE_MODEL) + self.assertEqual(reset_status, 200, reset_result) + self.assertTrue(reset_result.get("success"), reset_result) + + update_result, update_status = self._update_weights_from_tensor( + self.expected_updated_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + self.expected_updated_tensors + ) + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 200, check_result) + self.assertTrue(check_result.get("success"), check_result) + self.assertIn("Verified transformer update", check_result.get("message", "")) + + @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") + def test_update_weights_from_tensor_checker_with_multiprocessing_serializer(self): + cuda_named_tensors = [ + (name, tensor.to(device="cuda", non_blocking=True)) + for name, tensor in self.expected_updated_tensors + ] + try: + update_result, update_status = self._update_weights_from_tensor( + cuda_named_tensors, + serializer=_serialize_named_tensors_multiprocessing, + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + self.expected_updated_tensors + ) + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 200, check_result) + self.assertTrue(check_result.get("success"), check_result) + self.assertIn( + "Verified transformer update", + check_result.get("message", ""), + ) + finally: + del cuda_named_tensors + torch.cuda.empty_cache() + + def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): + reset_result, reset_status = self._update_weights_from_disk(BASE_MODEL) + self.assertEqual(reset_status, 200, reset_result) + self.assertTrue(reset_result.get("success"), reset_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + self.expected_updated_tensors + ) + corrupted_named_tensors = _build_shifted_named_tensors( + self.base_transformer_tensors, + delta=2.0, + ) + + update_result, update_status = self._update_weights_from_tensor( + corrupted_named_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 400, check_result) + self.assertFalse(check_result.get("success", True), check_result) + self.assertIn("checksum mismatch", check_result.get("message", "")) + self.assertIn( + self.expected_updated_tensors[0][0], + check_result.get("message", ""), + ) + + +def _select_tp_candidate_tensors( + model_path: str, + max_tensors: int = 24, +) -> list[tuple[str, torch.Tensor]]: + norm_candidates: list[tuple[str, torch.Tensor]] = [] + other_candidates: list[tuple[str, torch.Tensor]] = [] + for name, tensor in _iter_transformer_weights_from_disk(model_path): + if not tensor.is_floating_point() or tensor.ndim != 1: + continue + candidate = (name, tensor.to(torch.bfloat16).clone()) + if "norm" in name: + norm_candidates.append(candidate) + elif name.endswith(".bias"): + other_candidates.append(candidate) + + if len(norm_candidates) + len(other_candidates) >= max_tensors: + break + + candidates = norm_candidates + other_candidates + if not candidates: + raise AssertionError("Expected at least one 1D transformer tensor candidate") + return candidates[:max_tensors] + + +@unittest.skipUnless(torch.cuda.device_count() >= 2, "requires at least 2 GPUs") +class UpdateWeightFromTensorChecker2GPUTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + snapshot_download(repo_id=BASE_MODEL) + cls.tp_candidates = _select_tp_candidate_tensors(BASE_MODEL) + cls.selected_updated_tensors = None + + def setUp(self): + self.server = _ServerRunner(BASE_MODEL, num_gpus=2, tp_size=2) + self.server.start() + + def tearDown(self): + self.server.stop() + self.server = None + + def _update_weights_from_tensor( + self, + named_tensors: list[tuple[str, torch.Tensor]], + *, + serializer=_serialize_named_tensors, + ) -> tuple[dict, int]: + serialized_payload = serializer(named_tensors) + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [serialized_payload, serialized_payload], + "target_modules": [TRANSFORMER_MODULE], + }, + timeout=300, + ) + return response.json(), response.status_code + + def _check_updated_weights_from_tensor( + self, + expected_transformer_sha256: dict[str, str], + ) -> tuple[dict, int]: + response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor_checker", + json={ + "expected_transformer_sha256": [ + expected_transformer_sha256, + expected_transformer_sha256, + ] + }, + timeout=300, + ) + return response.json(), response.status_code + + def _build_selected_updated_tensors(self): + cls = type(self) + if cls.selected_updated_tensors is not None: + return cls.selected_updated_tensors + + errors: list[str] = [] + for name, tensor in cls.tp_candidates: + updated_tensors = [(name, tensor.clone().add_(1.0))] + + serialized_payload = _serialize_named_tensors(updated_tensors) + update_response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [serialized_payload, serialized_payload], + "target_modules": [TRANSFORMER_MODULE], + }, + timeout=300, + ) + if ( + update_response.status_code != 200 + or not update_response.json().get("success") + ): + errors.append( + f"{name}: update failed with " + f"{update_response.status_code} {update_response.text}" + ) + continue + + expected_transformer_sha256 = _build_named_tensor_sha256(updated_tensors) + check_response = requests.post( + f"{self.server.base_url}/update_weights_from_tensor_checker", + json={ + "expected_transformer_sha256": [ + expected_transformer_sha256, + expected_transformer_sha256, + ] + }, + timeout=300, + ) + if ( + check_response.status_code == 200 + and check_response.json().get("success") + ): + cls.selected_updated_tensors = updated_tensors + return cls.selected_updated_tensors + + errors.append( + f"{name}: checker failed with " + f"{check_response.status_code} {check_response.text}" + ) + + displayed_errors = "; ".join(errors[:5]) + raise AssertionError( + "Could not find a TP-compatible transformer tensor candidate for 2 GPU " + f"update_weights_from_tensor checker test. First errors: {displayed_errors}" + ) + + def test_update_weights_from_tensor_checker_success(self): + selected_updated_tensors = self._build_selected_updated_tensors() + update_result, update_status = self._update_weights_from_tensor( + selected_updated_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + expected_transformer_sha256 = _build_named_tensor_sha256( + selected_updated_tensors + ) + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 200, check_result) + self.assertTrue(check_result.get("success"), check_result) + self.assertIn("across 2 TP ranks", check_result.get("message", "")) + + def test_update_weights_from_tensor_checker_detects_corrupted_payload(self): + selected_updated_tensors = self._build_selected_updated_tensors() + expected_transformer_sha256 = _build_named_tensor_sha256( + selected_updated_tensors + ) + corrupted_named_tensors = [ + (name, tensor.clone().add_(1.0)) for name, tensor in selected_updated_tensors + ] + + update_result, update_status = self._update_weights_from_tensor( + corrupted_named_tensors + ) + self.assertEqual(update_status, 200, update_result) + self.assertTrue(update_result.get("success"), update_result) + + check_result, check_status = self._check_updated_weights_from_tensor( + expected_transformer_sha256 + ) + self.assertEqual(check_status, 400, check_result) + self.assertFalse(check_result.get("success", True), check_result) + self.assertIn("checksum mismatch", check_result.get("message", "")) + self.assertIn(selected_updated_tensors[0][0], check_result.get("message", "")) + + +if __name__ == "__main__": + unittest.main()