diff --git a/docs/advanced_features/sglang_for_rl.md b/docs/advanced_features/sglang_for_rl.md index 2fd84c90de69..12eb41540339 100644 --- a/docs/advanced_features/sglang_for_rl.md +++ b/docs/advanced_features/sglang_for_rl.md @@ -106,6 +106,29 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates **Python Engine API:** `engine.update_weights_from_disk(model_path, load_format=None)` +**Diffusion engine (SGLang-Diffusion):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint with the following behavior: + +- **All-or-nothing with rollback:** if any module fails to load, all previously updated modules are rolled back to the original weights by reloading from the original model path. No partial updates are left behind. If rollback itself fails, the exception propagates so the caller knows the model is in an inconsistent state. +- **Offload-aware:** when layerwise offload (`--dit-layerwise-offload`) is enabled, the diffusion offload manager replaces GPU parameters with small `torch.empty((1,))` placeholders while real weights live in consolidated pinned CPU buffers. A naive `param.data.copy_()` would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state. +- **DTensor-aware:** parameters distributed via `torch.distributed.tensor` (tensor parallelism) are updated through `distribute_tensor` so that each shard is correctly placed on the right device mesh. + +**Request body:** + +| Field | Description | Defaults | Options | +| --- | --- | --- | --- | +| `model_path` | The model path with the new weights. | Required | Type: str | +| `flush_cache` | Flush TeaCache state after update. | `True` | Type: bool | +| `target_modules` | List of module names to update (e.g. `["transformer"]`). If omitted, all `nn.Module` components are updated. | `None` | Type: list[str] | + +**Response body:** + +| Field | Description | Defaults | Options | +| --- | --- | --- | --- | +| `success` | Whether the update succeeded. | - | Type: bool | +| `message` | Status / error message. | - | Type: str | + +> **Note:** The diffusion engine (SGLang-Diffusion) does not currently support hot refit (updating weights while inference is in progress). The diffusion scheduler processes one request at a time and completes the entire inference before handling the next request, so weight updates and inference never run concurrently. + ### Update Weights from Tensor **When to use:** diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index ba96d8b8392a..30a60b35adcf 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -17,6 +17,7 @@ VertexGenerateReqInput, ) from sglang.multimodal_gen.runtime.entrypoints.openai.utils import build_sampling_params +from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api from sglang.multimodal_gen.runtime.entrypoints.utils import ( prepare_request, save_outputs, @@ -214,6 +215,7 @@ def create_app(server_args: ServerArgs): app.include_router(common_api.router) app.include_router(image_api.router) app.include_router(video_api.router) + app.include_router(weights_api.router) app.state.server_args = server_args return app diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py new file mode 100644 index 000000000000..bda72df12a8f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -0,0 +1,19 @@ +"""Request/response data structures for post-training APIs.""" + +from dataclasses import dataclass + + +@dataclass +class UpdateWeightFromDiskReqInput: + """Request to update model weights from disk for diffusion models.""" + + model_path: str + flush_cache: bool = True + target_modules: list[str] | None = None + + +@dataclass +class GetWeightsChecksumReqInput: + """Compute SHA-256 checksum of loaded module weights for verification.""" + + module_names: list[str] | None = None diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py new file mode 100644 index 000000000000..1b9312d8ea0f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -0,0 +1,62 @@ +"""Weight update API for the diffusion engine.""" + +from fastapi import APIRouter, Request +from fastapi.responses import ORJSONResponse + +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, + UpdateWeightFromDiskReqInput, +) +from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client + +router = APIRouter() + + +@router.post("/update_weights_from_disk") +async def update_weights_from_disk(request: Request): + """Update model weights from disk inplace without restarting the server.""" + body = await request.json() + model_path = body.get("model_path") + if not model_path: + return ORJSONResponse( + {"success": False, "message": "model_path is required"}, + status_code=400, + ) + + req = UpdateWeightFromDiskReqInput( + model_path=model_path, + flush_cache=body.get("flush_cache", True), + target_modules=body.get("target_modules"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + + +@router.post("/get_weights_checksum") +async def get_weights_checksum(request: Request): + """Return SHA-256 checksum of each requested module's weights.""" + body = await request.json() + req = GetWeightsChecksumReqInput( + module_names=body.get("module_names"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse({"error": str(e)}, status_code=500) + + return ORJSONResponse(response.output, status_code=200) diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py index e74c40c1756a..7507dc10833d 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -2,19 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py -"""Utilities for downloading and initializing model weights.""" +"""Utilities for downloading, loading, initializing and verifying model weights.""" import hashlib import json import os import tempfile -from collections.abc import Generator +from collections.abc import Generator, Iterable from pathlib import Path import filelock import huggingface_hub.constants import torch from safetensors.torch import safe_open +from torch.distributed.tensor import DTensor from tqdm.auto import tqdm try: @@ -336,3 +337,23 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: # If there were no matches, return the untouched param name return name + + +def compute_weights_checksum( + named_params: Iterable[tuple[str, torch.Tensor]], +) -> str: + """Compute a SHA-256 checksum for a set of (name, tensor) pairs. + + Used to verify the correctness of weight refitting. After a refit, + compare the checksum of the in-GPU model weights against the checksum + of the on-disk tensors or the tensors in the training engine. + """ + hasher = hashlib.sha256() + for name, tensor in sorted(named_params, key=lambda x: x[0]): + hasher.update(name.encode()) + t = tensor.detach() + # DTensor doesn't support .numpy(); extract the local tensor. + if isinstance(t, DTensor): + t = t._local_tensor + hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data) + return hasher.hexdigest() diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py new file mode 100644 index 000000000000..f170809a738e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -0,0 +1,293 @@ +""" +In-place weight updates for diffusion pipeline modules. + +This module provides WeightsUpdater, which swaps model weights at runtime +without restarting the server. It is the diffusion-engine counterpart of the +LLM engine's ModelRunner.update_weights_from_disk. + +Detailed usage of higher level API can be found in + +/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py + +Key design decisions: + +- All-or-nothing with rollback: modules are updated sequentially. If + any module fails (shape mismatch, corrupted file, etc.), every module + that was already updated is rolled back by reloading its weights from + pipeline.model_path (the last successfully-loaded checkpoint). On + success, pipeline.model_path is updated to the new model_path so + that future rollbacks target the latest good checkpoint, not the + originally-launched model. + +- Rollback failures propagate: if rollback itself fails, the exception is + not caught so the caller knows the model is in an inconsistent state. + This matches the LLM engine behaviour. + +- Offload-aware: the diffusion LayerwiseOffloadManager replaces GPU + parameters with torch.empty((1,)) placeholders while real weights live + in consolidated pinned CPU buffers. A naive param.data.copy_() would + fail with a shape mismatch. Instead, the updater dynamically detects + active offload managers and writes new weights directly into their CPU + buffers via update_cpu_weights(), bypassing the placeholders entirely. + For any layer that happens to be prefetched on GPU at update time, the + live GPU tensor is also updated so the change takes effect immediately. + This requires no extra GPU memory and does not disturb the offload state. + +- DTensor-aware: parameters that have been distributed via + torch.distributed.tensor are updated through distribute_tensor + so that each shard is correctly placed on the right device mesh. +""" + +from __future__ import annotations + +import gc +from pathlib import Path + +import torch +from torch.distributed.tensor import DTensor, distribute_tensor + +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]: + """Return updatable nn.Module components for the given pipeline. + + Works with both the native ComposedPipelineBase backend and the + DiffusersPipeline wrapper. + """ + if isinstance(pipeline, DiffusersPipeline): + diffusers_pipe = pipeline.get_module("diffusers_pipeline") + if diffusers_pipe is not None and diffusers_pipe.components is not None: + raw = diffusers_pipe.components + else: + raw = {} + else: + raw = pipeline.modules + return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)} + + +def _get_weights_iter(weights_dir: str): + """Return a (name, tensor) iterator over safetensors in weights_dir.""" + safetensors_files = _list_safetensors_files(weights_dir) + if not safetensors_files: + raise FileNotFoundError(f"No safetensors files found in {weights_dir}") + return safetensors_weights_iterator(safetensors_files) + + +def _validate_weight_files( + local_model_path: str, + modules_to_update: list[tuple[str, torch.nn.Module]], +) -> tuple[dict[str, str], list[str]]: + """Check that every module has a weights directory with safetensors files. + + Returns: + (weights_map, missing) where weights_map maps module name to its + weights directory and missing lists modules without weight files. + """ + weights_map: dict[str, str] = {} + missing: list[str] = [] + for module_name, _ in modules_to_update: + weights_dir = Path(local_model_path) / module_name + if weights_dir.exists() and _list_safetensors_files(str(weights_dir)): + weights_map[module_name] = str(weights_dir) + else: + missing.append(module_name) + return weights_map, missing + + +def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: + """Load weights into a module, handling offload-managed parameters. + + For offloaded modules, updates CPU buffers directly via + update_cpu_weights(); non-offloaded parameters use in-place copy. + """ + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + + if offload_managers: + weight_dict = dict(weights_iter) + offloaded_names: set[str] = set() + for manager in offload_managers: + offloaded_names.update(manager.update_cpu_weights(weight_dict)) + remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) + load_weights_into_model(remaining, dict(module.named_parameters())) + else: + load_weights_into_model(weights_iter, dict(module.named_parameters())) + + +def load_weights_into_model(weights_iter, model_params: dict) -> None: + """Copy weights from weights_iter into model_params in-place.""" + for name, loaded_weight in weights_iter: + if name not in model_params: + continue + param = model_params[name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" + ) + if isinstance(param, DTensor): + distributed_weight = distribute_tensor( + loaded_weight.to(param.dtype), + param.device_mesh, + param.placements, + ) + param._local_tensor.copy_(distributed_weight._local_tensor) + else: + param.data.copy_(loaded_weight.to(param.dtype)) + + +class WeightsUpdater: + """In-place weight updates for diffusion pipeline modules. + + Args: + pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance + whose modules will be updated. The pipeline's model_path + attribute is used for rollback on failure. + """ + + def __init__(self, pipeline): + self.pipeline = pipeline + + def update_weights_from_disk( + self, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update model weights from disk without restarting the server.""" + logger.info(f"Updating weights from disk: {model_path}") + + try: + modules_to_update = self._collect_modules(target_modules) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + + if not modules_to_update: + error_msg = ( + f"No matching modules found for update. " + f"Requested: {target_modules}. " + f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}" + ) + logger.error(error_msg) + return False, error_msg + + try: + local_model_path = maybe_download_model(model_path) + except Exception as e: + return False, f"Failed to download model: {e}" + + weights_map, missing = _validate_weight_files( + local_model_path, modules_to_update + ) + if missing: + error_msg = ( + f"Cannot update weights: missing weight files for modules: {missing}. " + f"No partial updates allowed." + ) + logger.error(error_msg) + return False, error_msg + + logger.info( + f"Updating {len(weights_map)} modules: " + + ", ".join(f"{n} <- {p}" for n, p in weights_map.items()) + ) + + success, message = self._apply_weights(modules_to_update, weights_map) + + gc.collect() + torch.cuda.empty_cache() + + if success and flush_cache: + for _, module in modules_to_update: + if isinstance(module, TeaCacheMixin): + module.reset_teacache_state() + + logger.info(message) + return success, message + + def _collect_modules( + self, target_modules: list[str] | None + ) -> list[tuple[str, torch.nn.Module]]: + """Resolve target_modules to (name, module) pairs. + + Raises: + ValueError: If target_modules contains names not found in the pipeline. + """ + components = get_updatable_modules(self.pipeline) + + if target_modules is None: + names = list(components.keys()) + else: + unknown = [n for n in target_modules if n not in components] + if unknown: + raise ValueError( + f"Module(s) requested for update not found in pipeline: {unknown}. " + f"Available Module(s): {list(components.keys())}" + ) + names = target_modules + + return [(name, components[name]) for name in names] + + def _apply_weights( + self, + modules_to_update: list[tuple[str, torch.nn.Module]], + weights_map: dict[str, str], + ) -> tuple[bool, str]: + """Load weights into each module; rollback on first failure.""" + updated_modules: list[str] = [] + + for module_name, module in modules_to_update: + try: + weights_iter = _get_weights_iter(weights_map[module_name]) + _load_weights_into_module(module, weights_iter) + updated_modules.append(module_name) + except Exception as e: + rollback_list = updated_modules + [module_name] + logger.error( + f"Weight update failed for module '{module_name}': {e}. " + f"Rolling back {len(rollback_list)} module(s) " + f"(including partially-loaded '{module_name}'): " + f"{rollback_list}.", + exc_info=True, + ) + self._rollback(rollback_list) + return False, ( + f"Failed to update module '{module_name}': {e}. " + f"All modules rolled back to original weights." + ) + + names = ", ".join(updated_modules) + return True, f"Updated {len(updated_modules)} modules ({names})." + + def _rollback(self, updated_modules: list[str]) -> None: + """Restore updated_modules to original weights. + + If rollback itself fails the exception propagates so the caller + knows the model is in an inconsistent state. + """ + if not updated_modules: + return + original_path = maybe_download_model(self.pipeline.model_path) + for name in updated_modules: + module = self.pipeline.get_module(name) + if module is None: + continue + weights_dir = Path(original_path) / name + if not weights_dir.exists(): + continue + weights_iter = _get_weights_iter(str(weights_dir)) + _load_weights_into_module(module, weights_iter) diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index ef04a63c28df..9cbf0c211597 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -29,6 +29,11 @@ get_ulysses_parallel_world_size, ) from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs +from sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum +from sglang.multimodal_gen.runtime.loader.weights_updater import ( + WeightsUpdater, + get_updatable_modules, +) from sglang.multimodal_gen.runtime.pipelines_core import ( ComposedPipelineBase, LoRAPipeline, @@ -39,7 +44,10 @@ from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch -from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.layerwise_offload import ( + OffloadableDiTMixin, + iter_materialized_weights, +) from sglang.multimodal_gen.runtime.utils.logging_utils import ( configure_logger, globally_suppress_loggers, @@ -371,6 +379,48 @@ def list_loras(self) -> OutputBatch: status = self.pipeline.get_lora_status() return OutputBatch(output=status) + def update_weights_from_disk( + self, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update model weights from disk inplace without restarting the server.""" + if not self.pipeline: + return False, "Pipeline is not initialized" + + updater = WeightsUpdater(self.pipeline) + success, message = updater.update_weights_from_disk( + model_path, + flush_cache=flush_cache, + target_modules=target_modules, + ) + if success: + self.server_args.model_path = model_path + self.pipeline.model_path = model_path + return success, message + + def get_weights_checksum( + self, module_names: list[str] | None = None + ) -> dict[str, str]: + """Compute SHA-256 checksum of each module's weights.""" + if not self.pipeline: + return {"error": "Pipeline is not initialized"} + + all_modules = get_updatable_modules(self.pipeline) + names = module_names if module_names is not None else list(all_modules.keys()) + + checksums: dict[str, str] = {} + for name in names: + module = all_modules.get(name) + if module is None: + checksums[name] = "not_found" + continue + checksums[name] = compute_weights_checksum( + iter_materialized_weights(module) + ) + return checksums + OOM_MSG = f""" OOM detected. Possible solutions: diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index acaa96da2ef3..e1fef1df2d83 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -15,6 +15,10 @@ _parse_size, save_image_to_path, ) +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, + UpdateWeightFromDiskReqInput, +) from sglang.multimodal_gen.runtime.entrypoints.utils import ( ListLorasReq, MergeLoraWeightsReq, @@ -91,6 +95,8 @@ def __init__( List[Req]: self._handle_generation, ListLorasReq: self._handle_list_loras, ShutdownReq: self._handle_shutdown, + UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, + GetWeightsChecksumReqInput: self._handle_get_weights_checksum, } # FIFO, new reqs are appended @@ -131,6 +137,25 @@ def _handle_shutdown(self, _reqs: List[Any]) -> OutputBatch: self._running = False return OutputBatch() + def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_disk request for RL workflows.""" + req = reqs[0] + success, message = self.worker.update_weights_from_disk( + model_path=req.model_path, + flush_cache=req.flush_cache, + target_modules=req.target_modules, + ) + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) + + def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch: + """Handle get_weights_checksum request.""" + req = reqs[0] + checksums = self.worker.get_weights_checksum(module_names=req.module_names) + return OutputBatch(output=checksums) + def _handle_generation(self, reqs: List[Req]): warmup_reqs = [req for req in reqs if req.is_warmup] if warmup_reqs: diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 8af6ad1a69ed..0bdbd841db2f 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -276,6 +276,86 @@ def sync_all_layers_to_cpu(self) -> None: for layer_idx in list(self._gpu_layers): self.sync_layer_to_cpu(layer_idx) + @torch.compiler.disable + def update_cpu_weights( + self, weight_dict: Dict[str, torch.Tensor] + ) -> Set[str] | None: + """Update consolidated CPU buffers with new weights. + + When layerwise offload (--dit-layerwise-offload) is enabled, the + offload manager replaces GPU parameters with small torch.empty((1,)) + placeholders while real weights live in consolidated pinned CPU + buffers. + + The refit process writes new weights directly into the CPU buffers, + bypassing the placeholders. For any layer that happens to be resident + on the GPU at update time, the live GPU tensor is also updated. + + Args: + weight_dict: Mapping of parameter name to new weight tensor. + + Returns: + Set of parameter names that were successfully updated. + + Raises: + ValueError: If a weight's shape does not match the recorded + metadata (i.e., the real shape, not the placeholder shape). + """ + if not self.enabled: + return None + + updated_names: Set[str] = set() + for name, loaded_weight in weight_dict.items(): + layer_idx = self._match_layer_idx(name) + if layer_idx is None: + continue + meta_layer = self._weight_metadata.get(layer_idx) + if meta_layer is None or name not in meta_layer: + continue + + meta = meta_layer[name] + if tuple(meta["shape"]) != tuple(loaded_weight.shape): + raise ValueError( + f"Shape mismatch for {name}: " + f"expected={tuple(meta['shape'])}, " + f"loaded={tuple(loaded_weight.shape)}" + ) + + dtype = meta["dtype"] + offset = meta["offset"] + numel = meta["numel"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + cpu_buffer[offset : offset + numel].copy_( + loaded_weight.to(dtype=dtype).flatten() + ) + + # If this layer is currently on GPU, update the live parameter. + if layer_idx in self._gpu_layers: + target = self.get_target_with_name(name) + target.data.copy_(loaded_weight.to(dtype=target.dtype)) + + updated_names.add(name) + + return updated_names + + def iter_cpu_weights(self): + """Yield (name, tensor) pairs from consolidated CPU buffers. + + This reconstructs the original weight tensors (with correct shapes) + from the flat CPU buffers using stored metadata. Unlike + model.named_parameters(), which returns (1,) placeholders + when offload is enabled, this method returns the real weights and + can be used for checksum computation. + """ + for layer_idx in sorted(self._weight_metadata): + for name, meta in self._weight_metadata[layer_idx].items(): + dtype = meta["dtype"] + offset = meta["offset"] + numel = meta["numel"] + shape = meta["shape"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + yield name, cpu_buffer[offset : offset + numel].reshape(shape) + def register_forward_hooks(self) -> None: if not self.enabled: return @@ -383,3 +463,32 @@ def enable_offload(self) -> None: manager.sync_all_layers_to_cpu() manager.release_all() manager.register_forward_hooks() + + +def iter_materialized_weights(module: torch.nn.Module): + """Yield (name, tensor) pairs with materialized weights, even under offload. + + When layerwise offload is active, module.named_parameters() returns + (1,) placeholders for offloaded layers. This function reads the + actual data from the offload manager's CPU buffers and chains it with + the non-offloaded parameters. + """ + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + + if not offload_managers: + yield from module.named_parameters() + return + + # Collect offloaded names and their real tensors from CPU buffers. + offloaded_names: set[str] = set() + for manager in offload_managers: + for name, tensor in manager.iter_cpu_weights(): + offloaded_names.add(name) + yield name, tensor + + # Yield non-offloaded parameters (e.g. final norms, embeddings). + for name, param in module.named_parameters(): + if name not in offloaded_names: + yield name, param diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index 6610a4cfbe96..fc52247749dc 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -10,6 +10,7 @@ import argparse import os +import random import subprocess import sys from pathlib import Path @@ -20,6 +21,13 @@ logger = init_logger(__name__) +_UPDATE_WEIGHTS_FROM_DISK_TEST_FILE = "test_update_weights_from_disk.py" +_UPDATE_WEIGHTS_MODEL_PAIR_ENV = "SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR" +_UPDATE_WEIGHTS_MODEL_PAIR_IDS = ( + "FLUX.2-klein-base-4B", + "Qwen-Image", +) + SUITES = { "1-gpu": [ "test_server_a.py", @@ -29,6 +37,7 @@ "../cli/test_generate_t2i_perf.py", # unit tests (no server needed) "../test_sampling_params_validate.py", + "test_update_weights_from_disk.py", # add new 1-gpu test files here ], "2-gpu": [ @@ -225,6 +234,27 @@ def run_pytest(files, filter_expr=None): return returncode +def _is_in_ci() -> bool: + return os.environ.get("SGLANG_IS_IN_CI", "").lower() in ("1", "true", "yes", "on") + + +def _maybe_pin_update_weights_model_pair(suite_files_rel: list[str]) -> None: + if not _is_in_ci(): + return + if _UPDATE_WEIGHTS_FROM_DISK_TEST_FILE not in suite_files_rel: + return + if os.environ.get(_UPDATE_WEIGHTS_MODEL_PAIR_ENV): + print( + f"Using preset {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}=" + f"{os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV]}" + ) + return + + selected_pair = random.choice(_UPDATE_WEIGHTS_MODEL_PAIR_IDS) + os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV] = selected_pair + print(f"Selected {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}={selected_pair} for this CI run") + + def main(): args = parse_args() @@ -239,6 +269,7 @@ def main(): # 2. get files from suite suite_files_rel = SUITES[args.suite] + _maybe_pin_update_weights_model_pair(suite_files_rel) suite_files_abs = [] for f_rel in suite_files_rel: diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py new file mode 100644 index 000000000000..68700e93d016 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -0,0 +1,667 @@ +"""Tests for diffusion `update_weights_from_disk`. + +This module verifies the ability to update model weights in place without restarting +the server, which is critical for RL workflows and iterative fine-tuning scenarios. + +Author: + +Menyang Liu, https://github.com/dreamyang-liu +Chenyang Zhao, https://github.com/zhaochenyang20 + +We use two model pairs for testing (base model / instruct model pairs): + +- FLUX.2-klein-base-4B / FLUX.2-klein-4B +- Qwen/Qwen-Image / Qwen/Qwen-Image-2512 + +These model pairs share the same architecture but differ in transformer +weights. The basic testing logic is to refit the instruct model into the +base model and verify the checksum of the transformer weights are the same, +which simulates the real-world RL scenario. However, since these two model +pairs only differ in transformer weights, and we want to verify update a +specific module with update_weights_from_disk API, we need to create a perturbed +instruct model that adds noise to the vae weights. In this sense, the instruct +model differs from the base model in vae and transformer weights, the text +encoder are still the same. + +To strictly verify the correctness of the refit API, we compare the checksum in +SHA-256 on the disk and the server. + +NOTE and TODO: In the refit a specific module test, we randomly select one module +from the transformer and vae to refit the server and keep other modules the same. +As described above, the vae's weights are perturbed. If we select the vae to be the +target module, ideally speaking, we should assert that the refitted vae's checksum +is the same as directly computed from the perturbed vae weights in the disk. However, +since the there is complex weight-name remapping and QKV merge during model loading, +it is not easy to compare the server-disk checksum for vae and text encoder directly. +Therefore, if the target module is vae, we only verify that the refitted vae's checksum +is different from the base model's vae's checksum. + +It should be good issue to solve for the community to adds comparison the server-disk +checksum for vae and text encoder in this test. + +============================================================================= + +Test organization: + +7 test cases in 2 classes; +two model pairs are tested locally, one in CI. + +============================================================================= + +Class 1: TestUpdateWeightsFromDisk (6 tests) — API contract, checksum & rollback +Class 2: TestUpdateWeightsFromDiskWithOffload (1 test) — Offload-aware update + checksum + +----------------------------------------------------------------------------- + +Class 1: TestUpdateWeightsFromDisk + +Validate the update_weights_from_disk API contract, request/response shape, +error handling, checksum verification, and corrupted-weight rollback. + +All tests share one class-scoped server (same process, same in-memory weights). +Tests that require "base model then update" should be explicitly reset to +base model first so behavior is order-independent and updates are real +(base -> perturbed), not no-ops (perturbed -> perturbed). + + • test_update_weights_from_disk_default + + base model -> perturbed model with flush_cache=True. + Verifies after-update transformer checksum == perturbed model's + transformer disk checksum + + + • test_update_weights_specific_modules + + base -> perturbed with flush_cache=False. Randomly selects one module + from _DIFFERING_MODULES (transformer and vae) as target_modules, updates + only that module. Verifies that: + (1) targeted module's in-memory checksum changed; + (2) non-targeted modules' in-memory checksums are unchanged. + + • test_update_weights_nonexistent_model + + model_path set to a non-existent path; must fail (400, success=False). + + Ensure server is healthy after failed update and server's transformer + checksums equal base model's transformer disk checksum. + + • test_update_weights_missing_model_path + + Request body empty (no model_path); must fail (400, success=False). + + Ensure server is healthy after failed update and server's transformer + checksums equal base model's transformer disk checksum. + + • test_update_weights_nonexistent_module + + target_modules=["nonexistent_module"]; must fail (400, success=False). + + Verify server is healthy after failed update and server's checksums + equal base model's transformer disk checksum. + + • test_corrupted_weights_rollback + + All-or-nothing rollback: We first refit the server from base model -> + perturbed model. We manually truncate the vae weights of the base + model to get a corrupted model. We then call the refit to update + the server from the perturbed model -> corrupted model. Verify that: + + 1. The update fails due to truncated vae, server should roll back to the + perturbed model, i.e., server's transformer weights == perturbed model's + transformer weights != base model's transformer weights. + + 2. After the rollback, server's vae weights == perturbed model's vae + weights != base model's vae weights. + + 3. After the rollback, server's text encoder weights == base model's + text encoder weights == perturbed model's text encoder weights. + +----------------------------------------------------------------------------- + +Class 2: TestUpdateWeightsFromDiskWithOffload + + +Ensure weight updates and checksum verification work when layerwise offload is enabled +(--dit-layerwise-offload). With offload, parameters live in CPU buffers and only left +small torch.empty((1,)) as placeholders on GPU; the updater must write into CPU buffers +and update prefetched GPU tensors without shape mismatch. + + • test_update_weights_with_offload_enabled + + Server with --dit-layerwise-offload (base). Load perturbed checkpoint; + must succeed (200, success=True), no "Shape mismatch". server's transformer checksum + matches perturbed model's transformer disk checksum. +""" + +from __future__ import annotations + +import functools +import os +import random +import shutil +import tempfile +import threading +from collections.abc import Callable + +import pytest +import requests +from safetensors.torch import load_file, save_file + +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + compute_weights_checksum, + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.test_server_utils import ( + ServerManager, +) +from sglang.multimodal_gen.test.test_utils import get_dynamic_server_port, is_in_ci + +logger = init_logger(__name__) + + +_TRANSFORMER_MODULE = "transformer" +_VAE_MODULE = "vae" +_TEXT_ENCODER_MODULE_PREFIX = "text_encoder" + + +# Modules whose weights differ between the base model and the perturbed +# perturbed checkpoint +_DIFFERING_MODULES: list[str] = [_TRANSFORMER_MODULE, _VAE_MODULE] + +_ALL_MODEL_PAIRS: list[tuple[str, str]] = [ + ( + "black-forest-labs/FLUX.2-klein-base-4B", + "black-forest-labs/FLUX.2-klein-4B", + ), + ( + "Qwen/Qwen-Image", + "Qwen/Qwen-Image-2512", + ), +] + + +_CI_MODEL_PAIR_ENV = "SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR" + + +def _resolve_active_model_pairs() -> list[tuple[str, str]]: + if not is_in_ci(): + return _ALL_MODEL_PAIRS + + pair_by_id = {pair[0].split("/")[-1]: pair for pair in _ALL_MODEL_PAIRS} + selected_pair_id = os.environ.get(_CI_MODEL_PAIR_ENV) + if selected_pair_id is None: + return [random.choice(_ALL_MODEL_PAIRS)] + + selected_pair = pair_by_id.get(selected_pair_id) + if selected_pair is None: + valid_ids = ", ".join(sorted(pair_by_id)) + raise ValueError( + f"Invalid {_CI_MODEL_PAIR_ENV}={selected_pair_id!r}. " + f"Expected one of: {valid_ids}." + ) + return [selected_pair] + + +_ACTIVE_MODEL_PAIRS = _resolve_active_model_pairs() +_PAIR_IDS = [p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS] + + +@functools.lru_cache(maxsize=None) +def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: + """Compute SHA-256 checksum from safetensors files on disk. + + Uses the same compute_weights_checksum function as the server, + so the checksums are directly comparable. + + Results are cached (keyed on model_path and module_name) because the + same disk checksum is requested multiple times across tests. + """ + local_path = maybe_download_model(model_path) + weights_dir = os.path.join(local_path, module_name) + assert os.path.exists( + weights_dir + ), f"No weights dir for {module_name} in {local_path}" + + safetensors_files = _list_safetensors_files(weights_dir) + assert safetensors_files, f"No safetensors files in {weights_dir}" + + return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) + + +def _clone_model_with_modified_module( + src_model: str, + dst_model: str, + target_module: str, + transform_safetensor: Callable[[str, str], None], +) -> None: + # Symlink root-level files (model_index.json, etc.). + for fname in os.listdir(src_model): + src_path = os.path.join(src_model, fname) + dst_path = os.path.join(dst_model, fname) + if os.path.isfile(src_path) and not os.path.exists(dst_path): + os.symlink(src_path, dst_path) + + for module_dir in sorted(os.listdir(src_model)): + src_dir = os.path.join(src_model, module_dir) + dst_dir = os.path.join(dst_model, module_dir) + if not os.path.isdir(src_dir): + continue + + if module_dir != target_module: + if not os.path.exists(dst_dir): + os.symlink(src_dir, dst_dir) + continue + + os.makedirs(dst_dir, exist_ok=True) + transformed = False + for fname in sorted(os.listdir(src_dir)): + src_file = os.path.join(src_dir, fname) + dst_file = os.path.join(dst_dir, fname) + if not os.path.isfile(src_file): + continue + + if not fname.endswith(".safetensors") or transformed: + if not os.path.exists(dst_file): + os.symlink(src_file, dst_file) + continue + + transform_safetensor(src_file, dst_file) + transformed = True + + +def _truncate_safetensor(src_file: str, dst_file: str) -> None: + shutil.copy2(src_file, dst_file) + size = os.path.getsize(dst_file) + with open(dst_file, "r+b") as f: + f.truncate(size - 2) + logger.info( + "Created corrupted safetensors: %s (%d -> %d bytes)", + dst_file, + size, + size - 2, + ) + + +def _perturb_safetensor(src_file: str, dst_file: str) -> None: + + tensors = load_file(src_file) + perturbed = { + k: (t + 0.01 if t.is_floating_point() else t) for k, t in tensors.items() + } + save_file(perturbed, dst_file) + logger.info("Created perturbed safetensors: %s", dst_file) + + +class _UpdateWeightsApiMixin: + def _update_weights( + self, + base_url: str, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + timeout: int = 300, + ) -> tuple[dict, int]: + payload = {"model_path": model_path, "flush_cache": flush_cache} + if target_modules is not None: + payload["target_modules"] = target_modules + response = requests.post( + f"{base_url}/update_weights_from_disk", + json=payload, + timeout=timeout, + ) + return response.json(), response.status_code + + def _get_weights_checksum( + self, + base_url: str, + module_names: list[str] | None = None, + timeout: int = 300, + ) -> dict: + payload = {} + if module_names is not None: + payload["module_names"] = module_names + response = requests.post( + f"{base_url}/get_weights_checksum", + json=payload, + timeout=timeout, + ) + assert ( + response.status_code == 200 + ), f"get_weights_checksum failed: {response.status_code} {response.text}" + return response.json() + + def _assert_server_matches_model( + self, + base_url: str, + expected_model: str, + ) -> None: + server_checksums = self._get_weights_checksum( + base_url, module_names=[_TRANSFORMER_MODULE] + ) + expected_cs = _compute_checksum_from_disk(expected_model, _TRANSFORMER_MODULE) + server_cs = server_checksums.get(_TRANSFORMER_MODULE) + assert server_cs == expected_cs, ( + f"Checksum mismatch on '{_TRANSFORMER_MODULE}'\n" + f" expected({expected_model}): {expected_cs}\n" + f" server: {server_cs}" + ) + + +class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): + + @pytest.fixture( + scope="class", + params=_ACTIVE_MODEL_PAIRS, + ids=_PAIR_IDS, + ) + def diffusion_server_no_offload(self, request): + default_model, source_model = request.param + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=default_model, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1", + ) + + # Ensure models are local before spawning threads that need the paths. + local_default = maybe_download_model(default_model) + local_source = maybe_download_model(source_model) + + perturbed_vae_model_dir = tempfile.mkdtemp(prefix="sglang_perturbed_vae_") + corrupted_vae_model_dir = tempfile.mkdtemp(prefix="sglang_corrupted_") + + # Run all disk I/O in background while the server boots. + bg_threads = [ + threading.Thread( + target=_compute_checksum_from_disk, args=(default_model, module) + ) + for module in _DIFFERING_MODULES + ] + [ + threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_source, + perturbed_vae_model_dir, + _VAE_MODULE, + _perturb_safetensor, + ), + ), + threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_default, + corrupted_vae_model_dir, + _VAE_MODULE, + _truncate_safetensor, + ), + ), + ] + for t in bg_threads: + t.start() + + ctx = manager.start() + for t in bg_threads: + t.join() + + # Sanity: all _DIFFERING_MODULES should differ between base and perturbed. + for module in _DIFFERING_MODULES: + assert _compute_checksum_from_disk( + default_model, module + ) != _compute_checksum_from_disk(perturbed_vae_model_dir, module), ( + f"Assumption violated: {module} should differ between " + f"{default_model} and {perturbed_vae_model_dir}" + ) + + try: + yield ctx, default_model, perturbed_vae_model_dir, corrupted_vae_model_dir + finally: + ctx.cleanup() + shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) + shutil.rmtree(corrupted_vae_model_dir, ignore_errors=True) + + def test_update_weights_from_disk_default(self, diffusion_server_no_offload): + """Default update (target_modules=None, flush_cache=True): all changed modules updated.""" + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model, flush_cache=True) + + result, status_code = self._update_weights( + base_url, perturbed_model_dir, flush_cache=True + ) + assert status_code == 200 + assert result.get("success", False), f"Update failed: {result.get('message')}" + + self._assert_server_matches_model(base_url, perturbed_model_dir) + + def test_update_weights_specific_modules(self, diffusion_server_no_offload): + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + # Reset server to default_model. + self._update_weights(base_url, default_model) + before_checksums = self._get_weights_checksum( + base_url, module_names=_DIFFERING_MODULES + ) + + target_modules = [random.choice(_DIFFERING_MODULES)] + result, status_code = self._update_weights( + base_url, + perturbed_model_dir, + target_modules=target_modules, + flush_cache=False, + ) + assert status_code == 200, f"Update failed: {result}" + assert result.get("success", False), f"Update failed: {result.get('message')}" + + after_checksums = self._get_weights_checksum( + base_url, module_names=_DIFFERING_MODULES + ) + + # Targeted module should have changed. + for name in target_modules: + assert after_checksums.get(name) != before_checksums.get(name), ( + f"Targeted module '{name}' checksum should change after update\n" + f" before: {before_checksums.get(name)}\n" + f" after: {after_checksums.get(name)}" + ) + + # Non-targeted modules should be unchanged. + for name, cs in after_checksums.items(): + if name in target_modules or cs == "not_found": + continue + assert cs == before_checksums.get(name), ( + f"Non-targeted module '{name}' should be unchanged\n" + f" before: {before_checksums.get(name)}\n" + f" after: {cs}" + ) + + def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): + """Nonexistent model path must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, _, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model) + + result, status_code = self._update_weights( + base_url, + "/nonexistent/path/to/model", + timeout=60, + ) + logger.info(f"Update result for nonexistent model: {result}") + + assert status_code == 400, f"Expected 400, got {status_code}" + assert not result.get("success", True), "Should fail for nonexistent model" + self._assert_server_matches_model(base_url, default_model) + + def test_update_weights_missing_model_path(self, diffusion_server_no_offload): + """Request without model_path must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, _, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model) + + response = requests.post( + f"{base_url}/update_weights_from_disk", + json={}, + timeout=30, + ) + + assert response.status_code == 400, f"Expected 400, got {response.status_code}" + result = response.json() + assert not result.get("success", True), "Should fail when model_path is missing" + self._assert_server_matches_model(base_url, default_model) + + def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): + """Nonexistent module must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model) + + result, status_code = self._update_weights( + base_url, + perturbed_model_dir, + target_modules=["nonexistent_module"], + timeout=60, + ) + logger.info(f"Update nonexistent module result: {result}") + + assert status_code == 400, f"Expected 400, got {status_code}" + assert not result.get("success", True), "Should fail for nonexistent module" + assert "not found in pipeline" in result.get("message", "") + self._assert_server_matches_model(base_url, default_model) + + def test_corrupted_weights_rollback(self, diffusion_server_no_offload): + ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = ( + diffusion_server_no_offload + ) + base_url = f"http://localhost:{ctx.port}" + + # base → perturbed + self._update_weights(base_url, default_model) + base_checksums = self._get_weights_checksum(base_url) + + result, status_code = self._update_weights(base_url, perturbed_model_dir) + assert status_code == 200 and result.get("success") + perturbed_checksums = self._get_weights_checksum(base_url) + + text_encoder_modules = sorted( + name + for name in perturbed_checksums + if _TEXT_ENCODER_MODULE_PREFIX in name + and perturbed_checksums.get(name) != "not_found" + and base_checksums.get(name) != "not_found" + ) + assert ( + text_encoder_modules + ), "Expected at least one text encoder module checksum" + + # perturbed → corrupted (should fail and rollback) + rollback_targets = [_TRANSFORMER_MODULE, _VAE_MODULE] + result, status_code = self._update_weights( + base_url, + corrupted_vae_model_dir, + target_modules=rollback_targets, + ) + assert ( + status_code == 400 + ), f"Expected 400 on corrupted weights, got {status_code}" + assert not result.get("success", True) + message = result.get("message", "") + assert "rolled back" in message.lower() + # The updater reports the first failing module in the error message. + # With ordered target_modules=[transformer, vae], this makes the + # failure point explicit: transformer is processed first, then vae fails. + assert ( + "Failed to update module 'vae'" in message + ), f"Expected vae to be the explicit failure point, got: {message}" + rolled_back_checksums = self._get_weights_checksum(base_url) + + # 1) transformer: server == perturbed != base + transformer_base = base_checksums.get(_TRANSFORMER_MODULE) + transformer_perturbed = perturbed_checksums.get(_TRANSFORMER_MODULE) + transformer_rolled_back = rolled_back_checksums.get(_TRANSFORMER_MODULE) + assert transformer_rolled_back == transformer_perturbed + assert transformer_rolled_back != transformer_base + + # 2) vae: server == perturbed != base + vae_base = base_checksums.get(_VAE_MODULE) + vae_perturbed = perturbed_checksums.get(_VAE_MODULE) + vae_rolled_back = rolled_back_checksums.get(_VAE_MODULE) + assert vae_rolled_back == vae_perturbed + assert vae_rolled_back != vae_base + + # 3) text encoder(s): server == base == perturbed + for name in text_encoder_modules: + assert rolled_back_checksums.get(name) == perturbed_checksums.get( + name + ), f"Text encoder module '{name}' should stay equal to perturbed" + assert rolled_back_checksums.get(name) == base_checksums.get( + name + ), f"Text encoder module '{name}' should stay equal to base" + + +class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): + """Test update_weights_from_disk with layerwise offload enabled.""" + + @pytest.fixture(scope="class", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS) + def diffusion_server_with_offload(self, request): + default_model, source_model = request.param + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + local_source = maybe_download_model(source_model) + perturbed_vae_model_dir = tempfile.mkdtemp(prefix="sglang_perturbed_vae_") + + clone_thread = threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_source, + perturbed_vae_model_dir, + _VAE_MODULE, + _perturb_safetensor, + ), + ) + clone_thread.start() + + manager = ServerManager( + model=default_model, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1 --dit-layerwise-offload true", + ) + + ctx = manager.start() + clone_thread.join() + + try: + yield ctx, default_model, perturbed_vae_model_dir + finally: + ctx.cleanup() + shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) + + def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): + ctx, _, perturbed_model_dir = diffusion_server_with_offload + base_url = f"http://localhost:{ctx.port}" + + result, status_code = self._update_weights(base_url, perturbed_model_dir) + assert status_code == 200, f"Expected 200, got {status_code}" + assert result.get("success", False), f"Update failed: {result.get('message')}" + + message = result.get("message", "") + assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" + + self._assert_server_matches_model(base_url, perturbed_model_dir) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index 9379dcc9d5ce..f66b8801cfa0 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -28,6 +28,8 @@ from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord +DEFAULT_SMALL_MODEL = "Tongyi-MAI/Z-Image-Turbo" + @dataclass class ToleranceConfig: @@ -339,8 +341,6 @@ def from_req_perf_record( fps=4, ) -DEFAULT_SMALL_MODEL = "Tongyi-MAI/Z-Image-Turbo" - # All test cases with clean default values # To test different models, simply add more DiffusionCase entries ONE_GPU_CASES_A: list[DiffusionTestCase] = [