From c93c10380ddab526597d0e7232ed28e4974c65ae Mon Sep 17 00:00:00 2001 From: dreamyang-liu Date: Thu, 5 Feb 2026 09:42:13 +0000 Subject: [PATCH 01/30] [Feature] Implement update_weights_from_disk for SGLang-D (Diffusion Engine) This PR implements the update_weights_from_disk interface for the SGLang-D diffusion engine, enabling dynamic weight updates for RL workflows and iterative model fine-tuning without restarting the server. - Add `UpdateWeightsFromDiskReq` dataclass in `io_struct.py` for request handling - Implement `GPUWorker.update_weights_from_disk()` method with: - Support for all nn.Module components by default (transformer, vae, text_encoder, etc.) - Layerwise offload handling: disable before update, re-enable after with synced weights - DTensor support for tensor parallel parameters - Atomic updates with rollback: if any module fails, rollback all updated modules - TeaCache state reset after weight updates - Add scheduler handler for update_weights requests via ZMQ - Add `/update_weights_from_disk` HTTP endpoint - Add `/get_model_info` endpoint to query current model path - Add test suite in `test_update_weights_from_disk.py` - Basic API tests (same model, flush_cache, specific modules) - Layerwise offload integration tests - End-to-end tests verifying generation after weight update The implementation mirrors the LLM engine's update_weights_from_disk functionality, using safetensors_weights_iterator for weight loading and supporting HuggingFace model paths. Closes sgl-project#18078 --- .../runtime/entrypoints/http_server.py | 85 +++++ .../runtime/managers/gpu_worker.py | 313 ++++++++++++++++ .../runtime/managers/io_struct.py | 16 + .../runtime/managers/scheduler.py | 15 + .../server/test_update_weights_from_disk.py | 340 ++++++++++++++++++ 5 files changed, 769 insertions(+) create mode 100644 python/sglang/multimodal_gen/runtime/managers/io_struct.py create mode 100644 python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index 20d2c3df9bff..122e80b943e7 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -92,6 +92,90 @@ async def health_generate(): return {"status": "ok"} +@health_router.get("/get_model_info") +async def get_model_info(request: Request): + """Get information about the current model.""" + server_args: ServerArgs = request.app.state.server_args + return { + "model_path": server_args.model_path, + } + + +# Weight update router for RL workflows +weight_update_router = APIRouter() + + +@weight_update_router.post("/update_weights_from_disk") +async def update_weights_from_disk(request: Request): + """ + Update model weights from disk without restarting the server. + + This endpoint enables dynamic weight updates for RL workflows and iterative + model fine-tuning scenarios. + + Request body: + - model_path (str): Path to the new model weights (HuggingFace model path or local directory) + - load_format (str, optional): Format of the weights to load (default: "auto") + - flush_cache (bool, optional): Whether to flush cache after update (default: True) + - target_modules (list[str], optional): List of module names to update. + Default: updates ALL nn.Module components (transformer, vae, text_encoder, etc.) + Examples: ["transformer"] to update only transformer + + Returns: + - success (bool): Whether the update was successful + - message (str): Status message + """ + from sglang.multimodal_gen.runtime.managers.io_struct import ( + UpdateWeightsFromDiskReq, + ) + + try: + body = await request.json() + model_path = body.get("model_path") + if not model_path: + return ORJSONResponse( + {"success": False, "message": "model_path is required"}, + status_code=400, + ) + + # Create the request object with diffusion-specific fields + req = UpdateWeightsFromDiskReq( + model_path=model_path, + load_format=body.get("load_format", "auto"), + flush_cache=body.get("flush_cache", True), + target_modules=body.get("target_modules"), + ) + + response = await async_scheduler_client.forward(req) + + # Handle response + if hasattr(response, "output") and response.output: + result = response.output + return ORJSONResponse( + { + "success": result.get("success", False), + "message": result.get("message", "Unknown status"), + }, + status_code=200 if result.get("success") else 400, + ) + elif hasattr(response, "error") and response.error: + return ORJSONResponse( + {"success": False, "message": response.error}, + status_code=400, + ) + else: + return ORJSONResponse( + {"success": False, "message": "Unknown response format"}, + status_code=500, + ) + + except Exception as e: + return ORJSONResponse( + {"success": False, "message": f"Error: {str(e)}"}, + status_code=500, + ) + + def make_serializable(obj): """Recursively converts Tensors to None for JSON serialization.""" if isinstance(obj, torch.Tensor): @@ -211,6 +295,7 @@ def create_app(server_args: ServerArgs): app = FastAPI(lifespan=lifespan) app.include_router(health_router) + app.include_router(weight_update_router) app.include_router(vertex_router) from sglang.multimodal_gen.runtime.entrypoints.openai import common_api diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 4003c4d3de9f..26a5ee2b8de8 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -342,6 +342,319 @@ def list_loras(self) -> OutputBatch: status = self.pipeline.get_lora_status() return OutputBatch(output=status) + # Module name to weight directory mapping for different model architectures + _MODULE_WEIGHT_DIR_MAPPING = { + "transformer": ["transformer", "dit", "model"], + "transformer_2": ["transformer_2"], + "video_dit": ["video_dit", "transformer", "dit", "model"], + "video_dit_2": ["video_dit_2"], + "audio_dit": ["audio_dit"], + } + + # Default modules to update for RL workflows (typically only transformer is trained) + _DEFAULT_TARGET_MODULES = [ + "transformer", + "transformer_2", + "video_dit", + "video_dit_2", + "audio_dit", + ] + + def update_weights_from_disk( + self, + model_path: str, + load_format: str = "auto", + flush_cache: bool = True, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """ + Update model weights from disk in-place without restarting the server. + + This method enables dynamic weight updates for RL workflows and iterative + model fine-tuning scenarios. Includes rollback mechanism to restore original + weights if loading fails. + + By default, updates ALL nn.Module components in the pipeline (transformer, vae, + text_encoder, etc.). Use target_modules to specify a subset if needed. + + Args: + model_path: Path to the new model weights (HuggingFace model path or local directory). + load_format: Format of the weights to load (default: "auto"). + flush_cache: Whether to reset cache state after updating weights (default: True). + target_modules: List of module names to update. If None or ["all"], updates all + nn.Module components. Specify a list like ["transformer"] to update + only specific modules. + + Returns: + Tuple of (success: bool, message: str). + """ + import gc + import os + + from sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files + from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, + ) + from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model, + ) + + logger.info(f"Updating weights from disk: {model_path}") + + # Store original model path for potential rollback + original_model_path = self.server_args.model_path + + if not self.pipeline: + return False, "Pipeline is not initialized" + + available_modules: list[str] = [] + if hasattr(self.pipeline, "modules"): + available_modules = list(self.pipeline.modules.keys()) + + # Determine which modules to update + if target_modules is None or target_modules == ["all"]: + # Default: update all nn.Module components in the pipeline + module_names = [ + name + for name in available_modules + if isinstance(self.pipeline.get_module(name), torch.nn.Module) + ] + else: + module_names = target_modules + + # Collect all modules that need to be updated + modules_to_update: list[tuple[str, torch.nn.Module]] = [] + + for name in module_names: + module = self.pipeline.get_module(name) + if module is not None and isinstance(module, torch.nn.Module): + modules_to_update.append((name, module)) + + # For DiffusersPipeline, also check diffusers_pipe attributes + diffusers_pipe = self.pipeline.get_module("diffusers_pipeline") + if diffusers_pipe is not None and not modules_to_update: + for name in module_names: + if hasattr(diffusers_pipe, name): + module = getattr(diffusers_pipe, name) + if module is not None and isinstance(module, torch.nn.Module): + modules_to_update.append((name, module)) + + if not modules_to_update: + # Provide detailed error message + error_msg = ( + f"No matching modules found for update. " + f"Requested: {module_names}. " + f"Available in pipeline: {available_modules}" + ) + logger.error(error_msg) + return False, error_msg + + # Helper function to find weights directory for a module + def find_weights_dir(local_path: str, module_name: str) -> str | None: + possible_dirs = self._MODULE_WEIGHT_DIR_MAPPING.get( + module_name, [module_name] + ) + for dir_name in possible_dirs: + dir_path = os.path.join(local_path, dir_name) + if os.path.exists(dir_path): + return dir_path + # Fallback: check if weights are in root directory (for single-module models) + if _list_safetensors_files(local_path): + return local_path + return None + + # Helper function to get weights iterator from a directory + def get_weights_iter(weights_dir: str): + safetensors_files = _list_safetensors_files(weights_dir) + if not safetensors_files: + raise FileNotFoundError(f"No safetensors files found in {weights_dir}") + return safetensors_weights_iterator(safetensors_files), len( + safetensors_files + ) + + # Helper function to load weights into model + def load_weights_into_model( + weights_iter, model_params: dict + ) -> tuple[int, int]: + try: + from torch.distributed.tensor import DTensor, distribute_tensor + except ImportError: + DTensor = None + distribute_tensor = None + + updated = 0 + skipped = 0 + for name, loaded_weight in weights_iter: + if name in model_params: + param = model_params[name] + if param.shape == loaded_weight.shape: + if DTensor is not None and isinstance(param, DTensor): + # For DTensor, distribute the loaded weight first then copy + distributed_weight = distribute_tensor( + loaded_weight.to(param.device, param.dtype), + param.device_mesh, + param.placements, + ) + param._local_tensor.copy_(distributed_weight._local_tensor) + else: + param.data.copy_( + loaded_weight.to(param.device, param.dtype) + ) + updated += 1 + else: + logger.warning( + f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" + ) + skipped += 1 + else: + skipped += 1 + return updated, skipped + + # Download model if it's a HuggingFace path + try: + local_model_path = maybe_download_model(model_path) + except Exception as e: + return False, f"Failed to download model: {e}" + + # Phase 1: Validate ALL modules have their weight directories before any update + # This ensures we don't do partial updates + module_weights_map: dict[str, str] = {} # module_name -> weights_dir + missing_modules: list[str] = [] + + for module_name, module in modules_to_update: + weights_dir = find_weights_dir(local_model_path, module_name) + if weights_dir is None: + missing_modules.append(module_name) + else: + # Also validate that we can get weights iterator + try: + safetensors_files = _list_safetensors_files(weights_dir) + if not safetensors_files: + missing_modules.append(module_name) + else: + module_weights_map[module_name] = weights_dir + except Exception: + missing_modules.append(module_name) + + # Fail if any module is missing weights - no partial updates allowed + if missing_modules: + error_message = ( + f"Cannot update weights: missing weight files for modules: {missing_modules}. " + f"All modules must have corresponding weights. No partial updates allowed." + ) + logger.error(error_message) + return False, error_message + + # Log which modules will be updated from which directories + logger.info( + f"Updating {len(module_weights_map)} modules: " + + ", ".join( + f"{name} <- {path}" for name, path in module_weights_map.items() + ) + ) + + # Phase 2: Update all modules + # First, disable layerwise offload for all modules (load weights from CPU to GPU) + offload_disabled_modules: list[torch.nn.Module] = [] + for module_name, module in modules_to_update: + if ( + hasattr(module, "layerwise_offload_managers") + and module.layerwise_offload_managers + ): + module.disable_offload() + offload_disabled_modules.append(module) + + total_updated = 0 + total_skipped = 0 + updated_modules: list[str] = [] + + for module_name, module in modules_to_update: + weights_dir = module_weights_map[module_name] + model_state_dict = dict(module.named_parameters()) + + try: + weights_iter, _ = get_weights_iter(weights_dir) + updated, skipped = load_weights_into_model( + weights_iter, model_state_dict + ) + total_updated += updated + total_skipped += skipped + updated_modules.append(module_name) + except Exception as e: + # Rollback ALL modules (including the ones already updated) + error_message = ( + f"Failed to update {module_name}: {e}. Rolling back all modules." + ) + logger.error(error_message, exc_info=True) + + if updated_modules: + try: + original_local_path = maybe_download_model(original_model_path) + for rollback_name in updated_modules: + rollback_module = self.pipeline.get_module(rollback_name) + if rollback_module is None: + continue + rollback_weights_dir = find_weights_dir( + original_local_path, rollback_name + ) + if rollback_weights_dir is None: + continue + rollback_iter, _ = get_weights_iter(rollback_weights_dir) + rollback_params = dict(rollback_module.named_parameters()) + load_weights_into_model(rollback_iter, rollback_params) + except Exception as rollback_error: + logger.error(f"Rollback failed: {rollback_error}") + # Re-enable offload before returning + for m in offload_disabled_modules: + m.enable_offload() + return ( + False, + f"{error_message} Rollback also failed: {rollback_error}", + ) + + gc.collect() + torch.cuda.empty_cache() + # Re-enable offload before returning + for m in offload_disabled_modules: + m.enable_offload() + return False, error_message + + # Clean up GPU memory + gc.collect() + torch.cuda.empty_cache() + + # Reset cache state for all updated modules + if flush_cache: + for module_name, module in modules_to_update: + if module_name in updated_modules: + self._reset_cache_state_after_weight_update(module) + + # Re-enable layerwise offload (sync new weights to CPU) + for module in offload_disabled_modules: + module.enable_offload() + + # Update the model path in server_args + self.server_args.model_path = model_path + + message = f"Successfully updated {len(updated_modules)} modules ({', '.join(updated_modules)}): {total_updated} params updated" + logger.info(message) + return True, message + + def _reset_cache_state_after_weight_update(self, module: torch.nn.Module) -> None: + """ + Reset cache state for a single module after weight updates. + + This resets TeaCache state. Cache-DiT context is automatically refreshed + at the start of each inference request with the correct num_inference_steps, + so we don't need to manually reset it here. + + Args: + module: The module whose cache state should be reset. + """ + # Reset TeaCache state if the module has it + if hasattr(module, "reset_teacache_state"): + module.reset_teacache_state() + OOM_MSG = f""" OOM detected. Possible solutions: diff --git a/python/sglang/multimodal_gen/runtime/managers/io_struct.py b/python/sglang/multimodal_gen/runtime/managers/io_struct.py new file mode 100644 index 000000000000..d4b29223e027 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/managers/io_struct.py @@ -0,0 +1,16 @@ +""" +I/O data structures for diffusion engine scheduler. +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class UpdateWeightsFromDiskReq: + """Request to update model weights from disk for diffusion models.""" + + model_path: str + load_format: str = "auto" + flush_cache: bool = True + target_modules: Optional[List[str]] = None diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 041f6b7fec48..7978f8c2ae4a 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -21,6 +21,7 @@ save_image_to_path, ) from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker +from sglang.multimodal_gen.runtime.managers.io_struct import UpdateWeightsFromDiskReq from sglang.multimodal_gen.runtime.pipelines_core import Req from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch from sglang.multimodal_gen.runtime.server_args import ( @@ -89,6 +90,7 @@ def __init__( List[Req]: self._handle_generation, ListLorasReq: self._handle_list_loras, ShutdownReq: self._handle_shutdown, + UpdateWeightsFromDiskReq: self._handle_update_weights_from_disk, } # FIFO, new reqs are appended @@ -128,6 +130,19 @@ def _handle_list_loras(self, _reqs: List[Any]) -> OutputBatch: def _handle_shutdown(self, _reqs: List[Any]) -> OutputBatch: self._running = False return OutputBatch() + def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_disk request for RL workflows.""" + req = reqs[0] + success, message = self.worker.update_weights_from_disk( + model_path=req.model_path, + load_format=req.load_format or "auto", + flush_cache=req.flush_cache, + target_modules=req.target_modules, + ) + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) def _handle_generation(self, reqs: List[Req]): warmup_reqs = [req for req in reqs if req.is_warmup] diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py new file mode 100644 index 000000000000..4303e823faf0 --- /dev/null +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -0,0 +1,340 @@ +""" +Tests for update_weights_from_disk API in SGLang-D (diffusion engine). + +This tests the ability to dynamically update model weights without restarting the server, +which is critical for RL workflows and iterative fine-tuning scenarios. +""" + +from __future__ import annotations + +import os + +import pytest +import requests + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.test_server_utils import ( + ServerContext, + ServerManager, +) +from sglang.multimodal_gen.test.test_utils import get_dynamic_server_port + +logger = init_logger(__name__) + +# Default model for testing - use a small/fast model, need to be an image diffusion model +DEFAULT_DIFFUSION_MODEL = os.environ.get( + "SGLANG_TEST_DIFFUSION_MODEL", "black-forest-labs/FLUX.2-klein-4B" +) + + +@pytest.fixture(scope="class") +def diffusion_server_for_weight_update(): + """Start a diffusion server for weight update tests.""" + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=DEFAULT_DIFFUSION_MODEL, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1", + ) + + ctx = manager.start() + + try: + yield ctx + finally: + ctx.cleanup() + + +class TestUpdateWeightsFromDisk: + """Test suite for update_weights_from_disk API.""" + + def _get_base_url(self, ctx: ServerContext) -> str: + return f"http://localhost:{ctx.port}" + + def _get_model_info(self, base_url: str) -> dict: + """Get current model info from server.""" + response = requests.get(f"{base_url}/get_model_info", timeout=30) + assert response.status_code == 200, f"get_model_info failed: {response.text}" + return response.json() + + def _update_weights( + self, + base_url: str, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + timeout: int = 300, + ) -> dict: + """Call update_weights_from_disk API.""" + payload = { + "model_path": model_path, + "flush_cache": flush_cache, + } + if target_modules is not None: + payload["target_modules"] = target_modules + + response = requests.post( + f"{base_url}/update_weights_from_disk", + json=payload, + timeout=timeout, + ) + return response.json(), response.status_code + + def test_get_model_info(self, diffusion_server_for_weight_update: ServerContext): + """Test that we can get model info from the server.""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + model_info = self._get_model_info(base_url) + + assert "model_path" in model_info, "model_path not in response" + logger.info(f"Model info: {model_info}") + + def test_update_weights_same_model( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test updating weights with the same model (should succeed).""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + + # Get current model path + model_info = self._get_model_info(base_url) + current_model_path = model_info["model_path"] + logger.info(f"Current model path: {current_model_path}") + + # Update with same model + result, status_code = self._update_weights(base_url, current_model_path) + logger.info(f"Update result: {result}") + + assert status_code == 200, f"Expected 200, got {status_code}" + assert result.get("success", False), f"Update failed: {result.get('message')}" + + def test_update_weights_with_flush_cache( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test updating weights with flush_cache=True.""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + model_info = self._get_model_info(base_url) + current_model_path = model_info["model_path"] + + result, status_code = self._update_weights( + base_url, + current_model_path, + flush_cache=True, + ) + + assert status_code == 200 + assert result.get("success", False), f"Update failed: {result.get('message')}" + + def test_update_weights_without_flush_cache( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test updating weights with flush_cache=False.""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + model_info = self._get_model_info(base_url) + current_model_path = model_info["model_path"] + + result, status_code = self._update_weights( + base_url, + current_model_path, + flush_cache=False, + ) + + assert status_code == 200 + assert result.get("success", False), f"Update failed: {result.get('message')}" + + def test_update_weights_nonexistent_model( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test that updating with non-existent model fails gracefully.""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + + result, status_code = self._update_weights( + base_url, + "/nonexistent/path/to/model", + timeout=60, + ) + logger.info(f"Update result for nonexistent model: {result}") + + # Should fail gracefully + assert not result.get("success", True), "Should fail for nonexistent model" + + def test_update_weights_missing_model_path( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test that request without model_path returns 400.""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + + response = requests.post( + f"{base_url}/update_weights_from_disk", + json={}, + timeout=30, + ) + + # Should return 400 Bad Request + assert response.status_code == 400, f"Expected 400, got {response.status_code}" + + def test_update_weights_specific_modules( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test updating only specific modules (e.g., transformer only).""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + model_info = self._get_model_info(base_url) + current_model_path = model_info["model_path"] + + # Try to update only transformer module + result, status_code = self._update_weights( + base_url, + current_model_path, + target_modules=["transformer"], + ) + logger.info(f"Update specific modules result: {result}") + + # This might fail if the model doesn't have a transformer module + # or if weights for only transformer aren't available + # The test verifies the API handles target_modules parameter + assert status_code == 200 + + +class TestUpdateWeightsFromDiskWithOffload: + """Test update_weights_from_disk with layerwise offload enabled.""" + + @pytest.fixture(scope="class") + def diffusion_server_with_offload(self): + """Start a diffusion server with layerwise offload enabled.""" + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=DEFAULT_DIFFUSION_MODEL, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1 --dit-layerwise-offload true", + ) + + ctx = manager.start() + + try: + yield ctx + finally: + ctx.cleanup() + + def _get_base_url(self, ctx: ServerContext) -> str: + return f"http://localhost:{ctx.port}" + + def _get_model_info(self, base_url: str) -> dict: + response = requests.get(f"{base_url}/get_model_info", timeout=30) + return response.json() + + def _update_weights( + self, base_url: str, model_path: str, **kwargs + ) -> tuple[dict, int]: + payload = {"model_path": model_path, **kwargs} + response = requests.post( + f"{base_url}/update_weights_from_disk", + json=payload, + timeout=kwargs.get("timeout", 300), + ) + return response.json(), response.status_code + + def test_update_weights_with_offload_enabled( + self, diffusion_server_with_offload: ServerContext + ): + """Test that weight update works correctly when layerwise offload is enabled. + + This tests the fix for the shape mismatch issue where offloaded weights + have placeholder size [1] tensors on GPU. + """ + base_url = self._get_base_url(diffusion_server_with_offload) + model_info = self._get_model_info(base_url) + current_model_path = model_info["model_path"] + + logger.info( + f"Testing weight update with offload enabled, model: {current_model_path}" + ) + + result, status_code = self._update_weights(base_url, current_model_path) + logger.info(f"Update result: {result}") + + assert status_code == 200, f"Expected 200, got {status_code}" + assert result.get("success", False), f"Update failed: {result.get('message')}" + + # Verify no shape mismatch warnings in the message + message = result.get("message", "") + assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" + + +class TestUpdateWeightsEndToEnd: + """End-to-end tests: verify generation works after weight update.""" + + @pytest.fixture(scope="class") + def diffusion_server_e2e(self): + """Start a diffusion server for E2E tests.""" + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=DEFAULT_DIFFUSION_MODEL, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1", + ) + + ctx = manager.start() + + try: + yield ctx + finally: + ctx.cleanup() + + def _get_base_url(self, ctx: ServerContext) -> str: + return f"http://localhost:{ctx.port}" + + def _generate_image(self, base_url: str, prompt: str = "a cat") -> dict: + """Generate an image using the OpenAI-compatible API.""" + from openai import OpenAI + + client = OpenAI( + api_key="sglang-test", + base_url=f"{base_url}/v1", + ) + + response = client.images.generate( + model="default", + prompt=prompt, + n=1, + size="512x512", + response_format="b64_json", # Avoid needing cloud storage + ) + + return response + + def test_generation_after_weight_update(self, diffusion_server_e2e: ServerContext): + """Test that generation still works after updating weights.""" + base_url = self._get_base_url(diffusion_server_e2e) + + # Generate before update + logger.info("Generating image before weight update...") + response_before = self._generate_image(base_url, "a beautiful sunset") + assert response_before.data, "Generation before update failed" + logger.info("Generation before update succeeded") + + # Update weights + model_info = requests.get(f"{base_url}/get_model_info", timeout=30).json() + update_response = requests.post( + f"{base_url}/update_weights_from_disk", + json={"model_path": model_info["model_path"], "flush_cache": True}, + timeout=300, + ) + assert update_response.json().get("success"), "Weight update failed" + logger.info("Weight update succeeded") + + # Generate after update + logger.info("Generating image after weight update...") + response_after = self._generate_image(base_url, "a beautiful sunrise") + assert response_after.data, "Generation after update failed" + logger.info("Generation after update succeeded") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 7703c69c5bc991e92a7d9bd09fb2ba918827645e Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Fri, 6 Feb 2026 09:01:01 +0000 Subject: [PATCH 02/30] [diffusion] refactor: extract WeightsUpdater for update_weights_from_disk - Extract weight update logic from GPUWorker into dedicated WeightsUpdater class in loader/weights_updater.py - Simplify http_server.py: merge endpoint into health_router, remove /get_model_info, narrow exception handling - Align with LLM engine: propagate rollback failures, raise on shape mismatch, no skip/update counting - Remove unused load_format parameter from the full call chain - Add DTensor and layerwise offload support with docstrings - Document diffusion weight update API in sglang_for_rl.md - Add test to per-commit 1-gpu CI suite --- docs/advanced_features/sglang_for_rl.md | 17 + .../runtime/entrypoints/http_server.py | 99 ++---- .../runtime/loader/weights_updater.py | 335 ++++++++++++++++++ .../runtime/managers/gpu_worker.py | 315 +--------------- .../runtime/managers/io_struct.py | 1 - .../runtime/managers/scheduler.py | 1 - .../sglang/multimodal_gen/test/run_suite.py | 1 + .../server/test_update_weights_from_disk.py | 55 +-- 8 files changed, 398 insertions(+), 426 deletions(-) create mode 100644 python/sglang/multimodal_gen/runtime/loader/weights_updater.py diff --git a/docs/advanced_features/sglang_for_rl.md b/docs/advanced_features/sglang_for_rl.md index 2fd84c90de69..30984dd32150 100644 --- a/docs/advanced_features/sglang_for_rl.md +++ b/docs/advanced_features/sglang_for_rl.md @@ -106,6 +106,23 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates **Python Engine API:** `engine.update_weights_from_disk(model_path, load_format=None)` +**Diffusion engine (SGLang-D):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint. The update is all-or-nothing with automatic rollback on failure. When layerwise offload (`--dit-layerwise-offload`) is enabled, offload is temporarily disabled during the update, which causes a temporary GPU memory peak before returning to normal offloaded memory usage. + +**Request body:** + +| Field | Description | Defaults | Options | +| --- | --- | --- | --- | +| `model_path` | The model path with the new weights. | Required | Type: str | +| `flush_cache` | Flush TeaCache state after update. | `True` | Type: bool | +| `target_modules` | List of module names to update (e.g. `["transformer"]`). If omitted, all `nn.Module` components are updated. | `None` | Type: list[str] | + +**Response body:** + +| Field | Description | Defaults | Options | +| --- | --- | --- | --- | +| `success` | Whether the update succeeded. | - | Type: bool | +| `message` | Status / error message. | - | Type: str | + ### Update Weights from Tensor **When to use:** diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index 122e80b943e7..c1e1e0678e35 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -19,6 +19,7 @@ prepare_request, save_outputs, ) +from sglang.multimodal_gen.runtime.managers.io_struct import UpdateWeightsFromDiskReq from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args @@ -91,90 +92,39 @@ async def health_generate(): # TODO : health generate endpoint return {"status": "ok"} - -@health_router.get("/get_model_info") -async def get_model_info(request: Request): - """Get information about the current model.""" - server_args: ServerArgs = request.app.state.server_args - return { - "model_path": server_args.model_path, - } - - -# Weight update router for RL workflows -weight_update_router = APIRouter() - - -@weight_update_router.post("/update_weights_from_disk") +@health_router.post("/update_weights_from_disk") async def update_weights_from_disk(request: Request): - """ - Update model weights from disk without restarting the server. - - This endpoint enables dynamic weight updates for RL workflows and iterative - model fine-tuning scenarios. - - Request body: - - model_path (str): Path to the new model weights (HuggingFace model path or local directory) - - load_format (str, optional): Format of the weights to load (default: "auto") - - flush_cache (bool, optional): Whether to flush cache after update (default: True) - - target_modules (list[str], optional): List of module names to update. - Default: updates ALL nn.Module components (transformer, vae, text_encoder, etc.) - Examples: ["transformer"] to update only transformer - - Returns: - - success (bool): Whether the update was successful - - message (str): Status message - """ - from sglang.multimodal_gen.runtime.managers.io_struct import ( - UpdateWeightsFromDiskReq, + """Update model weights from disk inplace without restarting the server.""" + body = await request.json() + model_path = body.get("model_path") + if not model_path: + return ORJSONResponse( + {"success": False, "message": "model_path is required"}, + status_code=400, + ) + + req = UpdateWeightsFromDiskReq( + model_path=model_path, + flush_cache=body.get("flush_cache", True), + target_modules=body.get("target_modules"), ) try: - body = await request.json() - model_path = body.get("model_path") - if not model_path: - return ORJSONResponse( - {"success": False, "message": "model_path is required"}, - status_code=400, - ) - - # Create the request object with diffusion-specific fields - req = UpdateWeightsFromDiskReq( - model_path=model_path, - load_format=body.get("load_format", "auto"), - flush_cache=body.get("flush_cache", True), - target_modules=body.get("target_modules"), - ) - response = await async_scheduler_client.forward(req) - - # Handle response - if hasattr(response, "output") and response.output: - result = response.output - return ORJSONResponse( - { - "success": result.get("success", False), - "message": result.get("message", "Unknown status"), - }, - status_code=200 if result.get("success") else 400, - ) - elif hasattr(response, "error") and response.error: - return ORJSONResponse( - {"success": False, "message": response.error}, - status_code=400, - ) - else: - return ORJSONResponse( - {"success": False, "message": "Unknown response format"}, - status_code=500, - ) - except Exception as e: return ORJSONResponse( - {"success": False, "message": f"Error: {str(e)}"}, + {"success": False, "message": str(e)}, status_code=500, ) + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + def make_serializable(obj): """Recursively converts Tensors to None for JSON serialization.""" @@ -295,7 +245,6 @@ def create_app(server_args: ServerArgs): app = FastAPI(lifespan=lifespan) app.include_router(health_router) - app.include_router(weight_update_router) app.include_router(vertex_router) from sglang.multimodal_gen.runtime.entrypoints.openai import common_api diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py new file mode 100644 index 000000000000..540b247067f6 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -0,0 +1,335 @@ +""" +In-place weight updates for diffusion pipeline modules. + +This module provides ``WeightsUpdater``, which swaps model weights at runtime +without restarting the server. It is the diffusion-engine counterpart of the +LLM engine's ``ModelRunner.update_weights_from_disk``. + +Typical usage (from ``GPUWorker``): + + updater = WeightsUpdater(self.pipeline) + success, message = updater.update_weights_from_disk( + model_path, + original_model_path=self.server_args.model_path, + ) + +Key design decisions +-------------------- +* **All-or-nothing**: if any module fails to load, all previously updated + modules are rolled back to the original weights. +* **Rollback failures propagate**: if rollback itself fails, the exception is + *not* caught so the caller knows the model is in an inconsistent state. + This matches the LLM engine behaviour. +* **Offload-aware**: layerwise offload is temporarily disabled during the + update so that weight tensors are fully materialised on the target device. + This is necessary because the offload manager replaces parameters with + small placeholders; the real tensors must be restored before copying. +* **DTensor-aware**: parameters that have been distributed via + ``torch.distributed.tensor`` are updated through ``distribute_tensor`` + so that each shard is correctly placed. +""" + +from __future__ import annotations + +import gc +import os +import time + +import torch + +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +try: + from torch.distributed.tensor import DTensor, distribute_tensor +except ImportError: + DTensor = None + distribute_tensor = None + +logger = init_logger(__name__) + + +class WeightsUpdater: + """In-place weight updates for diffusion pipeline modules. + + Args: + pipeline: A ``ComposedPipelineBase`` (or ``DiffusersPipeline``) instance + whose modules will be updated. + """ + + def __init__(self, pipeline): + self.pipeline = pipeline + + def update_weights_from_disk( + self, + model_path: str, + original_model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update model weights from disk without restarting the server. + + Args: + model_path: HF repo id or local path to the new weights. + original_model_path: Path to the currently loaded weights (used + for rollback on failure). + flush_cache: If ``True``, reset TeaCache state after a successful + update so that stale cached residuals are not reused. + target_modules: Explicit list of module names to update. ``None`` + or ``["all"]`` updates every ``nn.Module`` in the pipeline. + + Returns: + ``(success, message)`` tuple. + """ + tic = time.perf_counter() + self._original_model_path = original_model_path + logger.info(f"Updating weights from disk: {model_path}") + + modules_to_update = self._collect_modules(target_modules) + if not modules_to_update: + available = list(self.pipeline.modules.keys()) + error_msg = ( + f"No matching modules found for update. " + f"Requested: {target_modules}. Available in pipeline: {available}" + ) + logger.error(error_msg) + return False, error_msg + + try: + local_model_path = maybe_download_model(model_path) + except Exception as e: + return False, f"Failed to download model: {e}" + + weights_map, missing = _validate_weight_files( + local_model_path, modules_to_update + ) + if missing: + error_msg = ( + f"Cannot update weights: missing weight files for modules: {missing}. " + f"No partial updates allowed." + ) + logger.error(error_msg) + return False, error_msg + + logger.info( + f"Updating {len(weights_map)} modules: " + + ", ".join(f"{n} <- {p}" for n, p in weights_map.items()) + ) + + offload_disabled = _disable_offload(modules_to_update) + + success, message = self._apply_weights(modules_to_update, weights_map) + + gc.collect() + torch.cuda.empty_cache() + + if success and flush_cache: + for _, module in modules_to_update: + _reset_cache_state(module) + + for m in offload_disabled: + m.enable_offload() + + elapsed = time.perf_counter() - tic + message = f"{message} elapsed={elapsed:.2f}s" + logger.info(message) + return success, message + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _collect_modules( + self, target_modules: list[str] | None + ) -> list[tuple[str, torch.nn.Module]]: + """Resolve *target_modules* to ``(name, module)`` pairs. + + For ``ComposedPipelineBase`` pipelines, modules are looked up via + ``pipeline.modules``. For ``DiffusersPipeline`` (where + ``pipeline.modules`` is empty), we fall back to + ``diffusers_pipe.components``. + """ + available = self.pipeline.modules.keys() + if target_modules is None or target_modules == ["all"]: + names = [ + n for n in available + if isinstance(self.pipeline.get_module(n), torch.nn.Module) + ] + else: + names = target_modules + + result: list[tuple[str, torch.nn.Module]] = [] + for name in names: + module = self.pipeline.get_module(name) + if module is not None and isinstance(module, torch.nn.Module): + result.append((name, module)) + + # Fallback for DiffusersPipeline: modules live on the diffusers pipe, + # not in self.pipeline.modules. + if not result: + diffusers_pipe = self.pipeline.get_module("diffusers_pipeline") + if diffusers_pipe is not None and hasattr(diffusers_pipe, "components"): + components = diffusers_pipe.components + if target_modules is None or target_modules == ["all"]: + names = list(components.keys()) + for name in names: + module = components.get(name) + if module is not None and isinstance(module, torch.nn.Module): + result.append((name, module)) + return result + + def _apply_weights( + self, + modules_to_update: list[tuple[str, torch.nn.Module]], + weights_map: dict[str, str], + ) -> tuple[bool, str]: + """Load weights into each module; rollback on first failure.""" + updated_modules: list[str] = [] + + for module_name, module in modules_to_update: + params = dict(module.named_parameters()) + try: + weights_iter = _get_weights_iter(weights_map[module_name]) + load_weights_into_model(weights_iter, params) + updated_modules.append(module_name) + except Exception as e: + error_msg = f"Failed to update {module_name}: {e}. Rolling back." + logger.error(error_msg, exc_info=True) + self._rollback(updated_modules) + return False, error_msg + + names = ", ".join(updated_modules) + return True, f"Updated {len(updated_modules)} modules ({names})." + + def _rollback(self, updated_modules: list[str]) -> None: + """Restore *updated_modules* to original weights. + + If rollback itself fails the exception propagates so the caller + knows the model is in an inconsistent state. + """ + if not updated_modules: + return + original_path = maybe_download_model(self._original_model_path) + for name in updated_modules: + module = self.pipeline.get_module(name) + if module is None: + continue + weights_dir = find_weights_dir(original_path, name) + if weights_dir is None: + continue + weights_iter = _get_weights_iter(weights_dir) + load_weights_into_model( + weights_iter, dict(module.named_parameters()) + ) + + +# --------------------------------------------------------------------------- +# Module-level utility functions +# --------------------------------------------------------------------------- + + +def find_weights_dir(local_path: str, module_name: str) -> str | None: + """Locate the safetensors directory for *module_name* under *local_path*. + + Tries ``//`` first, then falls back to + *local_path* itself if it directly contains safetensors files (common + for RL checkpoints that save weights in a flat directory). + """ + dir_path = os.path.join(local_path, module_name) + if os.path.exists(dir_path): + return dir_path + if _list_safetensors_files(local_path): + return local_path + return None + + +def _get_weights_iter(weights_dir: str): + """Return a ``(name, tensor)`` iterator over safetensors in *weights_dir*.""" + safetensors_files = _list_safetensors_files(weights_dir) + if not safetensors_files: + raise FileNotFoundError(f"No safetensors files found in {weights_dir}") + return safetensors_weights_iterator(safetensors_files) + + +def _validate_weight_files( + local_model_path: str, + modules_to_update: list[tuple[str, torch.nn.Module]], +) -> tuple[dict[str, str], list[str]]: + """Check that every module has a weights directory with safetensors files. + + Returns: + ``(weights_map, missing)`` where *weights_map* maps module name to its + weights directory and *missing* lists modules without weight files. + """ + weights_map: dict[str, str] = {} + missing: list[str] = [] + for module_name, _ in modules_to_update: + weights_dir = find_weights_dir(local_model_path, module_name) + if weights_dir and _list_safetensors_files(weights_dir): + weights_map[module_name] = weights_dir + else: + missing.append(module_name) + return weights_map, missing + + +def _disable_offload( + modules_to_update: list[tuple[str, torch.nn.Module]], +) -> list[torch.nn.Module]: + """Temporarily disable layerwise offload so weights are materialised. + + The offload manager replaces parameters with small placeholders when + layers are offloaded. Disabling offload restores the real tensors so + that ``load_weights_into_model`` can copy into them. + """ + disabled: list[torch.nn.Module] = [] + for _, module in modules_to_update: + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + module.disable_offload() + disabled.append(module) + return disabled + + +def load_weights_into_model( + weights_iter, model_params: dict +) -> None: + """Copy weights from *weights_iter* into *model_params* in-place. + + Handles ``DTensor`` parameters by re-distributing the loaded weight + according to the existing device mesh and placements. + + Raises: + ValueError: On shape mismatch between model parameter and loaded weight. + """ + for name, loaded_weight in weights_iter: + if name not in model_params: + continue + param = model_params[name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" + ) + if DTensor is not None and isinstance(param, DTensor): + distributed_weight = distribute_tensor( + loaded_weight.to(param.device, param.dtype), + param.device_mesh, + param.placements, + ) + param._local_tensor.copy_(distributed_weight._local_tensor) + else: + param.data.copy_(loaded_weight.to(param.device, param.dtype)) + + +def _reset_cache_state(module: torch.nn.Module) -> None: + """Reset Cache state after weight updates. + + After weights change, any cached residuals from previous denoising steps + are invalid and must be cleared. + """ + if isinstance(module, TeaCacheMixin): + module.reset_teacache_state() diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 26a5ee2b8de8..5ddc3d26f991 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -35,6 +35,9 @@ Req, build_pipeline, ) +from sglang.multimodal_gen.runtime.loader.weights_updater import ( + WeightsUpdater, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs @@ -342,318 +345,26 @@ def list_loras(self) -> OutputBatch: status = self.pipeline.get_lora_status() return OutputBatch(output=status) - # Module name to weight directory mapping for different model architectures - _MODULE_WEIGHT_DIR_MAPPING = { - "transformer": ["transformer", "dit", "model"], - "transformer_2": ["transformer_2"], - "video_dit": ["video_dit", "transformer", "dit", "model"], - "video_dit_2": ["video_dit_2"], - "audio_dit": ["audio_dit"], - } - - # Default modules to update for RL workflows (typically only transformer is trained) - _DEFAULT_TARGET_MODULES = [ - "transformer", - "transformer_2", - "video_dit", - "video_dit_2", - "audio_dit", - ] - def update_weights_from_disk( self, model_path: str, - load_format: str = "auto", flush_cache: bool = True, target_modules: list[str] | None = None, ) -> tuple[bool, str]: - """ - Update model weights from disk in-place without restarting the server. - - This method enables dynamic weight updates for RL workflows and iterative - model fine-tuning scenarios. Includes rollback mechanism to restore original - weights if loading fails. - - By default, updates ALL nn.Module components in the pipeline (transformer, vae, - text_encoder, etc.). Use target_modules to specify a subset if needed. - - Args: - model_path: Path to the new model weights (HuggingFace model path or local directory). - load_format: Format of the weights to load (default: "auto"). - flush_cache: Whether to reset cache state after updating weights (default: True). - target_modules: List of module names to update. If None or ["all"], updates all - nn.Module components. Specify a list like ["transformer"] to update - only specific modules. - - Returns: - Tuple of (success: bool, message: str). - """ - import gc - import os - - from sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files - from sglang.multimodal_gen.runtime.loader.weight_utils import ( - safetensors_weights_iterator, - ) - from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( - maybe_download_model, - ) - - logger.info(f"Updating weights from disk: {model_path}") - - # Store original model path for potential rollback - original_model_path = self.server_args.model_path - + """Update model weights from disk inplace without restarting the server.""" if not self.pipeline: return False, "Pipeline is not initialized" - available_modules: list[str] = [] - if hasattr(self.pipeline, "modules"): - available_modules = list(self.pipeline.modules.keys()) - - # Determine which modules to update - if target_modules is None or target_modules == ["all"]: - # Default: update all nn.Module components in the pipeline - module_names = [ - name - for name in available_modules - if isinstance(self.pipeline.get_module(name), torch.nn.Module) - ] - else: - module_names = target_modules - - # Collect all modules that need to be updated - modules_to_update: list[tuple[str, torch.nn.Module]] = [] - - for name in module_names: - module = self.pipeline.get_module(name) - if module is not None and isinstance(module, torch.nn.Module): - modules_to_update.append((name, module)) - - # For DiffusersPipeline, also check diffusers_pipe attributes - diffusers_pipe = self.pipeline.get_module("diffusers_pipeline") - if diffusers_pipe is not None and not modules_to_update: - for name in module_names: - if hasattr(diffusers_pipe, name): - module = getattr(diffusers_pipe, name) - if module is not None and isinstance(module, torch.nn.Module): - modules_to_update.append((name, module)) - - if not modules_to_update: - # Provide detailed error message - error_msg = ( - f"No matching modules found for update. " - f"Requested: {module_names}. " - f"Available in pipeline: {available_modules}" - ) - logger.error(error_msg) - return False, error_msg - - # Helper function to find weights directory for a module - def find_weights_dir(local_path: str, module_name: str) -> str | None: - possible_dirs = self._MODULE_WEIGHT_DIR_MAPPING.get( - module_name, [module_name] - ) - for dir_name in possible_dirs: - dir_path = os.path.join(local_path, dir_name) - if os.path.exists(dir_path): - return dir_path - # Fallback: check if weights are in root directory (for single-module models) - if _list_safetensors_files(local_path): - return local_path - return None - - # Helper function to get weights iterator from a directory - def get_weights_iter(weights_dir: str): - safetensors_files = _list_safetensors_files(weights_dir) - if not safetensors_files: - raise FileNotFoundError(f"No safetensors files found in {weights_dir}") - return safetensors_weights_iterator(safetensors_files), len( - safetensors_files - ) - - # Helper function to load weights into model - def load_weights_into_model( - weights_iter, model_params: dict - ) -> tuple[int, int]: - try: - from torch.distributed.tensor import DTensor, distribute_tensor - except ImportError: - DTensor = None - distribute_tensor = None - - updated = 0 - skipped = 0 - for name, loaded_weight in weights_iter: - if name in model_params: - param = model_params[name] - if param.shape == loaded_weight.shape: - if DTensor is not None and isinstance(param, DTensor): - # For DTensor, distribute the loaded weight first then copy - distributed_weight = distribute_tensor( - loaded_weight.to(param.device, param.dtype), - param.device_mesh, - param.placements, - ) - param._local_tensor.copy_(distributed_weight._local_tensor) - else: - param.data.copy_( - loaded_weight.to(param.device, param.dtype) - ) - updated += 1 - else: - logger.warning( - f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" - ) - skipped += 1 - else: - skipped += 1 - return updated, skipped - - # Download model if it's a HuggingFace path - try: - local_model_path = maybe_download_model(model_path) - except Exception as e: - return False, f"Failed to download model: {e}" - - # Phase 1: Validate ALL modules have their weight directories before any update - # This ensures we don't do partial updates - module_weights_map: dict[str, str] = {} # module_name -> weights_dir - missing_modules: list[str] = [] - - for module_name, module in modules_to_update: - weights_dir = find_weights_dir(local_model_path, module_name) - if weights_dir is None: - missing_modules.append(module_name) - else: - # Also validate that we can get weights iterator - try: - safetensors_files = _list_safetensors_files(weights_dir) - if not safetensors_files: - missing_modules.append(module_name) - else: - module_weights_map[module_name] = weights_dir - except Exception: - missing_modules.append(module_name) - - # Fail if any module is missing weights - no partial updates allowed - if missing_modules: - error_message = ( - f"Cannot update weights: missing weight files for modules: {missing_modules}. " - f"All modules must have corresponding weights. No partial updates allowed." - ) - logger.error(error_message) - return False, error_message - - # Log which modules will be updated from which directories - logger.info( - f"Updating {len(module_weights_map)} modules: " - + ", ".join( - f"{name} <- {path}" for name, path in module_weights_map.items() - ) + updater = WeightsUpdater(self.pipeline) + success, message = updater.update_weights_from_disk( + model_path, + original_model_path=self.server_args.model_path, + flush_cache=flush_cache, + target_modules=target_modules, ) - - # Phase 2: Update all modules - # First, disable layerwise offload for all modules (load weights from CPU to GPU) - offload_disabled_modules: list[torch.nn.Module] = [] - for module_name, module in modules_to_update: - if ( - hasattr(module, "layerwise_offload_managers") - and module.layerwise_offload_managers - ): - module.disable_offload() - offload_disabled_modules.append(module) - - total_updated = 0 - total_skipped = 0 - updated_modules: list[str] = [] - - for module_name, module in modules_to_update: - weights_dir = module_weights_map[module_name] - model_state_dict = dict(module.named_parameters()) - - try: - weights_iter, _ = get_weights_iter(weights_dir) - updated, skipped = load_weights_into_model( - weights_iter, model_state_dict - ) - total_updated += updated - total_skipped += skipped - updated_modules.append(module_name) - except Exception as e: - # Rollback ALL modules (including the ones already updated) - error_message = ( - f"Failed to update {module_name}: {e}. Rolling back all modules." - ) - logger.error(error_message, exc_info=True) - - if updated_modules: - try: - original_local_path = maybe_download_model(original_model_path) - for rollback_name in updated_modules: - rollback_module = self.pipeline.get_module(rollback_name) - if rollback_module is None: - continue - rollback_weights_dir = find_weights_dir( - original_local_path, rollback_name - ) - if rollback_weights_dir is None: - continue - rollback_iter, _ = get_weights_iter(rollback_weights_dir) - rollback_params = dict(rollback_module.named_parameters()) - load_weights_into_model(rollback_iter, rollback_params) - except Exception as rollback_error: - logger.error(f"Rollback failed: {rollback_error}") - # Re-enable offload before returning - for m in offload_disabled_modules: - m.enable_offload() - return ( - False, - f"{error_message} Rollback also failed: {rollback_error}", - ) - - gc.collect() - torch.cuda.empty_cache() - # Re-enable offload before returning - for m in offload_disabled_modules: - m.enable_offload() - return False, error_message - - # Clean up GPU memory - gc.collect() - torch.cuda.empty_cache() - - # Reset cache state for all updated modules - if flush_cache: - for module_name, module in modules_to_update: - if module_name in updated_modules: - self._reset_cache_state_after_weight_update(module) - - # Re-enable layerwise offload (sync new weights to CPU) - for module in offload_disabled_modules: - module.enable_offload() - - # Update the model path in server_args - self.server_args.model_path = model_path - - message = f"Successfully updated {len(updated_modules)} modules ({', '.join(updated_modules)}): {total_updated} params updated" - logger.info(message) - return True, message - - def _reset_cache_state_after_weight_update(self, module: torch.nn.Module) -> None: - """ - Reset cache state for a single module after weight updates. - - This resets TeaCache state. Cache-DiT context is automatically refreshed - at the start of each inference request with the correct num_inference_steps, - so we don't need to manually reset it here. - - Args: - module: The module whose cache state should be reset. - """ - # Reset TeaCache state if the module has it - if hasattr(module, "reset_teacache_state"): - module.reset_teacache_state() + if success: + self.server_args.model_path = model_path + return success, message OOM_MSG = f""" diff --git a/python/sglang/multimodal_gen/runtime/managers/io_struct.py b/python/sglang/multimodal_gen/runtime/managers/io_struct.py index d4b29223e027..4bfb3b394441 100644 --- a/python/sglang/multimodal_gen/runtime/managers/io_struct.py +++ b/python/sglang/multimodal_gen/runtime/managers/io_struct.py @@ -11,6 +11,5 @@ class UpdateWeightsFromDiskReq: """Request to update model weights from disk for diffusion models.""" model_path: str - load_format: str = "auto" flush_cache: bool = True target_modules: Optional[List[str]] = None diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 7978f8c2ae4a..9a70b0fc4f91 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -135,7 +135,6 @@ def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: req = reqs[0] success, message = self.worker.update_weights_from_disk( model_path=req.model_path, - load_format=req.load_format or "auto", flush_cache=req.flush_cache, target_modules=req.target_modules, ) diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index 6610a4cfbe96..c9b34ca0b0fe 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -29,6 +29,7 @@ "../cli/test_generate_t2i_perf.py", # unit tests (no server needed) "../test_sampling_params_validate.py", + "test_update_weights_from_disk.py", # add new 1-gpu test files here ], "2-gpu": [ diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 4303e823faf0..cfed4e4f07d9 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -54,12 +54,6 @@ class TestUpdateWeightsFromDisk: def _get_base_url(self, ctx: ServerContext) -> str: return f"http://localhost:{ctx.port}" - def _get_model_info(self, base_url: str) -> dict: - """Get current model info from server.""" - response = requests.get(f"{base_url}/get_model_info", timeout=30) - assert response.status_code == 200, f"get_model_info failed: {response.text}" - return response.json() - def _update_weights( self, base_url: str, @@ -83,27 +77,13 @@ def _update_weights( ) return response.json(), response.status_code - def test_get_model_info(self, diffusion_server_for_weight_update: ServerContext): - """Test that we can get model info from the server.""" - base_url = self._get_base_url(diffusion_server_for_weight_update) - model_info = self._get_model_info(base_url) - - assert "model_path" in model_info, "model_path not in response" - logger.info(f"Model info: {model_info}") - def test_update_weights_same_model( self, diffusion_server_for_weight_update: ServerContext ): """Test updating weights with the same model (should succeed).""" base_url = self._get_base_url(diffusion_server_for_weight_update) - # Get current model path - model_info = self._get_model_info(base_url) - current_model_path = model_info["model_path"] - logger.info(f"Current model path: {current_model_path}") - - # Update with same model - result, status_code = self._update_weights(base_url, current_model_path) + result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) logger.info(f"Update result: {result}") assert status_code == 200, f"Expected 200, got {status_code}" @@ -114,12 +94,10 @@ def test_update_weights_with_flush_cache( ): """Test updating weights with flush_cache=True.""" base_url = self._get_base_url(diffusion_server_for_weight_update) - model_info = self._get_model_info(base_url) - current_model_path = model_info["model_path"] result, status_code = self._update_weights( base_url, - current_model_path, + DEFAULT_DIFFUSION_MODEL, flush_cache=True, ) @@ -131,12 +109,10 @@ def test_update_weights_without_flush_cache( ): """Test updating weights with flush_cache=False.""" base_url = self._get_base_url(diffusion_server_for_weight_update) - model_info = self._get_model_info(base_url) - current_model_path = model_info["model_path"] result, status_code = self._update_weights( base_url, - current_model_path, + DEFAULT_DIFFUSION_MODEL, flush_cache=False, ) @@ -179,13 +155,11 @@ def test_update_weights_specific_modules( ): """Test updating only specific modules (e.g., transformer only).""" base_url = self._get_base_url(diffusion_server_for_weight_update) - model_info = self._get_model_info(base_url) - current_model_path = model_info["model_path"] # Try to update only transformer module result, status_code = self._update_weights( base_url, - current_model_path, + DEFAULT_DIFFUSION_MODEL, target_modules=["transformer"], ) logger.info(f"Update specific modules result: {result}") @@ -222,10 +196,6 @@ def diffusion_server_with_offload(self): def _get_base_url(self, ctx: ServerContext) -> str: return f"http://localhost:{ctx.port}" - def _get_model_info(self, base_url: str) -> dict: - response = requests.get(f"{base_url}/get_model_info", timeout=30) - return response.json() - def _update_weights( self, base_url: str, model_path: str, **kwargs ) -> tuple[dict, int]: @@ -240,20 +210,12 @@ def _update_weights( def test_update_weights_with_offload_enabled( self, diffusion_server_with_offload: ServerContext ): - """Test that weight update works correctly when layerwise offload is enabled. - - This tests the fix for the shape mismatch issue where offloaded weights - have placeholder size [1] tensors on GPU. - """ + """Test that weight update works correctly when layerwise offload is enabled.""" base_url = self._get_base_url(diffusion_server_with_offload) - model_info = self._get_model_info(base_url) - current_model_path = model_info["model_path"] - logger.info( - f"Testing weight update with offload enabled, model: {current_model_path}" - ) + logger.info("Testing weight update with offload enabled") - result, status_code = self._update_weights(base_url, current_model_path) + result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) logger.info(f"Update result: {result}") assert status_code == 200, f"Expected 200, got {status_code}" @@ -320,10 +282,9 @@ def test_generation_after_weight_update(self, diffusion_server_e2e: ServerContex logger.info("Generation before update succeeded") # Update weights - model_info = requests.get(f"{base_url}/get_model_info", timeout=30).json() update_response = requests.post( f"{base_url}/update_weights_from_disk", - json={"model_path": model_info["model_path"], "flush_cache": True}, + json={"model_path": DEFAULT_DIFFUSION_MODEL, "flush_cache": True}, timeout=300, ) assert update_response.json().get("success"), "Weight update failed" From 886a6bc321085d866aaa88991db61c690e069695 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Fri, 6 Feb 2026 09:07:35 +0000 Subject: [PATCH 03/30] chore: isort lint --- .../runtime/entrypoints/http_server.py | 1 + .../runtime/loader/weights_updater.py | 16 ++++++++-------- .../runtime/managers/gpu_worker.py | 4 +--- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index c1e1e0678e35..9fed777f51c7 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -92,6 +92,7 @@ async def health_generate(): # TODO : health generate endpoint return {"status": "ok"} + @health_router.post("/update_weights_from_disk") async def update_weights_from_disk(request: Request): """Update model weights from disk inplace without restarting the server.""" diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 540b247067f6..5143d0fe19c8 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -158,7 +158,8 @@ def _collect_modules( available = self.pipeline.modules.keys() if target_modules is None or target_modules == ["all"]: names = [ - n for n in available + n + for n in available if isinstance(self.pipeline.get_module(n), torch.nn.Module) ] else: @@ -224,9 +225,7 @@ def _rollback(self, updated_modules: list[str]) -> None: if weights_dir is None: continue weights_iter = _get_weights_iter(weights_dir) - load_weights_into_model( - weights_iter, dict(module.named_parameters()) - ) + load_weights_into_model(weights_iter, dict(module.named_parameters())) # --------------------------------------------------------------------------- @@ -289,15 +288,16 @@ def _disable_offload( """ disabled: list[torch.nn.Module] = [] for _, module in modules_to_update: - if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + if ( + isinstance(module, OffloadableDiTMixin) + and module.layerwise_offload_managers + ): module.disable_offload() disabled.append(module) return disabled -def load_weights_into_model( - weights_iter, model_params: dict -) -> None: +def load_weights_into_model(weights_iter, model_params: dict) -> None: """Copy weights from *weights_iter* into *model_params* in-place. Handles ``DTensor`` parameters by re-distributing the loaded weight diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 5ddc3d26f991..55704dcd546a 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -29,15 +29,13 @@ get_ulysses_parallel_world_size, ) from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs +from sglang.multimodal_gen.runtime.loader.weights_updater import WeightsUpdater from sglang.multimodal_gen.runtime.pipelines_core import ( ComposedPipelineBase, LoRAPipeline, Req, build_pipeline, ) -from sglang.multimodal_gen.runtime.loader.weights_updater import ( - WeightsUpdater, -) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs From 3d19d61e8374f660e12c8a046906c7a488a3ffea Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Sat, 7 Feb 2026 07:49:21 +0000 Subject: [PATCH 04/30] [diffusion] offload-aware weight updates and cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace _disable_offload (which loaded all layers to GPU) with direct CPU buffer writes via LayerwiseOffloadManager.update_cpu_weights(), eliminating the temporary GPU memory spike during weight updates. - Add update_cpu_weights() to LayerwiseOffloadManager: updates consolidated CPU buffers and live GPU params for prefetched layers. - Optimise .to(dtype, device) → .to(dtype) in load paths to avoid allocating temporary GPU tensors; copy_() handles H2D internally. - Move UpdateWeightsFromDiskReq from io_struct.py to weights_updater.py and delete the now-empty io_struct module. - Expand diffusion engine section in sglang_for_rl.md with details on all-or-nothing rollback, offload-aware, and DTensor-aware behavior. --- docs/advanced_features/sglang_for_rl.md | 6 +- .../runtime/entrypoints/http_server.py | 4 +- .../runtime/loader/weights_updater.py | 145 ++++++++++-------- .../runtime/managers/io_struct.py | 15 -- .../runtime/managers/scheduler.py | 4 +- .../runtime/utils/layerwise_offload.py | 54 +++++++ 6 files changed, 148 insertions(+), 80 deletions(-) delete mode 100644 python/sglang/multimodal_gen/runtime/managers/io_struct.py diff --git a/docs/advanced_features/sglang_for_rl.md b/docs/advanced_features/sglang_for_rl.md index 30984dd32150..d72cbaa7bc8e 100644 --- a/docs/advanced_features/sglang_for_rl.md +++ b/docs/advanced_features/sglang_for_rl.md @@ -106,7 +106,11 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates **Python Engine API:** `engine.update_weights_from_disk(model_path, load_format=None)` -**Diffusion engine (SGLang-D):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint. The update is all-or-nothing with automatic rollback on failure. When layerwise offload (`--dit-layerwise-offload`) is enabled, offload is temporarily disabled during the update, which causes a temporary GPU memory peak before returning to normal offloaded memory usage. +**Diffusion engine (SGLang-D):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint with the following behavior: + +- **All-or-nothing with rollback:** if any module fails to load, all previously updated modules are rolled back to the original weights by reloading from the original model path. No partial updates are left behind. If rollback itself fails, the exception propagates so the caller knows the model is in an inconsistent state. +- **Offload-aware:** when layerwise offload (`--dit-layerwise-offload`) is enabled, the diffusion offload manager replaces GPU parameters with small `torch.empty((1,))` placeholders while real weights live in consolidated pinned CPU buffers. A naive `param.data.copy_()` would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state. +- **DTensor-aware:** parameters distributed via `torch.distributed.tensor` (tensor parallelism) are updated through `distribute_tensor` so that each shard is correctly placed on the right device mesh. **Request body:** diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index 9fed777f51c7..7ccfec057dcb 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -19,7 +19,9 @@ prepare_request, save_outputs, ) -from sglang.multimodal_gen.runtime.managers.io_struct import UpdateWeightsFromDiskReq +from sglang.multimodal_gen.runtime.loader.weights_updater import ( + UpdateWeightsFromDiskReq, +) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 5143d0fe19c8..0c308c8839ed 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -1,11 +1,11 @@ """ In-place weight updates for diffusion pipeline modules. -This module provides ``WeightsUpdater``, which swaps model weights at runtime +This module provides WeightsUpdater, which swaps model weights at runtime without restarting the server. It is the diffusion-engine counterpart of the -LLM engine's ``ModelRunner.update_weights_from_disk``. +LLM engine's ModelRunner.update_weights_from_disk. -Typical usage (from ``GPUWorker``): +Typical usage (from GPUWorker): updater = WeightsUpdater(self.pipeline) success, message = updater.update_weights_from_disk( @@ -13,20 +13,29 @@ original_model_path=self.server_args.model_path, ) -Key design decisions --------------------- -* **All-or-nothing**: if any module fails to load, all previously updated - modules are rolled back to the original weights. -* **Rollback failures propagate**: if rollback itself fails, the exception is - *not* caught so the caller knows the model is in an inconsistent state. +Key design decisions: + +- All-or-nothing: if any module fails to load, all previously updated + modules are rolled back to the original weights by reloading from + original_model_path. No partial updates are left behind. + +- Rollback failures propagate: if rollback itself fails, the exception is + not caught so the caller knows the model is in an inconsistent state. This matches the LLM engine behaviour. -* **Offload-aware**: layerwise offload is temporarily disabled during the - update so that weight tensors are fully materialised on the target device. - This is necessary because the offload manager replaces parameters with - small placeholders; the real tensors must be restored before copying. -* **DTensor-aware**: parameters that have been distributed via - ``torch.distributed.tensor`` are updated through ``distribute_tensor`` - so that each shard is correctly placed. + +- Offload-aware: the diffusion LayerwiseOffloadManager replaces GPU + parameters with torch.empty((1,)) placeholders while real weights live + in consolidated pinned CPU buffers. A naive param.data.copy_() would + fail with a shape mismatch. Instead, the updater dynamically detects + active offload managers and writes new weights directly into their CPU + buffers via update_cpu_weights(), bypassing the placeholders entirely. + For any layer that happens to be prefetched on GPU at update time, the + live GPU tensor is also updated so the change takes effect immediately. + This requires no extra GPU memory and does not disturb the offload state. + +- DTensor-aware: parameters that have been distributed via + torch.distributed.tensor are updated through distribute_tensor + so that each shard is correctly placed on the right device mesh. """ from __future__ import annotations @@ -34,6 +43,7 @@ import gc import os import time +from dataclasses import dataclass import torch @@ -55,11 +65,20 @@ logger = init_logger(__name__) +@dataclass +class UpdateWeightsFromDiskReq: + """Request to update model weights from disk for diffusion models.""" + + model_path: str + flush_cache: bool = True + target_modules: list[str] | None = None + + class WeightsUpdater: """In-place weight updates for diffusion pipeline modules. Args: - pipeline: A ``ComposedPipelineBase`` (or ``DiffusersPipeline``) instance + pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance whose modules will be updated. """ @@ -79,13 +98,13 @@ def update_weights_from_disk( model_path: HF repo id or local path to the new weights. original_model_path: Path to the currently loaded weights (used for rollback on failure). - flush_cache: If ``True``, reset TeaCache state after a successful + flush_cache: If True, reset TeaCache state after a successful update so that stale cached residuals are not reused. - target_modules: Explicit list of module names to update. ``None`` - or ``["all"]`` updates every ``nn.Module`` in the pipeline. + target_modules: Explicit list of module names to update. None + or ["all"] updates every nn.Module in the pipeline. Returns: - ``(success, message)`` tuple. + (success, message) tuple. """ tic = time.perf_counter() self._original_model_path = original_model_path @@ -122,8 +141,6 @@ def update_weights_from_disk( + ", ".join(f"{n} <- {p}" for n, p in weights_map.items()) ) - offload_disabled = _disable_offload(modules_to_update) - success, message = self._apply_weights(modules_to_update, weights_map) gc.collect() @@ -133,9 +150,6 @@ def update_weights_from_disk( for _, module in modules_to_update: _reset_cache_state(module) - for m in offload_disabled: - m.enable_offload() - elapsed = time.perf_counter() - tic message = f"{message} elapsed={elapsed:.2f}s" logger.info(message) @@ -148,12 +162,11 @@ def update_weights_from_disk( def _collect_modules( self, target_modules: list[str] | None ) -> list[tuple[str, torch.nn.Module]]: - """Resolve *target_modules* to ``(name, module)`` pairs. + """Resolve target_modules to (name, module) pairs. - For ``ComposedPipelineBase`` pipelines, modules are looked up via - ``pipeline.modules``. For ``DiffusersPipeline`` (where - ``pipeline.modules`` is empty), we fall back to - ``diffusers_pipe.components``. + For ComposedPipelineBase pipelines, modules are looked up via + pipeline.modules. For DiffusersPipeline (where pipeline.modules + is empty), we fall back to diffusers_pipe.components. """ available = self.pipeline.modules.keys() if target_modules is None or target_modules == ["all"]: @@ -194,10 +207,9 @@ def _apply_weights( updated_modules: list[str] = [] for module_name, module in modules_to_update: - params = dict(module.named_parameters()) try: weights_iter = _get_weights_iter(weights_map[module_name]) - load_weights_into_model(weights_iter, params) + _load_weights_into_module(module, weights_iter) updated_modules.append(module_name) except Exception as e: error_msg = f"Failed to update {module_name}: {e}. Rolling back." @@ -209,7 +221,7 @@ def _apply_weights( return True, f"Updated {len(updated_modules)} modules ({names})." def _rollback(self, updated_modules: list[str]) -> None: - """Restore *updated_modules* to original weights. + """Restore updated_modules to original weights. If rollback itself fails the exception propagates so the caller knows the model is in an inconsistent state. @@ -225,7 +237,7 @@ def _rollback(self, updated_modules: list[str]) -> None: if weights_dir is None: continue weights_iter = _get_weights_iter(weights_dir) - load_weights_into_model(weights_iter, dict(module.named_parameters())) + _load_weights_into_module(module, weights_iter) # --------------------------------------------------------------------------- @@ -234,11 +246,13 @@ def _rollback(self, updated_modules: list[str]) -> None: def find_weights_dir(local_path: str, module_name: str) -> str | None: - """Locate the safetensors directory for *module_name* under *local_path*. + """Locate the safetensors directory for module_name under local_path. - Tries ``//`` first, then falls back to - *local_path* itself if it directly contains safetensors files (common - for RL checkpoints that save weights in a flat directory). + Diffusion models store weights in per-module subdirectories (e.g. + transformer/, vae/, text_encoder/). This function tries + // first, then falls back to local_path + itself if it directly contains safetensors files (common for RL + checkpoints that save weights in a flat directory). """ dir_path = os.path.join(local_path, module_name) if os.path.exists(dir_path): @@ -249,7 +263,7 @@ def find_weights_dir(local_path: str, module_name: str) -> str | None: def _get_weights_iter(weights_dir: str): - """Return a ``(name, tensor)`` iterator over safetensors in *weights_dir*.""" + """Return a (name, tensor) iterator over safetensors in weights_dir.""" safetensors_files = _list_safetensors_files(weights_dir) if not safetensors_files: raise FileNotFoundError(f"No safetensors files found in {weights_dir}") @@ -263,8 +277,8 @@ def _validate_weight_files( """Check that every module has a weights directory with safetensors files. Returns: - ``(weights_map, missing)`` where *weights_map* maps module name to its - weights directory and *missing* lists modules without weight files. + (weights_map, missing) where weights_map maps module name to its + weights directory and missing lists modules without weight files. """ weights_map: dict[str, str] = {} missing: list[str] = [] @@ -277,30 +291,37 @@ def _validate_weight_files( return weights_map, missing -def _disable_offload( - modules_to_update: list[tuple[str, torch.nn.Module]], -) -> list[torch.nn.Module]: - """Temporarily disable layerwise offload so weights are materialised. +def _get_offload_managers(module: torch.nn.Module) -> list: + """Return active offload managers for the given module, if any.""" + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + return [m for m in module.layerwise_offload_managers if m.enabled] + return [] + + +def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: + """Load weights into a module, handling offload-managed parameters. - The offload manager replaces parameters with small placeholders when - layers are offloaded. Disabling offload restores the real tensors so - that ``load_weights_into_model`` can copy into them. + When layerwise offload is active, block-layer parameters are stored as + small placeholders on GPU while the real weights live in consolidated + CPU buffers. This function updates those CPU buffers directly and + falls back to the normal in-place copy for non-offloaded parameters. """ - disabled: list[torch.nn.Module] = [] - for _, module in modules_to_update: - if ( - isinstance(module, OffloadableDiTMixin) - and module.layerwise_offload_managers - ): - module.disable_offload() - disabled.append(module) - return disabled + offload_managers = _get_offload_managers(module) + if offload_managers: + weight_dict = dict(weights_iter) + offloaded_names: set[str] = set() + for manager in offload_managers: + offloaded_names |= manager.update_cpu_weights(weight_dict) + remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) + load_weights_into_model(remaining, dict(module.named_parameters())) + else: + load_weights_into_model(weights_iter, dict(module.named_parameters())) def load_weights_into_model(weights_iter, model_params: dict) -> None: - """Copy weights from *weights_iter* into *model_params* in-place. + """Copy weights from weights_iter into model_params in-place. - Handles ``DTensor`` parameters by re-distributing the loaded weight + Handles DTensor parameters by re-distributing the loaded weight according to the existing device mesh and placements. Raises: @@ -316,13 +337,13 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None: ) if DTensor is not None and isinstance(param, DTensor): distributed_weight = distribute_tensor( - loaded_weight.to(param.device, param.dtype), + loaded_weight.to(param.dtype), param.device_mesh, param.placements, ) param._local_tensor.copy_(distributed_weight._local_tensor) else: - param.data.copy_(loaded_weight.to(param.device, param.dtype)) + param.data.copy_(loaded_weight.to(param.dtype)) def _reset_cache_state(module: torch.nn.Module) -> None: diff --git a/python/sglang/multimodal_gen/runtime/managers/io_struct.py b/python/sglang/multimodal_gen/runtime/managers/io_struct.py deleted file mode 100644 index 4bfb3b394441..000000000000 --- a/python/sglang/multimodal_gen/runtime/managers/io_struct.py +++ /dev/null @@ -1,15 +0,0 @@ -""" -I/O data structures for diffusion engine scheduler. -""" - -from dataclasses import dataclass -from typing import List, Optional - - -@dataclass -class UpdateWeightsFromDiskReq: - """Request to update model weights from disk for diffusion models.""" - - model_path: str - flush_cache: bool = True - target_modules: Optional[List[str]] = None diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 9a70b0fc4f91..a8f3011eda84 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -20,8 +20,10 @@ _parse_size, save_image_to_path, ) +from sglang.multimodal_gen.runtime.loader.weights_updater import ( + UpdateWeightsFromDiskReq, +) from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker -from sglang.multimodal_gen.runtime.managers.io_struct import UpdateWeightsFromDiskReq from sglang.multimodal_gen.runtime.pipelines_core import Req from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch from sglang.multimodal_gen.runtime.server_args import ( diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 8af6ad1a69ed..85b5d0aaf7de 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -276,6 +276,60 @@ def sync_all_layers_to_cpu(self) -> None: for layer_idx in list(self._gpu_layers): self.sync_layer_to_cpu(layer_idx) + @torch.compiler.disable + def update_cpu_weights(self, weight_dict: Dict[str, torch.Tensor]) -> Set[str]: + """Update consolidated CPU buffers with new weights. + + For layers currently on GPU, the live GPU parameter is also updated + so the change takes effect immediately. + + Args: + weight_dict: Mapping of parameter name to new weight tensor. + + Returns: + Set of parameter names that were successfully updated. + + Raises: + ValueError: If a weight's shape does not match the recorded + metadata (i.e. the real shape, not the placeholder shape). + """ + if not self.enabled: + return set() + + updated_names: Set[str] = set() + for name, loaded_weight in weight_dict.items(): + layer_idx = self._match_layer_idx(name) + if layer_idx is None: + continue + meta_layer = self._weight_metadata.get(layer_idx) + if meta_layer is None or name not in meta_layer: + continue + + meta = meta_layer[name] + if tuple(meta["shape"]) != tuple(loaded_weight.shape): + raise ValueError( + f"Shape mismatch for {name}: " + f"expected={tuple(meta['shape'])}, " + f"loaded={tuple(loaded_weight.shape)}" + ) + + dtype = meta["dtype"] + offset = meta["offset"] + numel = meta["numel"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + cpu_buffer[offset : offset + numel].copy_( + loaded_weight.to(dtype=dtype).flatten() + ) + + # If this layer is currently on GPU, update the live parameter. + if layer_idx in self._gpu_layers: + target = self.get_target_with_name(name) + target.data.copy_(loaded_weight.to(dtype=target.dtype)) + + updated_names.add(name) + + return updated_names + def register_forward_hooks(self) -> None: if not self.enabled: return From d0e2fecd7b3a2e9d05e2216073b08e91caebdd37 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Sun, 8 Feb 2026 08:04:39 +0000 Subject: [PATCH 05/30] [diffusion] refactor: move post-training API to dedicated package and align naming MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract /update_weights_from_disk endpoint from http_server.py to entrypoints/post_training/weights_api.py - Move UpdateWeightsFromDiskReq to entrypoints/post_training/utils.py and rename to UpdateWeightFromDiskReqInput (matches LLM convention) - Fix doc nit: SGLang-D → SGLang-Diffusion in sglang_for_rl.md --- docs/advanced_features/sglang_for_rl.md | 2 +- .../runtime/entrypoints/http_server.py | 39 +--------------- .../entrypoints/post_training/__init__.py | 0 .../entrypoints/post_training/utils.py | 12 +++++ .../entrypoints/post_training/weights_api.py | 45 +++++++++++++++++++ .../runtime/loader/weights_updater.py | 10 ----- .../runtime/managers/scheduler.py | 6 +-- 7 files changed, 63 insertions(+), 51 deletions(-) create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py diff --git a/docs/advanced_features/sglang_for_rl.md b/docs/advanced_features/sglang_for_rl.md index d72cbaa7bc8e..57b3861d781a 100644 --- a/docs/advanced_features/sglang_for_rl.md +++ b/docs/advanced_features/sglang_for_rl.md @@ -106,7 +106,7 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates **Python Engine API:** `engine.update_weights_from_disk(model_path, load_format=None)` -**Diffusion engine (SGLang-D):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint with the following behavior: +**Diffusion engine (SGLang-Diffusion):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint with the following behavior: - **All-or-nothing with rollback:** if any module fails to load, all previously updated modules are rolled back to the original weights by reloading from the original model path. No partial updates are left behind. If rollback itself fails, the exception propagates so the caller knows the model is in an inconsistent state. - **Offload-aware:** when layerwise offload (`--dit-layerwise-offload`) is enabled, the diffusion offload manager replaces GPU parameters with small `torch.empty((1,))` placeholders while real weights live in consolidated pinned CPU buffers. A naive `param.data.copy_()` would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state. diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index 7ccfec057dcb..9283a3e5b0ed 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -15,13 +15,11 @@ from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( VertexGenerateReqInput, ) +from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api from sglang.multimodal_gen.runtime.entrypoints.utils import ( prepare_request, save_outputs, ) -from sglang.multimodal_gen.runtime.loader.weights_updater import ( - UpdateWeightsFromDiskReq, -) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args @@ -95,40 +93,6 @@ async def health_generate(): return {"status": "ok"} -@health_router.post("/update_weights_from_disk") -async def update_weights_from_disk(request: Request): - """Update model weights from disk inplace without restarting the server.""" - body = await request.json() - model_path = body.get("model_path") - if not model_path: - return ORJSONResponse( - {"success": False, "message": "model_path is required"}, - status_code=400, - ) - - req = UpdateWeightsFromDiskReq( - model_path=model_path, - flush_cache=body.get("flush_cache", True), - target_modules=body.get("target_modules"), - ) - - try: - response = await async_scheduler_client.forward(req) - except Exception as e: - return ORJSONResponse( - {"success": False, "message": str(e)}, - status_code=500, - ) - - result = response.output - success = result.get("success", False) - message = result.get("message", "Unknown status") - return ORJSONResponse( - {"success": success, "message": message}, - status_code=200 if success else 400, - ) - - def make_serializable(obj): """Recursively converts Tensors to None for JSON serialization.""" if isinstance(obj, torch.Tensor): @@ -255,6 +219,7 @@ def create_app(server_args: ServerArgs): app.include_router(common_api.router) app.include_router(image_api.router) app.include_router(video_api.router) + app.include_router(weights_api.router) app.state.server_args = server_args return app diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py new file mode 100644 index 000000000000..cad7c7ce927e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py @@ -0,0 +1,12 @@ +"""Request/response data structures for post-training APIs.""" + +from dataclasses import dataclass + + +@dataclass +class UpdateWeightFromDiskReqInput: + """Request to update model weights from disk for diffusion models.""" + + model_path: str + flush_cache: bool = True + target_modules: list[str] | None = None diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py new file mode 100644 index 000000000000..a26f8b65116c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -0,0 +1,45 @@ +"""Post-training APIs: weight updates and related operations.""" + +from fastapi import APIRouter, Request +from fastapi.responses import ORJSONResponse + +from sglang.multimodal_gen.runtime.entrypoints.post_training.utils import ( + UpdateWeightFromDiskReqInput, +) +from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client + +router = APIRouter() + + +@router.post("/update_weights_from_disk") +async def update_weights_from_disk(request: Request): + """Update model weights from disk inplace without restarting the server.""" + body = await request.json() + model_path = body.get("model_path") + if not model_path: + return ORJSONResponse( + {"success": False, "message": "model_path is required"}, + status_code=400, + ) + + req = UpdateWeightFromDiskReqInput( + model_path=model_path, + flush_cache=body.get("flush_cache", True), + target_modules=body.get("target_modules"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 0c308c8839ed..dc0201924637 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -43,7 +43,6 @@ import gc import os import time -from dataclasses import dataclass import torch @@ -65,15 +64,6 @@ logger = init_logger(__name__) -@dataclass -class UpdateWeightsFromDiskReq: - """Request to update model weights from disk for diffusion models.""" - - model_path: str - flush_cache: bool = True - target_modules: list[str] | None = None - - class WeightsUpdater: """In-place weight updates for diffusion pipeline modules. diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index a8f3011eda84..0de7b85fb1d4 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -20,8 +20,8 @@ _parse_size, save_image_to_path, ) -from sglang.multimodal_gen.runtime.loader.weights_updater import ( - UpdateWeightsFromDiskReq, +from sglang.multimodal_gen.runtime.entrypoints.post_training.utils import ( + UpdateWeightFromDiskReqInput, ) from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker from sglang.multimodal_gen.runtime.pipelines_core import Req @@ -92,7 +92,7 @@ def __init__( List[Req]: self._handle_generation, ListLorasReq: self._handle_list_loras, ShutdownReq: self._handle_shutdown, - UpdateWeightsFromDiskReq: self._handle_update_weights_from_disk, + UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, } # FIFO, new reqs are appended From 07cc6791cd06fa49dcb9a8fccc67412b7bb83a65 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Tue, 10 Feb 2026 05:49:49 +0000 Subject: [PATCH 06/30] [diffusion] refactor: extract get_updatable_modules and harden WeightsUpdater - Extract get_updatable_modules() as a public module-level function that unifies module lookup across ComposedPipelineBase and DiffusersPipeline, replacing the diffusers fallback hack in _collect_modules - Validate unknown target_modules upfront with ValueError and informative error listing available modules - Improve _apply_weights error messages: include failing module name and list of modules being rolled back - Simplify docstrings while preserving key information --- .../runtime/loader/weights_updater.py | 97 ++++++++++--------- .../runtime/managers/scheduler.py | 1 + .../server/test_update_weights_from_disk.py | 18 ++++ 3 files changed, 70 insertions(+), 46 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index dc0201924637..e73508405fe4 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -51,6 +51,7 @@ from sglang.multimodal_gen.runtime.loader.weight_utils import ( safetensors_weights_iterator, ) +from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -64,6 +65,23 @@ logger = init_logger(__name__) +def get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]: + """Return updatable nn.Module components for the given pipeline. + + Works with both the native ComposedPipelineBase backend and the + DiffusersPipeline wrapper. + """ + if isinstance(pipeline, DiffusersPipeline): + diffusers_pipe = pipeline.get_module("diffusers_pipeline") + if diffusers_pipe is not None and diffusers_pipe.components is not None: + raw = diffusers_pipe.components + else: + raw = {} + else: + raw = pipeline.modules + return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)} + + class WeightsUpdater: """In-place weight updates for diffusion pipeline modules. @@ -100,12 +118,17 @@ def update_weights_from_disk( self._original_model_path = original_model_path logger.info(f"Updating weights from disk: {model_path}") - modules_to_update = self._collect_modules(target_modules) + try: + modules_to_update = self._collect_modules(target_modules) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + if not modules_to_update: - available = list(self.pipeline.modules.keys()) error_msg = ( f"No matching modules found for update. " - f"Requested: {target_modules}. Available in pipeline: {available}" + f"Requested: {target_modules}. " + f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}" ) logger.error(error_msg) return False, error_msg @@ -154,39 +177,23 @@ def _collect_modules( ) -> list[tuple[str, torch.nn.Module]]: """Resolve target_modules to (name, module) pairs. - For ComposedPipelineBase pipelines, modules are looked up via - pipeline.modules. For DiffusersPipeline (where pipeline.modules - is empty), we fall back to diffusers_pipe.components. + Raises: + ValueError: If target_modules contains names not found in the pipeline. """ - available = self.pipeline.modules.keys() + components = get_updatable_modules(self.pipeline) + if target_modules is None or target_modules == ["all"]: - names = [ - n - for n in available - if isinstance(self.pipeline.get_module(n), torch.nn.Module) - ] + names = list(components.keys()) else: + unknown = [n for n in target_modules if n not in components] + if unknown: + raise ValueError( + f"Module(s) requested for update not found in pipeline: {unknown}. " + f"Available Module(s): {list(components.keys())}" + ) names = target_modules - result: list[tuple[str, torch.nn.Module]] = [] - for name in names: - module = self.pipeline.get_module(name) - if module is not None and isinstance(module, torch.nn.Module): - result.append((name, module)) - - # Fallback for DiffusersPipeline: modules live on the diffusers pipe, - # not in self.pipeline.modules. - if not result: - diffusers_pipe = self.pipeline.get_module("diffusers_pipeline") - if diffusers_pipe is not None and hasattr(diffusers_pipe, "components"): - components = diffusers_pipe.components - if target_modules is None or target_modules == ["all"]: - names = list(components.keys()) - for name in names: - module = components.get(name) - if module is not None and isinstance(module, torch.nn.Module): - result.append((name, module)) - return result + return [(name, components[name]) for name in names] def _apply_weights( self, @@ -202,10 +209,17 @@ def _apply_weights( _load_weights_into_module(module, weights_iter) updated_modules.append(module_name) except Exception as e: - error_msg = f"Failed to update {module_name}: {e}. Rolling back." - logger.error(error_msg, exc_info=True) + logger.error( + f"Weight update failed for module '{module_name}': {e}. " + f"Rolling back {len(updated_modules)} already updated module(s): " + f"{updated_modules}.", + exc_info=True, + ) self._rollback(updated_modules) - return False, error_msg + return False, ( + f"Failed to update module '{module_name}': {e}. " + f"All modules rolled back to original weights." + ) names = ", ".join(updated_modules) return True, f"Updated {len(updated_modules)} modules ({names})." @@ -291,10 +305,8 @@ def _get_offload_managers(module: torch.nn.Module) -> list: def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: """Load weights into a module, handling offload-managed parameters. - When layerwise offload is active, block-layer parameters are stored as - small placeholders on GPU while the real weights live in consolidated - CPU buffers. This function updates those CPU buffers directly and - falls back to the normal in-place copy for non-offloaded parameters. + For offloaded modules, updates CPU buffers directly via + update_cpu_weights(); non-offloaded parameters use in-place copy. """ offload_managers = _get_offload_managers(module) if offload_managers: @@ -309,14 +321,7 @@ def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: def load_weights_into_model(weights_iter, model_params: dict) -> None: - """Copy weights from weights_iter into model_params in-place. - - Handles DTensor parameters by re-distributing the loaded weight - according to the existing device mesh and placements. - - Raises: - ValueError: On shape mismatch between model parameter and loaded weight. - """ + """Copy weights from weights_iter into model_params in-place.""" for name, loaded_weight in weights_iter: if name not in model_params: continue diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 0de7b85fb1d4..3aa254832f21 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -132,6 +132,7 @@ def _handle_list_loras(self, _reqs: List[Any]) -> OutputBatch: def _handle_shutdown(self, _reqs: List[Any]) -> OutputBatch: self._running = False return OutputBatch() + def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: """Handle update_weights_from_disk request for RL workflows.""" req = reqs[0] diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index cfed4e4f07d9..8bb7bc1c1ac0 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -169,6 +169,24 @@ def test_update_weights_specific_modules( # The test verifies the API handles target_modules parameter assert status_code == 200 + def test_update_weights_nonexistent_module( + self, diffusion_server_for_weight_update: ServerContext + ): + """Test that requesting a non-existent module name fails with a clear error.""" + base_url = self._get_base_url(diffusion_server_for_weight_update) + + result, status_code = self._update_weights( + base_url, + DEFAULT_DIFFUSION_MODEL, + target_modules=["nonexistent_module"], + timeout=60, + ) + logger.info(f"Update nonexistent module result: {result}") + + assert status_code == 400, f"Expected 400, got {status_code}" + assert not result.get("success", True), "Should fail for nonexistent module" + assert "not found in pipeline" in result.get("message", "") + class TestUpdateWeightsFromDiskWithOffload: """Test update_weights_from_disk with layerwise offload enabled.""" From c442af250410c54baf1d6256c4801d7b97a142ce Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Wed, 11 Feb 2026 08:30:56 +0000 Subject: [PATCH 07/30] [diffusion] address comments --- docs/advanced_features/sglang_for_rl.md | 2 + .../post_training/{utils.py => io_struct.py} | 0 .../entrypoints/post_training/weights_api.py | 4 +- .../multimodal_gen/runtime/loader/utils.py | 12 +++ .../runtime/loader/weights_updater.py | 88 ++++++------------- .../runtime/managers/gpu_worker.py | 1 - .../runtime/managers/scheduler.py | 2 +- .../runtime/utils/layerwise_offload.py | 18 +++- 8 files changed, 57 insertions(+), 70 deletions(-) rename python/sglang/multimodal_gen/runtime/entrypoints/post_training/{utils.py => io_struct.py} (100%) diff --git a/docs/advanced_features/sglang_for_rl.md b/docs/advanced_features/sglang_for_rl.md index 57b3861d781a..12eb41540339 100644 --- a/docs/advanced_features/sglang_for_rl.md +++ b/docs/advanced_features/sglang_for_rl.md @@ -127,6 +127,8 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates | `success` | Whether the update succeeded. | - | Type: bool | | `message` | Status / error message. | - | Type: str | +> **Note:** The diffusion engine (SGLang-Diffusion) does not currently support hot refit (updating weights while inference is in progress). The diffusion scheduler processes one request at a time and completes the entire inference before handling the next request, so weight updates and inference never run concurrently. + ### Update Weights from Tensor **When to use:** diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py similarity index 100% rename from python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py rename to python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py index a26f8b65116c..c1e53a9c47d2 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -1,9 +1,9 @@ -"""Post-training APIs: weight updates and related operations.""" +"""Weight update API for the diffusion engine.""" from fastapi import APIRouter, Request from fastapi.responses import ORJSONResponse -from sglang.multimodal_gen.runtime.entrypoints.post_training.utils import ( +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( UpdateWeightFromDiskReqInput, ) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client diff --git a/python/sglang/multimodal_gen/runtime/loader/utils.py b/python/sglang/multimodal_gen/runtime/loader/utils.py index 1b3cefa651eb..725cf32265ab 100644 --- a/python/sglang/multimodal_gen/runtime/loader/utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/utils.py @@ -145,6 +145,18 @@ def _list_safetensors_files(model_path: str) -> list[str]: return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors"))) +def find_weights_dir(local_path: str, module_name: str) -> str | None: + """Locate the safetensors directory for module_name under local_path. + + Diffusion models store weights in per-module subdirectories (e.g. + transformer/, vae/, text_encoder/). + """ + dir_path = os.path.join(local_path, module_name) + if os.path.exists(dir_path): + return dir_path + return None + + def get_memory_usage_of_component(module) -> float | None: """ returned value is in GB, rounded to 2 decimal digits diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index e73508405fe4..30e680fae7b0 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -5,19 +5,23 @@ without restarting the server. It is the diffusion-engine counterpart of the LLM engine's ModelRunner.update_weights_from_disk. -Typical usage (from GPUWorker): +Typical usage (from GPUWorker.update_weights_from_disk): updater = WeightsUpdater(self.pipeline) success, message = updater.update_weights_from_disk( model_path, - original_model_path=self.server_args.model_path, + flush_cache=flush_cache, + target_modules=target_modules, ) + if success: + self.server_args.model_path = model_path + return success, message Key design decisions: - All-or-nothing: if any module fails to load, all previously updated modules are rolled back to the original weights by reloading from - original_model_path. No partial updates are left behind. + pipeline.model_path. No partial updates are left behind. - Rollback failures propagate: if rollback itself fails, the exception is not caught so the caller knows the model is in an inconsistent state. @@ -41,13 +45,15 @@ from __future__ import annotations import gc -import os -import time import torch +from torch.distributed.tensor import DTensor, distribute_tensor from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin -from sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, + find_weights_dir, +) from sglang.multimodal_gen.runtime.loader.weight_utils import ( safetensors_weights_iterator, ) @@ -56,12 +62,6 @@ from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger -try: - from torch.distributed.tensor import DTensor, distribute_tensor -except ImportError: - DTensor = None - distribute_tensor = None - logger = init_logger(__name__) @@ -87,7 +87,8 @@ class WeightsUpdater: Args: pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance - whose modules will be updated. + whose modules will be updated. The pipeline's model_path + attribute is used for rollback on failure. """ def __init__(self, pipeline): @@ -96,7 +97,6 @@ def __init__(self, pipeline): def update_weights_from_disk( self, model_path: str, - original_model_path: str, flush_cache: bool = True, target_modules: list[str] | None = None, ) -> tuple[bool, str]: @@ -104,18 +104,14 @@ def update_weights_from_disk( Args: model_path: HF repo id or local path to the new weights. - original_model_path: Path to the currently loaded weights (used - for rollback on failure). flush_cache: If True, reset TeaCache state after a successful update so that stale cached residuals are not reused. target_modules: Explicit list of module names to update. None - or ["all"] updates every nn.Module in the pipeline. + updates every nn.Module in the pipeline. Returns: - (success, message) tuple. + (success, message) tuple where success is True on success. """ - tic = time.perf_counter() - self._original_model_path = original_model_path logger.info(f"Updating weights from disk: {model_path}") try: @@ -161,10 +157,9 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - _reset_cache_state(module) + if isinstance(module, TeaCacheMixin): + module.reset_teacache_state() - elapsed = time.perf_counter() - tic - message = f"{message} elapsed={elapsed:.2f}s" logger.info(message) return success, message @@ -182,7 +177,7 @@ def _collect_modules( """ components = get_updatable_modules(self.pipeline) - if target_modules is None or target_modules == ["all"]: + if target_modules is None: names = list(components.keys()) else: unknown = [n for n in target_modules if n not in components] @@ -232,7 +227,7 @@ def _rollback(self, updated_modules: list[str]) -> None: """ if not updated_modules: return - original_path = maybe_download_model(self._original_model_path) + original_path = maybe_download_model(self.pipeline.model_path) for name in updated_modules: module = self.pipeline.get_module(name) if module is None: @@ -249,23 +244,6 @@ def _rollback(self, updated_modules: list[str]) -> None: # --------------------------------------------------------------------------- -def find_weights_dir(local_path: str, module_name: str) -> str | None: - """Locate the safetensors directory for module_name under local_path. - - Diffusion models store weights in per-module subdirectories (e.g. - transformer/, vae/, text_encoder/). This function tries - // first, then falls back to local_path - itself if it directly contains safetensors files (common for RL - checkpoints that save weights in a flat directory). - """ - dir_path = os.path.join(local_path, module_name) - if os.path.exists(dir_path): - return dir_path - if _list_safetensors_files(local_path): - return local_path - return None - - def _get_weights_iter(weights_dir: str): """Return a (name, tensor) iterator over safetensors in weights_dir.""" safetensors_files = _list_safetensors_files(weights_dir) @@ -295,25 +273,21 @@ def _validate_weight_files( return weights_map, missing -def _get_offload_managers(module: torch.nn.Module) -> list: - """Return active offload managers for the given module, if any.""" - if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: - return [m for m in module.layerwise_offload_managers if m.enabled] - return [] - - def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: """Load weights into a module, handling offload-managed parameters. For offloaded modules, updates CPU buffers directly via update_cpu_weights(); non-offloaded parameters use in-place copy. """ - offload_managers = _get_offload_managers(module) + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + if offload_managers: weight_dict = dict(weights_iter) offloaded_names: set[str] = set() for manager in offload_managers: - offloaded_names |= manager.update_cpu_weights(weight_dict) + offloaded_names.update(manager.update_cpu_weights(weight_dict)) remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) load_weights_into_model(remaining, dict(module.named_parameters())) else: @@ -330,7 +304,7 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None: raise ValueError( f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" ) - if DTensor is not None and isinstance(param, DTensor): + if isinstance(param, DTensor): distributed_weight = distribute_tensor( loaded_weight.to(param.dtype), param.device_mesh, @@ -339,13 +313,3 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None: param._local_tensor.copy_(distributed_weight._local_tensor) else: param.data.copy_(loaded_weight.to(param.dtype)) - - -def _reset_cache_state(module: torch.nn.Module) -> None: - """Reset Cache state after weight updates. - - After weights change, any cached residuals from previous denoising steps - are invalid and must be cleared. - """ - if isinstance(module, TeaCacheMixin): - module.reset_teacache_state() diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 55704dcd546a..75f298f8dbae 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -356,7 +356,6 @@ def update_weights_from_disk( updater = WeightsUpdater(self.pipeline) success, message = updater.update_weights_from_disk( model_path, - original_model_path=self.server_args.model_path, flush_cache=flush_cache, target_modules=target_modules, ) diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 3aa254832f21..304b33d52dd3 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -20,7 +20,7 @@ _parse_size, save_image_to_path, ) -from sglang.multimodal_gen.runtime.entrypoints.post_training.utils import ( +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( UpdateWeightFromDiskReqInput, ) from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 85b5d0aaf7de..7c6508de7526 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -277,11 +277,21 @@ def sync_all_layers_to_cpu(self) -> None: self.sync_layer_to_cpu(layer_idx) @torch.compiler.disable - def update_cpu_weights(self, weight_dict: Dict[str, torch.Tensor]) -> Set[str]: + def update_cpu_weights( + self, weight_dict: Dict[str, torch.Tensor] + ) -> Set[str] | None: """Update consolidated CPU buffers with new weights. - For layers currently on GPU, the live GPU parameter is also updated - so the change takes effect immediately. + When layerwise offload (--dit-layerwise-offload) is enabled, the + offload manager replaces GPU parameters with small torch.empty((1,)) + placeholders while real weights live in consolidated pinned CPU + buffers. A naive param.data.copy_() would fail with a shape + mismatch. Instead, this method writes new weights directly into + the CPU buffers, bypassing the placeholders entirely. For any + layer that happens to be resident on GPU at update time, the live + GPU tensor is also updated so the change takes effect immediately. + This requires no extra GPU memory and does not disturb the offload + state. Args: weight_dict: Mapping of parameter name to new weight tensor. @@ -294,7 +304,7 @@ def update_cpu_weights(self, weight_dict: Dict[str, torch.Tensor]) -> Set[str]: metadata (i.e. the real shape, not the placeholder shape). """ if not self.enabled: - return set() + return None updated_names: Set[str] = set() for name, loaded_weight in weight_dict.items(): From 881d8b360e07faa7dae5dd0ce67b91eb895dcb84 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 11 Feb 2026 14:27:33 -0800 Subject: [PATCH 08/30] adds doc string to diffusion rifit test --- .../server/test_update_weights_from_disk.py | 112 +++++++++++++++++- .../test/server/testcase_configs.py | 4 +- 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 8bb7bc1c1ac0..8076e6956396 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -1,8 +1,116 @@ """ Tests for update_weights_from_disk API in SGLang-D (diffusion engine). -This tests the ability to dynamically update model weights without restarting the server, -which is critical for RL workflows and iterative fine-tuning scenarios. +This module verifies the ability to hot update model weights without restarting +the server, which is critical for RL workflows and iterative fine-tuning scenarios. + +Author: + +Menyang Liu, https://github.com/dreamyang-liu +Chenyang Zhao, https://github.com/zhaochenyang20 + +============================================================================= +Test organization: 9 test cases in 3 classes +============================================================================= + +Each test class uses a single long-lived server (pytest fixture with scope="class"). +The server is started once when the first test in that class runs; all tests in the +class share the same server and send multiple POST /update_weights_from_disk +requests to it. This reflects real usage: one running diffusion service, many weight +updates over time. + +Class 1: TestUpdateWeightsFromDisk (7 tests) — API contract & error handling +Class 2: TestUpdateWeightsFromDiskWithOffload (1 test) — Offload-aware update +Class 3: TestUpdateWeightsEndToEnd (1 test) — Generation after update + +----------------------------------------------------------------------------- +Class 1: TestUpdateWeightsFromDisk +----------------------------------------------------------------------------- +Purpose: Validate the update_weights_from_disk API contract, request/response shape, +and error handling. All 7 tests run against one server (fixture: +diffusion_server_for_weight_update). + + • test_update_weights_same_model + Same model path as the one already loaded; must succeed (200, success=True). + Exercises the basic "hot reload same checkpoint" path. + + • test_update_weights_with_flush_cache + Explicit flush_cache=True; must succeed. Ensures the flush_cache parameter + is accepted and applied. + + • test_update_weights_without_flush_cache + Explicit flush_cache=False; must succeed. Ensures updates work when not + flushing TeaCache. + + • test_update_weights_nonexistent_model + model_path set to a non-existent path; must fail (success=False). Verifies + all-or-nothing / rollback semantics when load fails. + + • test_update_weights_missing_model_path + Request body empty (no model_path); must return 400. Validates required + parameter checks. + + • test_update_weights_specific_modules + target_modules=["transformer"]; must return 200. Verifies partial module + update (target_modules parameter). + + • test_update_weights_nonexistent_module + target_modules=["nonexistent_module"]; must return 400 and message containing + "not found in pipeline". Validates rejection of invalid module names. + +----------------------------------------------------------------------------- +Class 2: TestUpdateWeightsFromDiskWithOffload +----------------------------------------------------------------------------- +Purpose: Ensure weight updates work when layerwise offload is enabled +(--dit-layerwise-offload). With offload, parameters live in CPU buffers and +placeholders on GPU; the updater must write into CPU buffers and update +prefetched GPU tensors without shape mismatch. + + • test_update_weights_with_offload_enabled + Server started with --dit-layerwise-offload true. Call update_weights_from_disk + with the same model; must succeed (200, success=True) and message must not + contain "Shape mismatch". + +----------------------------------------------------------------------------- +Class 3: TestUpdateWeightsEndToEnd +----------------------------------------------------------------------------- +Purpose: End-to-end check that the model remains in a consistent, usable state +after a weight update: inference (image generation) works both before and after +the update. + + • test_generation_after_weight_update + (1) Generate an image (e.g. "a beautiful sunset") via /v1/images/generations. + (2) Call POST /update_weights_from_disk (same model, flush_cache=True). + (3) Generate another image (e.g. "a beautiful sunrise"). + Both generations must succeed; this confirms no partial or broken state + after update. + +============================================================================= +Relation to RL scenarios and reference implementation +============================================================================= + +In RL or iterative training, a typical pattern is: + + 1. Run a diffusion (or LLM) server for inference. + 2. Periodically pull new weights (e.g., from a training run or from disk) + without restarting the server. + 3. Continue serving with the updated model. + +The diffusion engine supports this via POST /update_weights_from_disk: it loads +weights from a model_path (HF repo or local) and applies them in-place, with +rollback on failure and support for layerwise offload and DTensor. + +For a distributed RL setup where the training process broadcasts weights to +inference engines (rather than loading from disk), see the SGLang LLM test that +simulates rank 0 as trainer and other ranks as inference engines, using +update_weights_from_distributed and init_weights_update_group: + + https://github.com/sgl-project/sglang/blob/main/test/registered/rl/test_update_weights_from_distributed.py + +That test verifies weight synchronization across ranks (instruct vs base model) +and optional pause_generation/continue_generation during update. This diffusion +test suite focuses on the disk-based update path and offload/consistency +behavior of the diffusion engine only. """ from __future__ import annotations diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index 9379dcc9d5ce..f66b8801cfa0 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -28,6 +28,8 @@ from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord +DEFAULT_SMALL_MODEL = "Tongyi-MAI/Z-Image-Turbo" + @dataclass class ToleranceConfig: @@ -339,8 +341,6 @@ def from_req_perf_record( fps=4, ) -DEFAULT_SMALL_MODEL = "Tongyi-MAI/Z-Image-Turbo" - # All test cases with clean default values # To test different models, simply add more DiffusionCase entries ONE_GPU_CASES_A: list[DiffusionTestCase] = [ From dfc93f6b85fcc5e52e3470b998a02074ec5ad535 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Thu, 12 Feb 2026 09:13:26 +0000 Subject: [PATCH 09/30] [diffusion] Add /get_weights_checksum endpoint for SHA-256 weight verification Supports DTensor and layerwise-offloaded modules by reading real weights from CPU buffers instead of GPU placeholders. Simplifies tests from 3 classes to 2 with checksum-based verification. --- .../entrypoints/post_training/io_struct.py | 7 + .../entrypoints/post_training/weights_api.py | 17 + .../runtime/loader/weight_utils.py | 25 +- .../runtime/loader/weights_updater.py | 1 + .../runtime/managers/gpu_worker.py | 33 +- .../runtime/managers/scheduler.py | 8 + .../runtime/utils/layerwise_offload.py | 47 +++ .../server/test_update_weights_from_disk.py | 295 +++++++++++------- 8 files changed, 324 insertions(+), 109 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py index cad7c7ce927e..749a39e817f8 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -10,3 +10,10 @@ class UpdateWeightFromDiskReqInput: model_path: str flush_cache: bool = True target_modules: list[str] | None = None + + +@dataclass +class GetWeightsChecksumReqInput: + """Request to compute SHA-256 checksum of loaded module weights.""" + + module_names: list[str] | None = None diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py index c1e53a9c47d2..1b9312d8ea0f 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -4,6 +4,7 @@ from fastapi.responses import ORJSONResponse from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, ) from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client @@ -43,3 +44,19 @@ async def update_weights_from_disk(request: Request): {"success": success, "message": message}, status_code=200 if success else 400, ) + + +@router.post("/get_weights_checksum") +async def get_weights_checksum(request: Request): + """Return SHA-256 checksum of each requested module's weights.""" + body = await request.json() + req = GetWeightsChecksumReqInput( + module_names=body.get("module_names"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse({"error": str(e)}, status_code=500) + + return ORJSONResponse(response.output, status_code=200) diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py index 89d22b31e84a..9cdfa28676b3 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -2,18 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py -"""Utilities for downloading and initializing model weights.""" +"""Utilities for downloading, loading, and verifying model weights.""" import hashlib import json import os import tempfile -from collections.abc import Generator +from collections.abc import Generator, Iterable from pathlib import Path import filelock import huggingface_hub.constants import torch from safetensors.torch import safe_open +from torch.distributed.tensor import DTensor from tqdm.auto import tqdm try: @@ -335,3 +336,23 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: # If there were no matches, return the untouched param name return name + + +def compute_weights_checksum( + named_params: Iterable[tuple[str, torch.Tensor]], +) -> str: + """Compute SHA-256 checksum over (name, tensor) pairs. + + Parameters are sorted by name so the digest is deterministic + regardless of iteration order. Raw bytes are hashed directly + (no dtype conversion) for speed and fidelity. + """ + hasher = hashlib.sha256() + for name, tensor in sorted(named_params, key=lambda x: x[0]): + hasher.update(name.encode()) + t = tensor.detach() + # DTensor doesn't support .numpy(); extract the local tensor. + if isinstance(t, DTensor): + t = t._local_tensor + hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data) + return hasher.hexdigest() diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 30e680fae7b0..5aae8fe0fcb2 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -15,6 +15,7 @@ ) if success: self.server_args.model_path = model_path + self.pipeline.model_path = model_path return success, message Key design decisions: diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 75f298f8dbae..18818dc1f737 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -29,7 +29,11 @@ get_ulysses_parallel_world_size, ) from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs -from sglang.multimodal_gen.runtime.loader.weights_updater import WeightsUpdater +from sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum +from sglang.multimodal_gen.runtime.loader.weights_updater import ( + WeightsUpdater, + get_updatable_modules, +) from sglang.multimodal_gen.runtime.pipelines_core import ( ComposedPipelineBase, LoRAPipeline, @@ -40,7 +44,10 @@ from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch -from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.layerwise_offload import ( + OffloadableDiTMixin, + iter_materialized_weights, +) from sglang.multimodal_gen.runtime.utils.logging_utils import ( configure_logger, globally_suppress_loggers, @@ -361,8 +368,30 @@ def update_weights_from_disk( ) if success: self.server_args.model_path = model_path + self.pipeline.model_path = model_path return success, message + def get_weights_checksum( + self, module_names: list[str] | None = None + ) -> dict[str, str]: + """Compute SHA-256 checksum of each module's weights.""" + if not self.pipeline: + return {"error": "Pipeline is not initialized"} + + all_modules = get_updatable_modules(self.pipeline) + names = module_names if module_names is not None else list(all_modules.keys()) + + checksums: dict[str, str] = {} + for name in names: + module = all_modules.get(name) + if module is None: + checksums[name] = "not_found" + continue + checksums[name] = compute_weights_checksum( + iter_materialized_weights(module) + ) + return checksums + OOM_MSG = f""" OOM detected. Possible solutions: diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 304b33d52dd3..b05c543165d5 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -21,6 +21,7 @@ save_image_to_path, ) from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, UpdateWeightFromDiskReqInput, ) from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker @@ -93,6 +94,7 @@ def __init__( ListLorasReq: self._handle_list_loras, ShutdownReq: self._handle_shutdown, UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, + GetWeightsChecksumReqInput: self._handle_get_weights_checksum, } # FIFO, new reqs are appended @@ -146,6 +148,12 @@ def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: error=None if success else message, ) + def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch: + """Handle get_weights_checksum request.""" + req = reqs[0] + checksums = self.worker.get_weights_checksum(module_names=req.module_names) + return OutputBatch(output=checksums) + def _handle_generation(self, reqs: List[Req]): warmup_reqs = [req for req in reqs if req.is_warmup] if warmup_reqs: diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 7c6508de7526..089cc608ab69 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -340,6 +340,24 @@ def update_cpu_weights( return updated_names + def iter_cpu_weights(self): + """Yield (name, tensor) pairs from consolidated CPU buffers. + + This reconstructs the original weight tensors (with correct shapes) + from the flat CPU buffers using stored metadata. Unlike + model.named_parameters(), which returns (1,) placeholders + when offload is enabled, this method returns the real weights and + can be used for checksum computation. + """ + for layer_idx in sorted(self._weight_metadata): + for name, meta in self._weight_metadata[layer_idx].items(): + dtype = meta["dtype"] + offset = meta["offset"] + numel = meta["numel"] + shape = meta["shape"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + yield name, cpu_buffer[offset : offset + numel].reshape(shape) + def register_forward_hooks(self) -> None: if not self.enabled: return @@ -447,3 +465,32 @@ def enable_offload(self) -> None: manager.sync_all_layers_to_cpu() manager.release_all() manager.register_forward_hooks() + + +def iter_materialized_weights(module: torch.nn.Module): + """Yield (name, tensor) pairs with materialized weights, even under offload. + + When layerwise offload is active, module.named_parameters() returns + (1,) placeholders for offloaded layers. This helper reads the + actual data from the offload manager's CPU buffers and chains it with + the non-offloaded parameters so callers always see real tensors. + """ + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + + if not offload_managers: + yield from module.named_parameters() + return + + # Collect offloaded names and their real tensors from CPU buffers. + offloaded_names: set[str] = set() + for manager in offload_managers: + for name, tensor in manager.iter_cpu_weights(): + offloaded_names.add(name) + yield name, tensor + + # Yield non-offloaded parameters (e.g. final norms, embeddings). + for name, param in module.named_parameters(): + if name not in offloaded_names: + yield name, param diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 8076e6956396..5ce7b77c7673 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -10,7 +10,7 @@ Chenyang Zhao, https://github.com/zhaochenyang20 ============================================================================= -Test organization: 9 test cases in 3 classes +Test organization: 9 test cases in 2 classes ============================================================================= Each test class uses a single long-lived server (pytest fixture with scope="class"). @@ -19,20 +19,15 @@ class share the same server and send multiple POST /update_weights_from_disk requests to it. This reflects real usage: one running diffusion service, many weight updates over time. -Class 1: TestUpdateWeightsFromDisk (7 tests) — API contract & error handling -Class 2: TestUpdateWeightsFromDiskWithOffload (1 test) — Offload-aware update -Class 3: TestUpdateWeightsEndToEnd (1 test) — Generation after update +Class 1: TestUpdateWeightsFromDisk (7 tests) — API contract & checksum +Class 2: TestUpdateWeightsFromDiskWithOffload (2 tests) — Offload-aware update ----------------------------------------------------------------------------- Class 1: TestUpdateWeightsFromDisk ----------------------------------------------------------------------------- Purpose: Validate the update_weights_from_disk API contract, request/response shape, -and error handling. All 7 tests run against one server (fixture: -diffusion_server_for_weight_update). - - • test_update_weights_same_model - Same model path as the one already loaded; must succeed (200, success=True). - Exercises the basic "hot reload same checkpoint" path. +error handling, and checksum verification. All 7 tests run against one server +(fixture: diffusion_server_for_weight_update). • test_update_weights_with_flush_cache Explicit flush_cache=True; must succeed. Ensures the flush_cache parameter @@ -58,32 +53,29 @@ class share the same server and send multiple POST /update_weights_from_disk target_modules=["nonexistent_module"]; must return 400 and message containing "not found in pipeline". Validates rejection of invalid module names. + • test_update_weights_checksum_matches + Fetches checksum before update (base model), then updates weights and fetches + checksum again (update model). Verifies the post-update checksum matches the + update model's disk checksum, and differs from the pre-update checksum. + ----------------------------------------------------------------------------- Class 2: TestUpdateWeightsFromDiskWithOffload ----------------------------------------------------------------------------- -Purpose: Ensure weight updates work when layerwise offload is enabled -(--dit-layerwise-offload). With offload, parameters live in CPU buffers and -placeholders on GPU; the updater must write into CPU buffers and update -prefetched GPU tensors without shape mismatch. +Purpose: Ensure weight updates and checksum verification work when layerwise +offload is enabled (--dit-layerwise-offload). With offload, parameters live in +CPU buffers and placeholders on GPU; the updater must write into CPU buffers and +update prefetched GPU tensors without shape mismatch. The checksum endpoint must +read from CPU buffers (not the (1,) placeholders) to produce correct results. • test_update_weights_with_offload_enabled Server started with --dit-layerwise-offload true. Call update_weights_from_disk with the same model; must succeed (200, success=True) and message must not contain "Shape mismatch". ------------------------------------------------------------------------------ -Class 3: TestUpdateWeightsEndToEnd ------------------------------------------------------------------------------ -Purpose: End-to-end check that the model remains in a consistent, usable state -after a weight update: inference (image generation) works both before and after -the update. - - • test_generation_after_weight_update - (1) Generate an image (e.g. "a beautiful sunset") via /v1/images/generations. - (2) Call POST /update_weights_from_disk (same model, flush_cache=True). - (3) Generate another image (e.g. "a beautiful sunrise"). - Both generations must succeed; this confirms no partial or broken state - after update. + • test_update_weights_checksum_matches + Fetches checksum before update (base model), then updates weights and fetches + checksum again (update model). Verifies the post-update checksum matches the + update model's disk checksum, and differs from the pre-update checksum. ============================================================================= Relation to RL scenarios and reference implementation @@ -120,6 +112,15 @@ class share the same server and send multiple POST /update_weights_from_disk import pytest import requests +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, + find_weights_dir, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + compute_weights_checksum, + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger from sglang.multimodal_gen.test.server.test_server_utils import ( ServerContext, @@ -129,11 +130,32 @@ class share the same server and send multiple POST /update_weights_from_disk logger = init_logger(__name__) -# Default model for testing - use a small/fast model, need to be an image diffusion model +# Base model the server starts with DEFAULT_DIFFUSION_MODEL = os.environ.get( - "SGLANG_TEST_DIFFUSION_MODEL", "black-forest-labs/FLUX.2-klein-4B" + "SGLANG_TEST_DIFFUSION_MODEL", "black-forest-labs/FLUX.2-klein-base-4B" ) +# Model used for weight updates (same architecture, different weights) +UPDATE_DIFFUSION_MODEL = os.environ.get( + "SGLANG_TEST_UPDATE_MODEL", "black-forest-labs/FLUX.2-klein-4B" +) + + +def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: + """Compute SHA-256 checksum from safetensors files on disk. + + Uses the same compute_weights_checksum function as the server, + so the checksums are directly comparable. + """ + local_path = maybe_download_model(model_path) + weights_dir = find_weights_dir(local_path, module_name) + assert weights_dir is not None, f"No weights dir for {module_name} in {local_path}" + + safetensors_files = _list_safetensors_files(weights_dir) + assert safetensors_files, f"No safetensors files in {weights_dir}" + + return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) + @pytest.fixture(scope="class") def diffusion_server_for_weight_update(): @@ -185,17 +207,26 @@ def _update_weights( ) return response.json(), response.status_code - def test_update_weights_same_model( - self, diffusion_server_for_weight_update: ServerContext - ): - """Test updating weights with the same model (should succeed).""" - base_url = self._get_base_url(diffusion_server_for_weight_update) - - result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) - logger.info(f"Update result: {result}") + def _get_weights_checksum( + self, + base_url: str, + module_names: list[str] | None = None, + timeout: int = 300, + ) -> dict: + """Call get_weights_checksum API and return the checksum dict.""" + payload = {} + if module_names is not None: + payload["module_names"] = module_names - assert status_code == 200, f"Expected 200, got {status_code}" - assert result.get("success", False), f"Update failed: {result.get('message')}" + response = requests.post( + f"{base_url}/get_weights_checksum", + json=payload, + timeout=timeout, + ) + assert ( + response.status_code == 200 + ), f"get_weights_checksum failed: {response.status_code} {response.text}" + return response.json() def test_update_weights_with_flush_cache( self, diffusion_server_for_weight_update: ServerContext @@ -205,7 +236,7 @@ def test_update_weights_with_flush_cache( result, status_code = self._update_weights( base_url, - DEFAULT_DIFFUSION_MODEL, + UPDATE_DIFFUSION_MODEL, flush_cache=True, ) @@ -220,7 +251,7 @@ def test_update_weights_without_flush_cache( result, status_code = self._update_weights( base_url, - DEFAULT_DIFFUSION_MODEL, + UPDATE_DIFFUSION_MODEL, flush_cache=False, ) @@ -267,7 +298,7 @@ def test_update_weights_specific_modules( # Try to update only transformer module result, status_code = self._update_weights( base_url, - DEFAULT_DIFFUSION_MODEL, + UPDATE_DIFFUSION_MODEL, target_modules=["transformer"], ) logger.info(f"Update specific modules result: {result}") @@ -285,7 +316,7 @@ def test_update_weights_nonexistent_module( result, status_code = self._update_weights( base_url, - DEFAULT_DIFFUSION_MODEL, + UPDATE_DIFFUSION_MODEL, target_modules=["nonexistent_module"], timeout=60, ) @@ -295,6 +326,58 @@ def test_update_weights_nonexistent_module( assert not result.get("success", True), "Should fail for nonexistent module" assert "not found in pipeline" in result.get("message", "") + def test_update_weights_checksum_matches( + self, diffusion_server_for_weight_update: ServerContext + ): + """Verify GPU checksum matches disk after weight update. + + 1. Fetch the pre-update (base model) checksum from the server. + 2. Update weights to a different model. + 3. Fetch the post-update checksum and compare with disk. + 4. Verify post-update checksum differs from pre-update (different model). + """ + base_url = self._get_base_url(diffusion_server_for_weight_update) + + # Update to base model. + result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) + + # Checksum before update (base model already loaded by the fixture). + pre_update_checksum = self._get_weights_checksum( + base_url, module_names=["transformer"] + )["transformer"] + + # Update to a different model. + result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + assert status_code == 200 and result.get( + "success" + ), f"Update failed: {result.get('message')}" + + # Checksum after update — must match the update model on disk. + post_update_checksum = self._get_weights_checksum( + base_url, module_names=["transformer"] + )["transformer"] + update_disk_checksum = _compute_checksum_from_disk( + UPDATE_DIFFUSION_MODEL, "transformer" + ) + + print(f"\n{'='*60}") + print(f"Checksum test") + print(f" pre-update (base): {pre_update_checksum}") + print(f" post-update (gpu): {post_update_checksum}") + print(f" post-update (disk): {update_disk_checksum}") + print(f" gpu == disk: {post_update_checksum == update_disk_checksum}") + print(f" changed: {pre_update_checksum != post_update_checksum}") + print(f"{'='*60}") + + assert post_update_checksum == update_disk_checksum, ( + f"GPU checksum does not match disk checksum for update model\n" + f" disk: {update_disk_checksum}\n" + f" gpu: {post_update_checksum}" + ) + assert ( + pre_update_checksum != post_update_checksum + ), "Checksum did not change after updating to a different model" + class TestUpdateWeightsFromDiskWithOffload: """Test update_weights_from_disk with layerwise offload enabled.""" @@ -341,7 +424,7 @@ def test_update_weights_with_offload_enabled( logger.info("Testing weight update with offload enabled") - result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) logger.info(f"Update result: {result}") assert status_code == 200, f"Expected 200, got {status_code}" @@ -351,76 +434,78 @@ def test_update_weights_with_offload_enabled( message = result.get("message", "") assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" + def _get_weights_checksum( + self, + base_url: str, + module_names: list[str] | None = None, + timeout: int = 300, + ) -> dict: + """Call get_weights_checksum API and return the checksum dict.""" + payload = {} + if module_names is not None: + payload["module_names"] = module_names -class TestUpdateWeightsEndToEnd: - """End-to-end tests: verify generation works after weight update.""" - - @pytest.fixture(scope="class") - def diffusion_server_e2e(self): - """Start a diffusion server for E2E tests.""" - port = get_dynamic_server_port() - wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) - - manager = ServerManager( - model=DEFAULT_DIFFUSION_MODEL, - port=port, - wait_deadline=wait_deadline, - extra_args="--num-gpus 1", + response = requests.post( + f"{base_url}/get_weights_checksum", + json=payload, + timeout=timeout, ) + assert ( + response.status_code == 200 + ), f"get_weights_checksum failed: {response.status_code} {response.text}" + return response.json() - ctx = manager.start() - - try: - yield ctx - finally: - ctx.cleanup() - - def _get_base_url(self, ctx: ServerContext) -> str: - return f"http://localhost:{ctx.port}" + def test_update_weights_checksum_matches( + self, diffusion_server_with_offload: ServerContext + ): + """Verify checksum from offloaded CPU buffers matches disk after update. - def _generate_image(self, base_url: str, prompt: str = "a cat") -> dict: - """Generate an image using the OpenAI-compatible API.""" - from openai import OpenAI + 1. Fetch the pre-update (base model) checksum from the server. + 2. Update weights to a different model. + 3. Fetch the post-update checksum and compare with disk. + 4. Verify post-update checksum differs from pre-update (different model). + """ + base_url = self._get_base_url(diffusion_server_with_offload) - client = OpenAI( - api_key="sglang-test", - base_url=f"{base_url}/v1", - ) + # Update to base model. + result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) - response = client.images.generate( - model="default", - prompt=prompt, - n=1, - size="512x512", - response_format="b64_json", # Avoid needing cloud storage + # Checksum before update (base model already loaded by the fixture). + pre_update_checksum = self._get_weights_checksum( + base_url, module_names=["transformer"] + )["transformer"] + + # Update to a different model. + result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + assert status_code == 200 and result.get( + "success" + ), f"Update failed: {result.get('message')}" + + # Checksum after update — must match the update model on disk. + post_update_checksum = self._get_weights_checksum( + base_url, module_names=["transformer"] + )["transformer"] + update_disk_checksum = _compute_checksum_from_disk( + UPDATE_DIFFUSION_MODEL, "transformer" ) - return response - - def test_generation_after_weight_update(self, diffusion_server_e2e: ServerContext): - """Test that generation still works after updating weights.""" - base_url = self._get_base_url(diffusion_server_e2e) - - # Generate before update - logger.info("Generating image before weight update...") - response_before = self._generate_image(base_url, "a beautiful sunset") - assert response_before.data, "Generation before update failed" - logger.info("Generation before update succeeded") - - # Update weights - update_response = requests.post( - f"{base_url}/update_weights_from_disk", - json={"model_path": DEFAULT_DIFFUSION_MODEL, "flush_cache": True}, - timeout=300, + print(f"\n{'='*60}") + print(f"Offload checksum test") + print(f" pre-update (base): {pre_update_checksum}") + print(f" post-update (gpu): {post_update_checksum}") + print(f" post-update (disk): {update_disk_checksum}") + print(f" gpu == disk: {post_update_checksum == update_disk_checksum}") + print(f" changed: {pre_update_checksum != post_update_checksum}") + print(f"{'='*60}") + + assert post_update_checksum == update_disk_checksum, ( + f"GPU checksum does not match disk checksum for update model\n" + f" disk: {update_disk_checksum}\n" + f" gpu: {post_update_checksum}" ) - assert update_response.json().get("success"), "Weight update failed" - logger.info("Weight update succeeded") - - # Generate after update - logger.info("Generating image after weight update...") - response_after = self._generate_image(base_url, "a beautiful sunrise") - assert response_after.data, "Generation after update failed" - logger.info("Generation after update succeeded") + assert ( + pre_update_checksum != post_update_checksum + ), "Checksum did not change after updating to a different model" if __name__ == "__main__": From 68902f3373d05c4e18747e83fe4bab4f9741c3c5 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Thu, 12 Feb 2026 20:16:45 +0000 Subject: [PATCH 10/30] [diffusion] Add corrupted-weight rollback test for update_weights_from_disk --- .../server/test_update_weights_from_disk.py | 245 +++++++++++++++++- 1 file changed, 244 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 5ce7b77c7673..48bf790ddb33 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -10,7 +10,7 @@ Chenyang Zhao, https://github.com/zhaochenyang20 ============================================================================= -Test organization: 9 test cases in 2 classes +Test organization: 10 test cases in 3 classes ============================================================================= Each test class uses a single long-lived server (pytest fixture with scope="class"). @@ -21,6 +21,7 @@ class share the same server and send multiple POST /update_weights_from_disk Class 1: TestUpdateWeightsFromDisk (7 tests) — API contract & checksum Class 2: TestUpdateWeightsFromDiskWithOffload (2 tests) — Offload-aware update +Class 3: TestUpdateWeightsCorruptedRollback (1 test) — Corrupted weight rollback ----------------------------------------------------------------------------- Class 1: TestUpdateWeightsFromDisk @@ -77,6 +78,25 @@ class share the same server and send multiple POST /update_weights_from_disk checksum again (update model). Verifies the post-update checksum matches the update model's disk checksum, and differs from the pre-update checksum. +----------------------------------------------------------------------------- +Class 3: TestUpdateWeightsCorruptedRollback +----------------------------------------------------------------------------- +Purpose: Verify all-or-nothing rollback semantics when loading corrupted weights. +The test builds a corrupted model directory by copying the base model and +truncating the vae safetensors. It then requests an update with +target_modules=["transformer", "vae"]. The transformer updates successfully +first; the corrupted vae then fails during safetensors validation, triggering a +rollback that restores the transformer to its previous weights. + + • test_corrupted_weights_rollback + 1. Server starts with the base model (DEFAULT_DIFFUSION_MODEL). + 2. Updates weights to the update model (UPDATE_DIFFUSION_MODEL); must succeed. + 3. Records checksums for all modules after the update. + 4. Prepares a corrupted model (base model copy with truncated vae). + 5. Attempts to load the corrupted model; must fail with rollback message. + 6. Verifies all module checksums still match the update model, confirming + the rollback restored every module that was partially updated. + ============================================================================= Relation to RL scenarios and reference implementation ============================================================================= @@ -108,6 +128,8 @@ class share the same server and send multiple POST /update_weights_from_disk from __future__ import annotations import os +import shutil +import tempfile import pytest import requests @@ -508,5 +530,226 @@ def test_update_weights_checksum_matches( ), "Checksum did not change after updating to a different model" +def _prepare_corrupted_model( + src_model: str, dst_model: str, corrupt_module: str +) -> None: + """Build a corrupted model directory from src_model. + + The root-level files (model_index.json, config.json, …) are copied so + that maybe_download_model recognises dst_model as a valid local + model. Every module sub-directory that contains safetensors is copied + verbatim, except corrupt_module whose safetensors are truncated so + that safetensors_weights_iterator detects corruption at load time + (after earlier modules have already been updated), triggering a rollback. + + Must be called before every test attempt because the server deletes + corrupted files on detection. + """ + # Copy root-level files (model_index.json, etc.) so the server + # recognises the directory as a valid local model. + for fname in os.listdir(src_model): + src_path = os.path.join(src_model, fname) + if os.path.isfile(src_path): + shutil.copy2(src_path, os.path.join(dst_model, fname)) + + # Copy module sub-directories; corrupt the designated one. + for module_dir in sorted(os.listdir(src_model)): + src_dir = os.path.join(src_model, module_dir) + if not os.path.isdir(src_dir): + continue + safetensors = [f for f in os.listdir(src_dir) if f.endswith(".safetensors")] + if not safetensors: + # Still copy non-safetensors dirs (config.json, tokenizer, etc.) + dst_dir = os.path.join(dst_model, module_dir) + if not os.path.exists(dst_dir): + shutil.copytree(src_dir, dst_dir) + continue + + dst_dir = os.path.join(dst_model, module_dir) + os.makedirs(dst_dir, exist_ok=True) + + # Copy config.json and other non-safetensors files in the module dir + for f in os.listdir(src_dir): + if not f.endswith(".safetensors"): + src_f = os.path.join(src_dir, f) + if os.path.isfile(src_f): + shutil.copy2(src_f, os.path.join(dst_dir, f)) + + for fname in safetensors: + src_file = os.path.join(src_dir, fname) + dst_file = os.path.join(dst_dir, fname) + shutil.copy2(src_file, dst_file) + + if module_dir == corrupt_module: + # Truncate 1000 bytes from the end to corrupt the tensor data + size = os.path.getsize(dst_file) + with open(dst_file, "r+b") as f: + f.truncate(size - 1000) + logger.info( + "Created corrupted safetensors: %s (%d -> %d bytes)", + dst_file, + size, + size - 1000, + ) + else: + logger.info("Copied valid safetensors: %s", dst_file) + + +class TestUpdateWeightsCorruptedRollback: + """Test that loading corrupted weights triggers rollback to the last good model.""" + + @pytest.fixture(scope="class") + def corrupted_model_dir(self): + """Create a temporary directory for the corrupted model.""" + tmpdir = tempfile.mkdtemp(prefix="sglang_corrupted_model_") + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) + + @pytest.fixture(scope="class") + def diffusion_server_for_rollback(self): + """Start a diffusion server with the base model.""" + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=DEFAULT_DIFFUSION_MODEL, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1", + ) + + ctx = manager.start() + try: + yield ctx + finally: + ctx.cleanup() + + def _get_base_url(self, ctx: ServerContext) -> str: + return f"http://localhost:{ctx.port}" + + def _update_weights( + self, + base_url: str, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + timeout: int = 300, + ) -> tuple[dict, int]: + payload = { + "model_path": model_path, + "flush_cache": flush_cache, + } + if target_modules is not None: + payload["target_modules"] = target_modules + + response = requests.post( + f"{base_url}/update_weights_from_disk", + json=payload, + timeout=timeout, + ) + return response.json(), response.status_code + + def _get_weights_checksum( + self, + base_url: str, + module_names: list[str] | None = None, + timeout: int = 300, + ) -> dict: + payload = {} + if module_names is not None: + payload["module_names"] = module_names + + response = requests.post( + f"{base_url}/get_weights_checksum", + json=payload, + timeout=timeout, + ) + assert ( + response.status_code == 200 + ), f"get_weights_checksum failed: {response.status_code} {response.text}" + return response.json() + + def test_corrupted_weights_rollback( + self, + diffusion_server_for_rollback: ServerContext, + corrupted_model_dir: str, + ): + """Load base → update weights → attempt corrupted → verify rollback. + + Checksums are verified for ALL modules, not just the transformer, + to ensure the entire pipeline is consistent after rollback. + """ + base_url = self._get_base_url(diffusion_server_for_rollback) + + # --- Step 1: Get base-model checksums for all modules --- + base_checksums = self._get_weights_checksum(base_url) + logger.info(f"Base model checksums: {base_checksums}") + + # --- Step 2: Update to the update model --- + result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + assert status_code == 200 + assert result.get( + "success", False + ), f"Weight update failed: {result.get('message')}" + + # --- Step 3: Record update-model checksums for all modules --- + update_checksums = self._get_weights_checksum(base_url) + logger.info(f"Update model checksums: {update_checksums}") + + assert ( + update_checksums != base_checksums + ), "Base and update checksums should differ" + + # --- Step 4: Recreate corrupted model, then attempt load --- + # Copy all modules from the base model (valid), but corrupt only the + # vae. With target_modules=["transformer", "vae"], the transformer + # updates successfully first, then vae fails, giving a meaningful + # rollback that actually restores the transformer. + local_base = maybe_download_model(DEFAULT_DIFFUSION_MODEL) + _prepare_corrupted_model(local_base, corrupted_model_dir, corrupt_module="vae") + + result, status_code = self._update_weights( + base_url, + corrupted_model_dir, + target_modules=["transformer", "vae"], + timeout=120, + ) + logger.info(f"Corrupted update result: status={status_code}, body={result}") + + assert not result.get("success", True), "Loading corrupted weights should fail" + assert ( + "rolled back" in result.get("message", "").lower() + ), f"Expected rollback message, got: {result.get('message')}" + + # --- Step 5: Verify rollback — all module checksums must match update model --- + post_rollback_checksums = self._get_weights_checksum(base_url) + logger.info(f"Post-rollback checksums: {post_rollback_checksums}") + + print(f"\n{'='*80}") + print("Corrupted-weight rollback test (all modules)") + for module in sorted(update_checksums.keys()): + update_cs = update_checksums.get(module, "N/A") + rollback_cs = post_rollback_checksums.get(module, "N/A") + base_cs = base_checksums.get(module, "N/A") + match = "OK" if update_cs == rollback_cs else "MISMATCH" + print(f" [{match}] {module}") + print(f" base: {base_cs}") + print(f" update: {update_cs}") + print(f" rollback: {rollback_cs}") + print(f"{'='*80}") + + for module in update_checksums: + assert post_rollback_checksums.get(module) == update_checksums[module], ( + f"Module '{module}' checksum mismatch after rollback\n" + f" update: {update_checksums[module]}\n" + f" post-rollback: {post_rollback_checksums.get(module)}" + ) + + assert post_rollback_checksums != base_checksums, ( + "Post-rollback checksums should not match base model " + "(rollback target is update model, not base)" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) From 719d31f3d71668371d65c52bf5348690b8434b2d Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Fri, 13 Feb 2026 09:07:54 +0000 Subject: [PATCH 11/30] [diffusion] Parametrize weight-update tests over FLUX and Qwen model pairs - Replace single DEFAULT_DIFFUSION_MODEL / UPDATE_DIFFUSION_MODEL constants with _ALL_MODEL_PAIRS list containing FLUX klein and Qwen Image pairs. - In CI, weighted random selection picks one pair (FLUX 5:1 vs Qwen) to save resources; locally both pairs run. (FLUX took 2-3 minutes, Qwen took 10 minutes to finish its own suite) - Parametrize all three test class fixtures so each model pair gets its own server and isolated corrupted_model_dir (prevents cross-pair symlink contamination). - Use symlinks instead of full copies in _prepare_corrupted_model to save disk space; only the corrupted module's safetensors are physically copied. - Improve docstrings for compute_weights_checksum and WeightsUpdater rollback logic; --- .../runtime/loader/weight_utils.py | 7 +- .../runtime/loader/weights_updater.py | 10 +- .../server/test_update_weights_from_disk.py | 255 ++++++++++-------- 3 files changed, 156 insertions(+), 116 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py index 9cdfa28676b3..6b7f24bafdbe 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -341,7 +341,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: def compute_weights_checksum( named_params: Iterable[tuple[str, torch.Tensor]], ) -> str: - """Compute SHA-256 checksum over (name, tensor) pairs. + """Compute a SHA-256 checksum for a set of (name, tensor) pairs. + + Helper function for verifying the correctness of weight refitting + (update_weights_from_disk). After a refit, callers can compare the + checksum of the in-GPU model weights against the checksum of the + on-disk tensors to confirm they match exactly. Parameters are sorted by name so the digest is deterministic regardless of iteration order. Raw bytes are hashed directly diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 5aae8fe0fcb2..4834523e3d4e 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -20,9 +20,13 @@ Key design decisions: -- All-or-nothing: if any module fails to load, all previously updated - modules are rolled back to the original weights by reloading from - pipeline.model_path. No partial updates are left behind. +- All-or-nothing with rollback: modules are updated sequentially. If + any module fails (shape mismatch, corrupted file, etc.), every module + that was already updated is rolled back by reloading its weights from + pipeline.model_path (the last successfully-loaded checkpoint). On + success, pipeline.model_path is updated to the new model_path so + that future rollbacks target the latest good checkpoint, not the + originally-launched model. - Rollback failures propagate: if rollback itself fails, the exception is not caught so the caller knows the model is in an inconsistent state. diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 48bf790ddb33..f4e63b3790ac 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -128,6 +128,7 @@ class share the same server and send multiple POST /update_weights_from_disk from __future__ import annotations import os +import random import shutil import tempfile @@ -148,19 +149,46 @@ class share the same server and send multiple POST /update_weights_from_disk ServerContext, ServerManager, ) -from sglang.multimodal_gen.test.test_utils import get_dynamic_server_port +from sglang.multimodal_gen.test.test_utils import get_dynamic_server_port, is_in_ci logger = init_logger(__name__) -# Base model the server starts with -DEFAULT_DIFFUSION_MODEL = os.environ.get( - "SGLANG_TEST_DIFFUSION_MODEL", "black-forest-labs/FLUX.2-klein-base-4B" -) +# Model pairs for weight update tests: (default_model, update_model, ci_weight). +# The server starts with default_model; tests update weights to update_model. +# ci_weight controls how likely each pair is to be selected in CI runs. +_ALL_MODEL_PAIRS: list[tuple[str, str, float]] = [ + ( + "black-forest-labs/FLUX.2-klein-base-4B", + "black-forest-labs/FLUX.2-klein-4B", + 5.0, + ), + ( + "Qwen/Qwen-Image", + "Qwen/Qwen-Image-2512", + 1.0, # Qwen Image is large; run it less often in CI. + ), +] + + +def _select_model_pairs() -> list[tuple[str, str]]: + """Return the (default, update) model pairs to test. + + When SGLANG_TEST_DIFFUSION_MODEL / SGLANG_TEST_UPDATE_MODEL env vars + are set, use them as a single explicit pair. Otherwise, run both + pairs locally, or randomly pick one in CI (weighted) to save resources. + """ + default_env = os.environ.get("SGLANG_TEST_DIFFUSION_MODEL") + update_env = os.environ.get("SGLANG_TEST_UPDATE_MODEL") + if default_env and update_env: + return [(default_env, update_env)] + pairs = [(d, u) for d, u, _ in _ALL_MODEL_PAIRS] + if is_in_ci(): + weights = [w for _, _, w in _ALL_MODEL_PAIRS] + return random.choices(pairs, weights=weights, k=1) + return pairs -# Model used for weight updates (same architecture, different weights) -UPDATE_DIFFUSION_MODEL = os.environ.get( - "SGLANG_TEST_UPDATE_MODEL", "black-forest-labs/FLUX.2-klein-4B" -) + +_ACTIVE_MODEL_PAIRS = _select_model_pairs() def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: @@ -179,14 +207,19 @@ def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) -@pytest.fixture(scope="class") -def diffusion_server_for_weight_update(): +@pytest.fixture( + scope="class", + params=_ACTIVE_MODEL_PAIRS, + ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], +) +def diffusion_server_for_weight_update(request): """Start a diffusion server for weight update tests.""" + default_model, update_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) manager = ServerManager( - model=DEFAULT_DIFFUSION_MODEL, + model=default_model, port=port, wait_deadline=wait_deadline, extra_args="--num-gpus 1", @@ -195,7 +228,7 @@ def diffusion_server_for_weight_update(): ctx = manager.start() try: - yield ctx + yield ctx, default_model, update_model finally: ctx.cleanup() @@ -250,15 +283,14 @@ def _get_weights_checksum( ), f"get_weights_checksum failed: {response.status_code} {response.text}" return response.json() - def test_update_weights_with_flush_cache( - self, diffusion_server_for_weight_update: ServerContext - ): + def test_update_weights_with_flush_cache(self, diffusion_server_for_weight_update): """Test updating weights with flush_cache=True.""" - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, _default_model, update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) result, status_code = self._update_weights( base_url, - UPDATE_DIFFUSION_MODEL, + update_model, flush_cache=True, ) @@ -266,25 +298,25 @@ def test_update_weights_with_flush_cache( assert result.get("success", False), f"Update failed: {result.get('message')}" def test_update_weights_without_flush_cache( - self, diffusion_server_for_weight_update: ServerContext + self, diffusion_server_for_weight_update ): """Test updating weights with flush_cache=False.""" - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, _default_model, update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) result, status_code = self._update_weights( base_url, - UPDATE_DIFFUSION_MODEL, + update_model, flush_cache=False, ) assert status_code == 200 assert result.get("success", False), f"Update failed: {result.get('message')}" - def test_update_weights_nonexistent_model( - self, diffusion_server_for_weight_update: ServerContext - ): + def test_update_weights_nonexistent_model(self, diffusion_server_for_weight_update): """Test that updating with non-existent model fails gracefully.""" - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, _default_model, _update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) result, status_code = self._update_weights( base_url, @@ -297,10 +329,11 @@ def test_update_weights_nonexistent_model( assert not result.get("success", True), "Should fail for nonexistent model" def test_update_weights_missing_model_path( - self, diffusion_server_for_weight_update: ServerContext + self, diffusion_server_for_weight_update ): """Test that request without model_path returns 400.""" - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, _default_model, _update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) response = requests.post( f"{base_url}/update_weights_from_disk", @@ -311,16 +344,15 @@ def test_update_weights_missing_model_path( # Should return 400 Bad Request assert response.status_code == 400, f"Expected 400, got {response.status_code}" - def test_update_weights_specific_modules( - self, diffusion_server_for_weight_update: ServerContext - ): + def test_update_weights_specific_modules(self, diffusion_server_for_weight_update): """Test updating only specific modules (e.g., transformer only).""" - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, _default_model, update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) # Try to update only transformer module result, status_code = self._update_weights( base_url, - UPDATE_DIFFUSION_MODEL, + update_model, target_modules=["transformer"], ) logger.info(f"Update specific modules result: {result}") @@ -331,14 +363,15 @@ def test_update_weights_specific_modules( assert status_code == 200 def test_update_weights_nonexistent_module( - self, diffusion_server_for_weight_update: ServerContext + self, diffusion_server_for_weight_update ): """Test that requesting a non-existent module name fails with a clear error.""" - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, _default_model, update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) result, status_code = self._update_weights( base_url, - UPDATE_DIFFUSION_MODEL, + update_model, target_modules=["nonexistent_module"], timeout=60, ) @@ -348,9 +381,7 @@ def test_update_weights_nonexistent_module( assert not result.get("success", True), "Should fail for nonexistent module" assert "not found in pipeline" in result.get("message", "") - def test_update_weights_checksum_matches( - self, diffusion_server_for_weight_update: ServerContext - ): + def test_update_weights_checksum_matches(self, diffusion_server_for_weight_update): """Verify GPU checksum matches disk after weight update. 1. Fetch the pre-update (base model) checksum from the server. @@ -358,10 +389,11 @@ def test_update_weights_checksum_matches( 3. Fetch the post-update checksum and compare with disk. 4. Verify post-update checksum differs from pre-update (different model). """ - base_url = self._get_base_url(diffusion_server_for_weight_update) + ctx, default_model, update_model = diffusion_server_for_weight_update + base_url = self._get_base_url(ctx) # Update to base model. - result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, default_model) # Checksum before update (base model already loaded by the fixture). pre_update_checksum = self._get_weights_checksum( @@ -369,7 +401,7 @@ def test_update_weights_checksum_matches( )["transformer"] # Update to a different model. - result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, update_model) assert status_code == 200 and result.get( "success" ), f"Update failed: {result.get('message')}" @@ -378,9 +410,7 @@ def test_update_weights_checksum_matches( post_update_checksum = self._get_weights_checksum( base_url, module_names=["transformer"] )["transformer"] - update_disk_checksum = _compute_checksum_from_disk( - UPDATE_DIFFUSION_MODEL, "transformer" - ) + update_disk_checksum = _compute_checksum_from_disk(update_model, "transformer") print(f"\n{'='*60}") print(f"Checksum test") @@ -404,14 +434,19 @@ def test_update_weights_checksum_matches( class TestUpdateWeightsFromDiskWithOffload: """Test update_weights_from_disk with layerwise offload enabled.""" - @pytest.fixture(scope="class") - def diffusion_server_with_offload(self): + @pytest.fixture( + scope="class", + params=_ACTIVE_MODEL_PAIRS, + ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], + ) + def diffusion_server_with_offload(self, request): """Start a diffusion server with layerwise offload enabled.""" + default_model, update_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) manager = ServerManager( - model=DEFAULT_DIFFUSION_MODEL, + model=default_model, port=port, wait_deadline=wait_deadline, extra_args="--num-gpus 1 --dit-layerwise-offload true", @@ -420,7 +455,7 @@ def diffusion_server_with_offload(self): ctx = manager.start() try: - yield ctx + yield ctx, default_model, update_model finally: ctx.cleanup() @@ -438,15 +473,14 @@ def _update_weights( ) return response.json(), response.status_code - def test_update_weights_with_offload_enabled( - self, diffusion_server_with_offload: ServerContext - ): + def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): """Test that weight update works correctly when layerwise offload is enabled.""" - base_url = self._get_base_url(diffusion_server_with_offload) + ctx, _default_model, update_model = diffusion_server_with_offload + base_url = self._get_base_url(ctx) logger.info("Testing weight update with offload enabled") - result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, update_model) logger.info(f"Update result: {result}") assert status_code == 200, f"Expected 200, got {status_code}" @@ -477,9 +511,7 @@ def _get_weights_checksum( ), f"get_weights_checksum failed: {response.status_code} {response.text}" return response.json() - def test_update_weights_checksum_matches( - self, diffusion_server_with_offload: ServerContext - ): + def test_update_weights_checksum_matches(self, diffusion_server_with_offload): """Verify checksum from offloaded CPU buffers matches disk after update. 1. Fetch the pre-update (base model) checksum from the server. @@ -487,10 +519,11 @@ def test_update_weights_checksum_matches( 3. Fetch the post-update checksum and compare with disk. 4. Verify post-update checksum differs from pre-update (different model). """ - base_url = self._get_base_url(diffusion_server_with_offload) + ctx, default_model, update_model = diffusion_server_with_offload + base_url = self._get_base_url(ctx) # Update to base model. - result, status_code = self._update_weights(base_url, DEFAULT_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, default_model) # Checksum before update (base model already loaded by the fixture). pre_update_checksum = self._get_weights_checksum( @@ -498,7 +531,7 @@ def test_update_weights_checksum_matches( )["transformer"] # Update to a different model. - result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, update_model) assert status_code == 200 and result.get( "success" ), f"Update failed: {result.get('message')}" @@ -507,9 +540,7 @@ def test_update_weights_checksum_matches( post_update_checksum = self._get_weights_checksum( base_url, module_names=["transformer"] )["transformer"] - update_disk_checksum = _compute_checksum_from_disk( - UPDATE_DIFFUSION_MODEL, "transformer" - ) + update_disk_checksum = _compute_checksum_from_disk(update_model, "transformer") print(f"\n{'='*60}") print(f"Offload checksum test") @@ -535,84 +566,83 @@ def _prepare_corrupted_model( ) -> None: """Build a corrupted model directory from src_model. - The root-level files (model_index.json, config.json, …) are copied so - that maybe_download_model recognises dst_model as a valid local - model. Every module sub-directory that contains safetensors is copied - verbatim, except corrupt_module whose safetensors are truncated so - that safetensors_weights_iterator detects corruption at load time - (after earlier modules have already been updated), triggering a rollback. + Uses symlinks for everything except the corrupt_module directory to + save disk space and time. Only the corrupt_module's safetensors are + physically copied and then truncated so that safetensors_weights_iterator + detects corruption at load time, triggering a rollback. Must be called before every test attempt because the server deletes corrupted files on detection. """ - # Copy root-level files (model_index.json, etc.) so the server - # recognises the directory as a valid local model. + # Symlink root-level files (model_index.json, etc.). for fname in os.listdir(src_model): src_path = os.path.join(src_model, fname) - if os.path.isfile(src_path): - shutil.copy2(src_path, os.path.join(dst_model, fname)) + dst_path = os.path.join(dst_model, fname) + if os.path.isfile(src_path) and not os.path.exists(dst_path): + os.symlink(src_path, dst_path) - # Copy module sub-directories; corrupt the designated one. for module_dir in sorted(os.listdir(src_model)): src_dir = os.path.join(src_model, module_dir) + dst_dir = os.path.join(dst_model, module_dir) if not os.path.isdir(src_dir): continue - safetensors = [f for f in os.listdir(src_dir) if f.endswith(".safetensors")] - if not safetensors: - # Still copy non-safetensors dirs (config.json, tokenizer, etc.) - dst_dir = os.path.join(dst_model, module_dir) + + # Non-corrupted modules: symlink the entire directory. + if module_dir != corrupt_module: if not os.path.exists(dst_dir): - shutil.copytree(src_dir, dst_dir) + os.symlink(src_dir, dst_dir) continue - dst_dir = os.path.join(dst_model, module_dir) + # Corrupted module: create a real directory, symlink non-safetensors + # files, and copy + truncate safetensors files. os.makedirs(dst_dir, exist_ok=True) - - # Copy config.json and other non-safetensors files in the module dir - for f in os.listdir(src_dir): - if not f.endswith(".safetensors"): - src_f = os.path.join(src_dir, f) - if os.path.isfile(src_f): - shutil.copy2(src_f, os.path.join(dst_dir, f)) - - for fname in safetensors: + for fname in os.listdir(src_dir): src_file = os.path.join(src_dir, fname) dst_file = os.path.join(dst_dir, fname) - shutil.copy2(src_file, dst_file) + if not os.path.isfile(src_file): + continue + + if not fname.endswith(".safetensors"): + if not os.path.exists(dst_file): + os.symlink(src_file, dst_file) + continue - if module_dir == corrupt_module: - # Truncate 1000 bytes from the end to corrupt the tensor data - size = os.path.getsize(dst_file) - with open(dst_file, "r+b") as f: - f.truncate(size - 1000) - logger.info( - "Created corrupted safetensors: %s (%d -> %d bytes)", - dst_file, - size, - size - 1000, - ) - else: - logger.info("Copied valid safetensors: %s", dst_file) + # Copy safetensors then truncate to corrupt it. + shutil.copy2(src_file, dst_file) + size = os.path.getsize(dst_file) + with open(dst_file, "r+b") as f: + f.truncate(size - 1000) + logger.info( + "Created corrupted safetensors: %s (%d -> %d bytes)", + dst_file, + size, + size - 1000, + ) class TestUpdateWeightsCorruptedRollback: """Test that loading corrupted weights triggers rollback to the last good model.""" @pytest.fixture(scope="class") - def corrupted_model_dir(self): - """Create a temporary directory for the corrupted model.""" + def corrupted_model_dir(self, diffusion_server_for_rollback): + """Create a separate temporary directory per parametrized model pair.""" tmpdir = tempfile.mkdtemp(prefix="sglang_corrupted_model_") yield tmpdir shutil.rmtree(tmpdir, ignore_errors=True) - @pytest.fixture(scope="class") - def diffusion_server_for_rollback(self): + @pytest.fixture( + scope="class", + params=_ACTIVE_MODEL_PAIRS, + ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], + ) + def diffusion_server_for_rollback(self, request): """Start a diffusion server with the base model.""" + default_model, update_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) manager = ServerManager( - model=DEFAULT_DIFFUSION_MODEL, + model=default_model, port=port, wait_deadline=wait_deadline, extra_args="--num-gpus 1", @@ -620,7 +650,7 @@ def diffusion_server_for_rollback(self): ctx = manager.start() try: - yield ctx + yield ctx, default_model, update_model finally: ctx.cleanup() @@ -671,7 +701,7 @@ def _get_weights_checksum( def test_corrupted_weights_rollback( self, - diffusion_server_for_rollback: ServerContext, + diffusion_server_for_rollback, corrupted_model_dir: str, ): """Load base → update weights → attempt corrupted → verify rollback. @@ -679,14 +709,15 @@ def test_corrupted_weights_rollback( Checksums are verified for ALL modules, not just the transformer, to ensure the entire pipeline is consistent after rollback. """ - base_url = self._get_base_url(diffusion_server_for_rollback) + ctx, default_model, update_model = diffusion_server_for_rollback + base_url = self._get_base_url(ctx) # --- Step 1: Get base-model checksums for all modules --- base_checksums = self._get_weights_checksum(base_url) logger.info(f"Base model checksums: {base_checksums}") # --- Step 2: Update to the update model --- - result, status_code = self._update_weights(base_url, UPDATE_DIFFUSION_MODEL) + result, status_code = self._update_weights(base_url, update_model) assert status_code == 200 assert result.get( "success", False @@ -705,7 +736,7 @@ def test_corrupted_weights_rollback( # vae. With target_modules=["transformer", "vae"], the transformer # updates successfully first, then vae fails, giving a meaningful # rollback that actually restores the transformer. - local_base = maybe_download_model(DEFAULT_DIFFUSION_MODEL) + local_base = maybe_download_model(default_model) _prepare_corrupted_model(local_base, corrupted_model_dir, corrupt_module="vae") result, status_code = self._update_weights( From 81d585b43e530c0bb589b614d728419678ff1d89 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Fri, 13 Feb 2026 23:39:28 +0000 Subject: [PATCH 12/30] [diffusion] Optimize weight-update tests --- .../runtime/loader/weights_updater.py | 8 +- .../server/test_update_weights_from_disk.py | 625 ++++++++---------- 2 files changed, 281 insertions(+), 352 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 4834523e3d4e..f4422b74bc2f 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -209,13 +209,15 @@ def _apply_weights( _load_weights_into_module(module, weights_iter) updated_modules.append(module_name) except Exception as e: + rollback_list = updated_modules + [module_name] logger.error( f"Weight update failed for module '{module_name}': {e}. " - f"Rolling back {len(updated_modules)} already updated module(s): " - f"{updated_modules}.", + f"Rolling back {len(rollback_list)} module(s) " + f"(including partially-loaded '{module_name}'): " + f"{rollback_list}.", exc_info=True, ) - self._rollback(updated_modules) + self._rollback(rollback_list) return False, ( f"Failed to update module '{module_name}': {e}. " f"All modules rolled back to original weights." diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index f4e63b3790ac..f91b11792f24 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -10,25 +10,24 @@ Chenyang Zhao, https://github.com/zhaochenyang20 ============================================================================= -Test organization: 10 test cases in 3 classes +Test organization: 10 test cases in 2 classes ============================================================================= -Each test class uses a single long-lived server (pytest fixture with scope="class"). -The server is started once when the first test in that class runs; all tests in the -class share the same server and send multiple POST /update_weights_from_disk -requests to it. This reflects real usage: one running diffusion service, many weight -updates over time. +Class 1 uses a class-scoped server fixture (diffusion_server_no_offload) that +is torn down when the class ends, freeing both the port and GPU memory before +Class 2 starts its own offload-enabled server on the same port. -Class 1: TestUpdateWeightsFromDisk (7 tests) — API contract & checksum +Class 1: TestUpdateWeightsFromDisk (8 tests) — API contract, checksum & rollback Class 2: TestUpdateWeightsFromDiskWithOffload (2 tests) — Offload-aware update -Class 3: TestUpdateWeightsCorruptedRollback (1 test) — Corrupted weight rollback + +Tests are ordered lighter-first so developers get fast feedback during iteration. ----------------------------------------------------------------------------- Class 1: TestUpdateWeightsFromDisk ----------------------------------------------------------------------------- Purpose: Validate the update_weights_from_disk API contract, request/response shape, -error handling, and checksum verification. All 7 tests run against one server -(fixture: diffusion_server_for_weight_update). +error handling, checksum verification, and corrupted-weight rollback. All 8 tests +run against one server (fixture: diffusion_server_no_offload). • test_update_weights_with_flush_cache Explicit flush_cache=True; must succeed. Ensures the flush_cache parameter @@ -55,9 +54,16 @@ class share the same server and send multiple POST /update_weights_from_disk "not found in pipeline". Validates rejection of invalid module names. • test_update_weights_checksum_matches - Fetches checksum before update (base model), then updates weights and fetches - checksum again (update model). Verifies the post-update checksum matches the - update model's disk checksum, and differs from the pre-update checksum. + Updates weights to the update model. Verifies the post-update checksum + matches the update model's disk checksum. + + • test_corrupted_weights_rollback + Verify all-or-nothing rollback semantics when loading corrupted weights. + Builds a corrupted model directory by copying the base model and truncating + the vae safetensors. Requests an update with target_modules=["transformer", + "vae"]. The transformer updates successfully first; the corrupted vae then + fails during safetensors validation, triggering a rollback that restores + the transformer to its previous weights. ----------------------------------------------------------------------------- Class 2: TestUpdateWeightsFromDiskWithOffload @@ -78,25 +84,6 @@ class share the same server and send multiple POST /update_weights_from_disk checksum again (update model). Verifies the post-update checksum matches the update model's disk checksum, and differs from the pre-update checksum. ------------------------------------------------------------------------------ -Class 3: TestUpdateWeightsCorruptedRollback ------------------------------------------------------------------------------ -Purpose: Verify all-or-nothing rollback semantics when loading corrupted weights. -The test builds a corrupted model directory by copying the base model and -truncating the vae safetensors. It then requests an update with -target_modules=["transformer", "vae"]. The transformer updates successfully -first; the corrupted vae then fails during safetensors validation, triggering a -rollback that restores the transformer to its previous weights. - - • test_corrupted_weights_rollback - 1. Server starts with the base model (DEFAULT_DIFFUSION_MODEL). - 2. Updates weights to the update model (UPDATE_DIFFUSION_MODEL); must succeed. - 3. Records checksums for all modules after the update. - 4. Prepares a corrupted model (base model copy with truncated vae). - 5. Attempts to load the corrupted model; must fail with rollback message. - 6. Verifies all module checksums still match the update model, confirming - the rollback restored every module that was partially updated. - ============================================================================= Relation to RL scenarios and reference implementation ============================================================================= @@ -127,10 +114,12 @@ class share the same server and send multiple POST /update_weights_from_disk from __future__ import annotations +import functools import os import random import shutil import tempfile +import threading import pytest import requests @@ -191,11 +180,15 @@ def _select_model_pairs() -> list[tuple[str, str]]: _ACTIVE_MODEL_PAIRS = _select_model_pairs() +@functools.lru_cache(maxsize=None) def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: """Compute SHA-256 checksum from safetensors files on disk. Uses the same compute_weights_checksum function as the server, so the checksums are directly comparable. + + Results are cached (keyed on model_path and module_name) because the + same disk checksum is requested multiple times across tests. """ local_path = maybe_download_model(model_path) weights_dir = find_weights_dir(local_path, module_name) @@ -207,34 +200,120 @@ def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) -@pytest.fixture( - scope="class", - params=_ACTIVE_MODEL_PAIRS, - ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], -) -def diffusion_server_for_weight_update(request): - """Start a diffusion server for weight update tests.""" - default_model, update_model = request.param - port = get_dynamic_server_port() - wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) - - manager = ServerManager( - model=default_model, - port=port, - wait_deadline=wait_deadline, - extra_args="--num-gpus 1", - ) +def _prepare_corrupted_model( + src_model: str, dst_model: str, corrupt_module: str +) -> None: + """Build a corrupted model directory from src_model. + + Uses symlinks for everything except the corrupt_module directory to + save disk space and time. Only the corrupt_module's safetensors are + physically copied and then truncated so that safetensors_weights_iterator + detects corruption at load time, triggering a rollback. + + Must be called before every test attempt because the server deletes + corrupted files on detection. + """ + # Symlink root-level files (model_index.json, etc.). + for fname in os.listdir(src_model): + src_path = os.path.join(src_model, fname) + dst_path = os.path.join(dst_model, fname) + if os.path.isfile(src_path) and not os.path.exists(dst_path): + os.symlink(src_path, dst_path) + + for module_dir in sorted(os.listdir(src_model)): + src_dir = os.path.join(src_model, module_dir) + dst_dir = os.path.join(dst_model, module_dir) + if not os.path.isdir(src_dir): + continue + + # Non-corrupted modules: symlink the entire directory. + if module_dir != corrupt_module: + if not os.path.exists(dst_dir): + os.symlink(src_dir, dst_dir) + continue + + # Corrupted module: create a real directory, symlink non-safetensors + # files, and copy + truncate safetensors files. + os.makedirs(dst_dir, exist_ok=True) + for fname in os.listdir(src_dir): + src_file = os.path.join(src_dir, fname) + dst_file = os.path.join(dst_dir, fname) + if not os.path.isfile(src_file): + continue - ctx = manager.start() + if not fname.endswith(".safetensors"): + if not os.path.exists(dst_file): + os.symlink(src_file, dst_file) + continue - try: - yield ctx, default_model, update_model - finally: - ctx.cleanup() + # Copy safetensors then truncate to corrupt it. + shutil.copy2(src_file, dst_file) + size = os.path.getsize(dst_file) + with open(dst_file, "r+b") as f: + f.truncate(size - 1000) + logger.info( + "Created corrupted safetensors: %s (%d -> %d bytes)", + dst_file, + size, + size - 1000, + ) class TestUpdateWeightsFromDisk: - """Test suite for update_weights_from_disk API.""" + """Test suite for update_weights_from_disk API and corrupted-weight rollback. + + Uses a class-scoped server fixture so the server is torn down at class end, + freeing the port and GPU memory before the offload class starts. + """ + + @pytest.fixture( + scope="class", + params=_ACTIVE_MODEL_PAIRS, + ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], + ) + def diffusion_server_no_offload(self, request): + """Start a diffusion server (no offload) for this test class. + + Precomputes disk checksums for the update model in background threads + while the server is starting, so they are already cached (via lru_cache) + by the time tests need them. + """ + default_model, update_model = request.param + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=default_model, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1", + ) + + # Warm the lru_cache while the server boots (disk I/O is independent). + checksum_threads = [ + threading.Thread( + target=_compute_checksum_from_disk, args=(update_model, module) + ) + for module in ("transformer", "vae") + ] + for t in checksum_threads: + t.start() + + ctx = manager.start() + for t in checksum_threads: + t.join() + + try: + yield ctx, default_model, update_model + finally: + ctx.cleanup() + + @pytest.fixture(scope="class") + def corrupted_model_dir(self, diffusion_server_no_offload): + """Create a separate temporary directory per parametrized model pair.""" + tmpdir = tempfile.mkdtemp(prefix="sglang_corrupted_model_") + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) def _get_base_url(self, ctx: ServerContext) -> str: return f"http://localhost:{ctx.port}" @@ -246,7 +325,7 @@ def _update_weights( flush_cache: bool = True, target_modules: list[str] | None = None, timeout: int = 300, - ) -> dict: + ) -> tuple[dict, int]: """Call update_weights_from_disk API.""" payload = { "model_path": model_path, @@ -283,9 +362,9 @@ def _get_weights_checksum( ), f"get_weights_checksum failed: {response.status_code} {response.text}" return response.json() - def test_update_weights_with_flush_cache(self, diffusion_server_for_weight_update): + def test_update_weights_with_flush_cache(self, diffusion_server_no_offload): """Test updating weights with flush_cache=True.""" - ctx, _default_model, update_model = diffusion_server_for_weight_update + ctx, _default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) result, status_code = self._update_weights( @@ -297,11 +376,9 @@ def test_update_weights_with_flush_cache(self, diffusion_server_for_weight_updat assert status_code == 200 assert result.get("success", False), f"Update failed: {result.get('message')}" - def test_update_weights_without_flush_cache( - self, diffusion_server_for_weight_update - ): + def test_update_weights_without_flush_cache(self, diffusion_server_no_offload): """Test updating weights with flush_cache=False.""" - ctx, _default_model, update_model = diffusion_server_for_weight_update + ctx, _default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) result, status_code = self._update_weights( @@ -313,9 +390,9 @@ def test_update_weights_without_flush_cache( assert status_code == 200 assert result.get("success", False), f"Update failed: {result.get('message')}" - def test_update_weights_nonexistent_model(self, diffusion_server_for_weight_update): + def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): """Test that updating with non-existent model fails gracefully.""" - ctx, _default_model, _update_model = diffusion_server_for_weight_update + ctx, _default_model, _update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) result, status_code = self._update_weights( @@ -328,11 +405,9 @@ def test_update_weights_nonexistent_model(self, diffusion_server_for_weight_upda # Should fail gracefully assert not result.get("success", True), "Should fail for nonexistent model" - def test_update_weights_missing_model_path( - self, diffusion_server_for_weight_update - ): + def test_update_weights_missing_model_path(self, diffusion_server_no_offload): """Test that request without model_path returns 400.""" - ctx, _default_model, _update_model = diffusion_server_for_weight_update + ctx, _default_model, _update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) response = requests.post( @@ -344,9 +419,9 @@ def test_update_weights_missing_model_path( # Should return 400 Bad Request assert response.status_code == 400, f"Expected 400, got {response.status_code}" - def test_update_weights_specific_modules(self, diffusion_server_for_weight_update): + def test_update_weights_specific_modules(self, diffusion_server_no_offload): """Test updating only specific modules (e.g., transformer only).""" - ctx, _default_model, update_model = diffusion_server_for_weight_update + ctx, _default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) # Try to update only transformer module @@ -362,11 +437,9 @@ def test_update_weights_specific_modules(self, diffusion_server_for_weight_updat # The test verifies the API handles target_modules parameter assert status_code == 200 - def test_update_weights_nonexistent_module( - self, diffusion_server_for_weight_update - ): + def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): """Test that requesting a non-existent module name fails with a clear error.""" - ctx, _default_model, update_model = diffusion_server_for_weight_update + ctx, _default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) result, status_code = self._update_weights( @@ -381,54 +454,138 @@ def test_update_weights_nonexistent_module( assert not result.get("success", True), "Should fail for nonexistent module" assert "not found in pipeline" in result.get("message", "") - def test_update_weights_checksum_matches(self, diffusion_server_for_weight_update): + def test_update_weights_checksum_matches(self, diffusion_server_no_offload): """Verify GPU checksum matches disk after weight update. - 1. Fetch the pre-update (base model) checksum from the server. - 2. Update weights to a different model. - 3. Fetch the post-update checksum and compare with disk. - 4. Verify post-update checksum differs from pre-update (different model). + Resets to the base model first (shared fixture may be in any state), + then updates to the update model and compares the server-side + checksum with the disk checksum. """ - ctx, default_model, update_model = diffusion_server_for_weight_update + ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) - # Update to base model. - result, status_code = self._update_weights(base_url, default_model) - - # Checksum before update (base model already loaded by the fixture). - pre_update_checksum = self._get_weights_checksum( - base_url, module_names=["transformer"] - )["transformer"] + # Reset to base model so the subsequent update is a real change. + self._update_weights(base_url, default_model) - # Update to a different model. result, status_code = self._update_weights(base_url, update_model) assert status_code == 200 and result.get( "success" ), f"Update failed: {result.get('message')}" - # Checksum after update — must match the update model on disk. - post_update_checksum = self._get_weights_checksum( + gpu_checksum = self._get_weights_checksum( base_url, module_names=["transformer"] )["transformer"] - update_disk_checksum = _compute_checksum_from_disk(update_model, "transformer") + disk_checksum = _compute_checksum_from_disk(update_model, "transformer") print(f"\n{'='*60}") print(f"Checksum test") - print(f" pre-update (base): {pre_update_checksum}") - print(f" post-update (gpu): {post_update_checksum}") - print(f" post-update (disk): {update_disk_checksum}") - print(f" gpu == disk: {post_update_checksum == update_disk_checksum}") - print(f" changed: {pre_update_checksum != post_update_checksum}") + print(f" gpu: {gpu_checksum}") + print(f" disk: {disk_checksum}") + print(f" match: {gpu_checksum == disk_checksum}") print(f"{'='*60}") - assert post_update_checksum == update_disk_checksum, ( + assert gpu_checksum == disk_checksum, ( f"GPU checksum does not match disk checksum for update model\n" - f" disk: {update_disk_checksum}\n" - f" gpu: {post_update_checksum}" + f" disk: {disk_checksum}\n" + f" gpu: {gpu_checksum}" + ) + + def test_corrupted_weights_rollback( + self, + diffusion_server_no_offload, + corrupted_model_dir: str, + ): + """Load base -> update weights -> attempt corrupted -> verify rollback. + + Checksums are restricted to ["transformer", "vae"] — the modules + involved in the partial update — to avoid computing checksums for + unrelated modules. + """ + ctx, default_model, update_model = diffusion_server_no_offload + base_url = self._get_base_url(ctx) + rollback_modules = ["transformer", "vae"] + + # --- Step 0: Reset to default model --- + # Previous tests may have left the server on a different model. + result, status_code = self._update_weights(base_url, default_model) + assert status_code == 200 and result.get( + "success" + ), f"Failed to reset to default model: {result.get('message')}" + + # --- Step 1: Get base-model checksums for rollback modules --- + base_checksums = self._get_weights_checksum( + base_url, module_names=rollback_modules ) + logger.info(f"Base model checksums: {base_checksums}") + + # --- Step 2: Update to the update model --- + result, status_code = self._update_weights(base_url, update_model) + assert status_code == 200 + assert result.get( + "success", False + ), f"Weight update failed: {result.get('message')}" + + # --- Step 3: Record update-model checksums for rollback modules --- + update_checksums = self._get_weights_checksum( + base_url, module_names=rollback_modules + ) + logger.info(f"Update model checksums: {update_checksums}") + assert ( - pre_update_checksum != post_update_checksum - ), "Checksum did not change after updating to a different model" + update_checksums != base_checksums + ), "Base and update checksums should differ" + + # --- Step 4: Recreate corrupted model, then attempt load --- + # Copy all modules from the base model (valid), but corrupt only the + # vae. With target_modules=["transformer", "vae"], the transformer + # updates successfully first, then vae fails, giving a meaningful + # rollback that actually restores the transformer. + local_base = maybe_download_model(default_model) + _prepare_corrupted_model(local_base, corrupted_model_dir, corrupt_module="vae") + + result, status_code = self._update_weights( + base_url, + corrupted_model_dir, + target_modules=rollback_modules, + timeout=120, + ) + logger.info(f"Corrupted update result: status={status_code}, body={result}") + + assert not result.get("success", True), "Loading corrupted weights should fail" + assert ( + "rolled back" in result.get("message", "").lower() + ), f"Expected rollback message, got: {result.get('message')}" + + # --- Step 5: Verify rollback — rollback module checksums must match update model --- + post_rollback_checksums = self._get_weights_checksum( + base_url, module_names=rollback_modules + ) + logger.info(f"Post-rollback checksums: {post_rollback_checksums}") + + print(f"\n{'='*80}") + print("Corrupted-weight rollback test (transformer, vae)") + for module in sorted(update_checksums.keys()): + update_cs = update_checksums.get(module, "N/A") + rollback_cs = post_rollback_checksums.get(module, "N/A") + base_cs = base_checksums.get(module, "N/A") + match = "OK" if update_cs == rollback_cs else "MISMATCH" + print(f" [{match}] {module}") + print(f" base: {base_cs}") + print(f" update: {update_cs}") + print(f" rollback: {rollback_cs}") + print(f"{'='*80}") + + for module in update_checksums: + assert post_rollback_checksums.get(module) == update_checksums[module], ( + f"Module '{module}' checksum mismatch after rollback\n" + f" update: {update_checksums[module]}\n" + f" post-rollback: {post_rollback_checksums.get(module)}" + ) + + assert post_rollback_checksums != base_checksums, ( + "Post-rollback checksums should not match base model " + "(rollback target is update model, not base)" + ) class TestUpdateWeightsFromDiskWithOffload: @@ -440,7 +597,11 @@ class TestUpdateWeightsFromDiskWithOffload: ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], ) def diffusion_server_with_offload(self, request): - """Start a diffusion server with layerwise offload enabled.""" + """Start a diffusion server with layerwise offload enabled. + + Disk checksums are already cached by diffusion_server_no_offload + (which runs first), so no background precomputation is needed here. + """ default_model, update_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) @@ -514,271 +675,37 @@ def _get_weights_checksum( def test_update_weights_checksum_matches(self, diffusion_server_with_offload): """Verify checksum from offloaded CPU buffers matches disk after update. - 1. Fetch the pre-update (base model) checksum from the server. - 2. Update weights to a different model. - 3. Fetch the post-update checksum and compare with disk. - 4. Verify post-update checksum differs from pre-update (different model). + Resets to the base model first, then updates to the update model + and compares the server-side checksum (read from CPU buffers) with + the disk checksum. """ ctx, default_model, update_model = diffusion_server_with_offload base_url = self._get_base_url(ctx) - # Update to base model. - result, status_code = self._update_weights(base_url, default_model) - - # Checksum before update (base model already loaded by the fixture). - pre_update_checksum = self._get_weights_checksum( - base_url, module_names=["transformer"] - )["transformer"] + # Reset to base model so the subsequent update is a real change. + self._update_weights(base_url, default_model) - # Update to a different model. result, status_code = self._update_weights(base_url, update_model) assert status_code == 200 and result.get( "success" ), f"Update failed: {result.get('message')}" - # Checksum after update — must match the update model on disk. - post_update_checksum = self._get_weights_checksum( + gpu_checksum = self._get_weights_checksum( base_url, module_names=["transformer"] )["transformer"] - update_disk_checksum = _compute_checksum_from_disk(update_model, "transformer") + disk_checksum = _compute_checksum_from_disk(update_model, "transformer") print(f"\n{'='*60}") print(f"Offload checksum test") - print(f" pre-update (base): {pre_update_checksum}") - print(f" post-update (gpu): {post_update_checksum}") - print(f" post-update (disk): {update_disk_checksum}") - print(f" gpu == disk: {post_update_checksum == update_disk_checksum}") - print(f" changed: {pre_update_checksum != post_update_checksum}") + print(f" gpu: {gpu_checksum}") + print(f" disk: {disk_checksum}") + print(f" match: {gpu_checksum == disk_checksum}") print(f"{'='*60}") - assert post_update_checksum == update_disk_checksum, ( + assert gpu_checksum == disk_checksum, ( f"GPU checksum does not match disk checksum for update model\n" - f" disk: {update_disk_checksum}\n" - f" gpu: {post_update_checksum}" - ) - assert ( - pre_update_checksum != post_update_checksum - ), "Checksum did not change after updating to a different model" - - -def _prepare_corrupted_model( - src_model: str, dst_model: str, corrupt_module: str -) -> None: - """Build a corrupted model directory from src_model. - - Uses symlinks for everything except the corrupt_module directory to - save disk space and time. Only the corrupt_module's safetensors are - physically copied and then truncated so that safetensors_weights_iterator - detects corruption at load time, triggering a rollback. - - Must be called before every test attempt because the server deletes - corrupted files on detection. - """ - # Symlink root-level files (model_index.json, etc.). - for fname in os.listdir(src_model): - src_path = os.path.join(src_model, fname) - dst_path = os.path.join(dst_model, fname) - if os.path.isfile(src_path) and not os.path.exists(dst_path): - os.symlink(src_path, dst_path) - - for module_dir in sorted(os.listdir(src_model)): - src_dir = os.path.join(src_model, module_dir) - dst_dir = os.path.join(dst_model, module_dir) - if not os.path.isdir(src_dir): - continue - - # Non-corrupted modules: symlink the entire directory. - if module_dir != corrupt_module: - if not os.path.exists(dst_dir): - os.symlink(src_dir, dst_dir) - continue - - # Corrupted module: create a real directory, symlink non-safetensors - # files, and copy + truncate safetensors files. - os.makedirs(dst_dir, exist_ok=True) - for fname in os.listdir(src_dir): - src_file = os.path.join(src_dir, fname) - dst_file = os.path.join(dst_dir, fname) - if not os.path.isfile(src_file): - continue - - if not fname.endswith(".safetensors"): - if not os.path.exists(dst_file): - os.symlink(src_file, dst_file) - continue - - # Copy safetensors then truncate to corrupt it. - shutil.copy2(src_file, dst_file) - size = os.path.getsize(dst_file) - with open(dst_file, "r+b") as f: - f.truncate(size - 1000) - logger.info( - "Created corrupted safetensors: %s (%d -> %d bytes)", - dst_file, - size, - size - 1000, - ) - - -class TestUpdateWeightsCorruptedRollback: - """Test that loading corrupted weights triggers rollback to the last good model.""" - - @pytest.fixture(scope="class") - def corrupted_model_dir(self, diffusion_server_for_rollback): - """Create a separate temporary directory per parametrized model pair.""" - tmpdir = tempfile.mkdtemp(prefix="sglang_corrupted_model_") - yield tmpdir - shutil.rmtree(tmpdir, ignore_errors=True) - - @pytest.fixture( - scope="class", - params=_ACTIVE_MODEL_PAIRS, - ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], - ) - def diffusion_server_for_rollback(self, request): - """Start a diffusion server with the base model.""" - default_model, update_model = request.param - port = get_dynamic_server_port() - wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) - - manager = ServerManager( - model=default_model, - port=port, - wait_deadline=wait_deadline, - extra_args="--num-gpus 1", - ) - - ctx = manager.start() - try: - yield ctx, default_model, update_model - finally: - ctx.cleanup() - - def _get_base_url(self, ctx: ServerContext) -> str: - return f"http://localhost:{ctx.port}" - - def _update_weights( - self, - base_url: str, - model_path: str, - flush_cache: bool = True, - target_modules: list[str] | None = None, - timeout: int = 300, - ) -> tuple[dict, int]: - payload = { - "model_path": model_path, - "flush_cache": flush_cache, - } - if target_modules is not None: - payload["target_modules"] = target_modules - - response = requests.post( - f"{base_url}/update_weights_from_disk", - json=payload, - timeout=timeout, - ) - return response.json(), response.status_code - - def _get_weights_checksum( - self, - base_url: str, - module_names: list[str] | None = None, - timeout: int = 300, - ) -> dict: - payload = {} - if module_names is not None: - payload["module_names"] = module_names - - response = requests.post( - f"{base_url}/get_weights_checksum", - json=payload, - timeout=timeout, - ) - assert ( - response.status_code == 200 - ), f"get_weights_checksum failed: {response.status_code} {response.text}" - return response.json() - - def test_corrupted_weights_rollback( - self, - diffusion_server_for_rollback, - corrupted_model_dir: str, - ): - """Load base → update weights → attempt corrupted → verify rollback. - - Checksums are verified for ALL modules, not just the transformer, - to ensure the entire pipeline is consistent after rollback. - """ - ctx, default_model, update_model = diffusion_server_for_rollback - base_url = self._get_base_url(ctx) - - # --- Step 1: Get base-model checksums for all modules --- - base_checksums = self._get_weights_checksum(base_url) - logger.info(f"Base model checksums: {base_checksums}") - - # --- Step 2: Update to the update model --- - result, status_code = self._update_weights(base_url, update_model) - assert status_code == 200 - assert result.get( - "success", False - ), f"Weight update failed: {result.get('message')}" - - # --- Step 3: Record update-model checksums for all modules --- - update_checksums = self._get_weights_checksum(base_url) - logger.info(f"Update model checksums: {update_checksums}") - - assert ( - update_checksums != base_checksums - ), "Base and update checksums should differ" - - # --- Step 4: Recreate corrupted model, then attempt load --- - # Copy all modules from the base model (valid), but corrupt only the - # vae. With target_modules=["transformer", "vae"], the transformer - # updates successfully first, then vae fails, giving a meaningful - # rollback that actually restores the transformer. - local_base = maybe_download_model(default_model) - _prepare_corrupted_model(local_base, corrupted_model_dir, corrupt_module="vae") - - result, status_code = self._update_weights( - base_url, - corrupted_model_dir, - target_modules=["transformer", "vae"], - timeout=120, - ) - logger.info(f"Corrupted update result: status={status_code}, body={result}") - - assert not result.get("success", True), "Loading corrupted weights should fail" - assert ( - "rolled back" in result.get("message", "").lower() - ), f"Expected rollback message, got: {result.get('message')}" - - # --- Step 5: Verify rollback — all module checksums must match update model --- - post_rollback_checksums = self._get_weights_checksum(base_url) - logger.info(f"Post-rollback checksums: {post_rollback_checksums}") - - print(f"\n{'='*80}") - print("Corrupted-weight rollback test (all modules)") - for module in sorted(update_checksums.keys()): - update_cs = update_checksums.get(module, "N/A") - rollback_cs = post_rollback_checksums.get(module, "N/A") - base_cs = base_checksums.get(module, "N/A") - match = "OK" if update_cs == rollback_cs else "MISMATCH" - print(f" [{match}] {module}") - print(f" base: {base_cs}") - print(f" update: {update_cs}") - print(f" rollback: {rollback_cs}") - print(f"{'='*80}") - - for module in update_checksums: - assert post_rollback_checksums.get(module) == update_checksums[module], ( - f"Module '{module}' checksum mismatch after rollback\n" - f" update: {update_checksums[module]}\n" - f" post-rollback: {post_rollback_checksums.get(module)}" - ) - - assert post_rollback_checksums != base_checksums, ( - "Post-rollback checksums should not match base model " - "(rollback target is update model, not base)" + f" disk: {disk_checksum}\n" + f" gpu: {gpu_checksum}" ) From 32a743b5ae11663f49a4e53522576f809f027de6 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Fri, 13 Feb 2026 21:39:17 -0800 Subject: [PATCH 13/30] [TODO] model weights is updated only once --- .../server/test_update_weights_from_disk.py | 142 ++++++++++++++---- 1 file changed, 112 insertions(+), 30 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index f91b11792f24..e9c3925a8253 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -1,63 +1,83 @@ """ -Tests for update_weights_from_disk API in SGLang-D (diffusion engine). +Tests for update_weights_from_disk API in SGLang diffusion server. -This module verifies the ability to hot update model weights without restarting +This module verifies the ability to update model weights in place without restarting the server, which is critical for RL workflows and iterative fine-tuning scenarios. +We use two model pairs for testing (before / after update model pairs): + +- FLUX.2-klein-base-4B / FLUX.2-klein-4B +- Qwen/Qwen-Image / Qwen/Qwen-Image-2512 + +These models are with the same model architecture and different number +of parameters. Only weights are different. + Author: Menyang Liu, https://github.com/dreamyang-liu Chenyang Zhao, https://github.com/zhaochenyang20 ============================================================================= -Test organization: 10 test cases in 2 classes -============================================================================= -Class 1 uses a class-scoped server fixture (diffusion_server_no_offload) that -is torn down when the class ends, freeing both the port and GPU memory before -Class 2 starts its own offload-enabled server on the same port. +Test organization: + +10 test cases in 2 classes; +two model pairs are tested locally, one in CI. + +============================================================================= Class 1: TestUpdateWeightsFromDisk (8 tests) — API contract, checksum & rollback Class 2: TestUpdateWeightsFromDiskWithOffload (2 tests) — Offload-aware update -Tests are ordered lighter-first so developers get fast feedback during iteration. - ----------------------------------------------------------------------------- + Class 1: TestUpdateWeightsFromDisk ------------------------------------------------------------------------------ -Purpose: Validate the update_weights_from_disk API contract, request/response shape, -error handling, checksum verification, and corrupted-weight rollback. All 8 tests -run against one server (fixture: diffusion_server_no_offload). + +Validate the update_weights_from_disk API contract, request/response shape, +error handling, checksum verification, and corrupted-weight rollback. • test_update_weights_with_flush_cache - Explicit flush_cache=True; must succeed. Ensures the flush_cache parameter - is accepted and applied. + + Explicit flush_cache=True; must succeed (200, success=True). Ensures the + flush_cache parameter is accepted and the update completes. + + TODO: Currently, TeaCache can not be verified whether it was flushed + since no cache-state API is exposed. • test_update_weights_without_flush_cache + Explicit flush_cache=False; must succeed. Ensures updates work when not - flushing TeaCache. + requesting TeaCache flush. • test_update_weights_nonexistent_model - model_path set to a non-existent path; must fail (success=False). Verifies - all-or-nothing / rollback semantics when load fails. + + model_path set to a non-existent path; must fail (400, success=False). + Also, verifies that the update fails and the model is rolled back to the + original weights. • test_update_weights_missing_model_path + Request body empty (no model_path); must return 400. Validates required parameter checks. • test_update_weights_specific_modules - target_modules=["transformer"]; must return 200. Verifies partial module - update (target_modules parameter). + + Randomly selects a subset of pipeline modules as target_modules, then asserts: + (1) updated modules' checksums match the update model's disk checksums; + (2) non-updated modules' checksums are unchanged (same before and after). • test_update_weights_nonexistent_module + target_modules=["nonexistent_module"]; must return 400 and message containing "not found in pipeline". Validates rejection of invalid module names. • test_update_weights_checksum_matches + Updates weights to the update model. Verifies the post-update checksum matches the update model's disk checksum. • test_corrupted_weights_rollback + Verify all-or-nothing rollback semantics when loading corrupted weights. Builds a corrupted model directory by copying the base model and truncating the vae safetensors. Requests an update with target_modules=["transformer", @@ -200,6 +220,19 @@ def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) +def _get_modules_with_weights_on_disk( + model_path: str, module_names: list[str] +) -> list[str]: + """Return module names that have safetensors on disk for the given model.""" + local_path = maybe_download_model(model_path) + result = [] + for name in module_names: + weights_dir = find_weights_dir(local_path, name) + if weights_dir and _list_safetensors_files(weights_dir): + result.append(name) + return result + + def _prepare_corrupted_model( src_model: str, dst_model: str, corrupt_module: str ) -> None: @@ -363,7 +396,11 @@ def _get_weights_checksum( return response.json() def test_update_weights_with_flush_cache(self, diffusion_server_no_offload): - """Test updating weights with flush_cache=True.""" + """Test updating weights with flush_cache=True. + + Verifies the API accepts flush_cache=True and returns success; does not + assert that TeaCache was actually reset (server does not expose cache state). + """ ctx, _default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) @@ -420,22 +457,67 @@ def test_update_weights_missing_model_path(self, diffusion_server_no_offload): assert response.status_code == 400, f"Expected 400, got {response.status_code}" def test_update_weights_specific_modules(self, diffusion_server_no_offload): - """Test updating only specific modules (e.g., transformer only).""" - ctx, _default_model, update_model = diffusion_server_no_offload + """Partial update: random subset of modules updated; checksums verified. + + Randomly picks a non-empty subset of modules that have weights on disk + for the update model, performs update_weights_from_disk with that + target_modules, then asserts: + - Updated modules: in-memory checksum == update model disk checksum. + - Non-updated modules: checksum unchanged (before == after). + """ + ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) - # Try to update only transformer module + # Reset to base model so we start from a known state. + self._update_weights(base_url, default_model) + + # All pipeline module names (from server). + all_checksums = self._get_weights_checksum(base_url, module_names=None) + all_module_names = [k for k in all_checksums if all_checksums[k] != "not_found"] + if not all_module_names: + pytest.skip("No updatable modules reported by server") + + # Only consider modules that exist on disk for the update model. + candidates = _get_modules_with_weights_on_disk(update_model, all_module_names) + if not candidates: + pytest.skip("Update model has no weight dirs for any pipeline module") + + # Random non-empty subset (fixed seed for reproducibility). + random.seed(42) + k = random.randint(1, len(candidates)) + target_modules = random.sample(candidates, k) + target_set = set(target_modules) + logger.info( + "Partial update test: target_modules=%s (unchanged: %s)", + target_modules, + [m for m in all_module_names if m not in target_set], + ) + + before_checksums = self._get_weights_checksum(base_url, module_names=None) + result, status_code = self._update_weights( base_url, update_model, - target_modules=["transformer"], + target_modules=target_modules, ) - logger.info(f"Update specific modules result: {result}") + assert status_code == 200, f"Update failed: {result}" + assert result.get("success", False), f"Update failed: {result.get('message')}" - # This might fail if the model doesn't have a transformer module - # or if weights for only transformer aren't available - # The test verifies the API handles target_modules parameter - assert status_code == 200 + after_checksums = self._get_weights_checksum(base_url, module_names=None) + + for name in all_module_names: + if name in target_set: + disk_cs = _compute_checksum_from_disk(update_model, name) + assert after_checksums.get(name) == disk_cs, ( + f"Updated module '{name}': checksum should match update model disk\n" + f" disk: {disk_cs}\n gpu: {after_checksums.get(name)}" + ) + else: + assert after_checksums.get(name) == before_checksums.get(name), ( + f"Non-updated module '{name}': checksum must be unchanged\n" + f" before: {before_checksums.get(name)}\n" + f" after: {after_checksums.get(name)}" + ) def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): """Test that requesting a non-existent module name fails with a clear error.""" From c35eecd25c28ede5866c1482f5c3c0ce9fe307f2 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Fri, 13 Feb 2026 22:34:10 -0800 Subject: [PATCH 14/30] Deduplicated tests; Should clean up --- .../server/test_update_weights_from_disk.py | 494 ++++++++++-------- 1 file changed, 269 insertions(+), 225 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index e9c3925a8253..abd7c8677b89 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -4,13 +4,14 @@ This module verifies the ability to update model weights in place without restarting the server, which is critical for RL workflows and iterative fine-tuning scenarios. -We use two model pairs for testing (before / after update model pairs): +We use two model pairs for testing (base model / instruct model pairs): - FLUX.2-klein-base-4B / FLUX.2-klein-4B - Qwen/Qwen-Image / Qwen/Qwen-Image-2512 -These models are with the same model architecture and different number -of parameters. Only weights are different. +These model pairs share the same architecture, but not every module is +guaranteed to have different weights between base and update models. +Some modules can be identical across the pair. Author: @@ -21,13 +22,13 @@ Test organization: -10 test cases in 2 classes; +7 test cases in 2 classes; two model pairs are tested locally, one in CI. ============================================================================= -Class 1: TestUpdateWeightsFromDisk (8 tests) — API contract, checksum & rollback -Class 2: TestUpdateWeightsFromDiskWithOffload (2 tests) — Offload-aware update +Class 1: TestUpdateWeightsFromDisk (6 tests) — API contract, checksum & rollback +Class 2: TestUpdateWeightsFromDiskWithOffload (1 test) — Offload-aware update + checksum ----------------------------------------------------------------------------- @@ -36,100 +37,75 @@ Validate the update_weights_from_disk API contract, request/response shape, error handling, checksum verification, and corrupted-weight rollback. - • test_update_weights_with_flush_cache +All tests share one class-scoped server (same process, same in-memory weights). +Tests that require "base model then update" should be explicitly reset to +default_model first so behavior is order-independent and updates are real + (base→update), not no-ops (update→update). - Explicit flush_cache=True; must succeed (200, success=True). Ensures the - flush_cache parameter is accepted and the update completes. + • test_update_weights_from_disk_default - TODO: Currently, TeaCache can not be verified whether it was flushed - since no cache-state API is exposed. + base -> instruct with flush_cache=True. Verifies: + (1) before-update checksum == base model disk checksum; + (2) after-update checksum == instruct model disk checksum; + (3) before != after (update actually changed weights). + rollback to base model after update. - • test_update_weights_without_flush_cache + • test_update_weights_specific_modules - Explicit flush_cache=False; must succeed. Ensures updates work when not - requesting TeaCache flush. + base -> instruct with flush_cache=False: randomly selects target_modules, + updates only those from base to instruct model. Verifies: + (1) updated modules' checksums match instruct model disk checksum; + (2) non-updated modules' checksums are unchanged (before == after == disk). + rollback to base model after update. • test_update_weights_nonexistent_model model_path set to a non-existent path; must fail (400, success=False). - Also, verifies that the update fails and the model is rolled back to the - original weights. - • test_update_weights_missing_model_path + Ensure server is healthy after inaccurate update and server's checksums + equals to base model's disk checksums. - Request body empty (no model_path); must return 400. Validates required - parameter checks. + • test_update_weights_missing_model_path - • test_update_weights_specific_modules + Request body empty (no model_path); must fail (400, success=False). - Randomly selects a subset of pipeline modules as target_modules, then asserts: - (1) updated modules' checksums match the update model's disk checksums; - (2) non-updated modules' checksums are unchanged (same before and after). + Ensure server is healthy after inaccurate update and server's checksums + equals to base model's disk checksums. • test_update_weights_nonexistent_module - target_modules=["nonexistent_module"]; must return 400 and message containing - "not found in pipeline". Validates rejection of invalid module names. - - • test_update_weights_checksum_matches + target_modules=["nonexistent_module"]; must fail (400, success=False). - Updates weights to the update model. Verifies the post-update checksum - matches the update model's disk checksum. + Verify server is healthy after inaccurate update and server's checksums + equals to base model's disk checksums. • test_corrupted_weights_rollback - Verify all-or-nothing rollback semantics when loading corrupted weights. - Builds a corrupted model directory by copying the base model and truncating - the vae safetensors. Requests an update with target_modules=["transformer", - "vae"]. The transformer updates successfully first; the corrupted vae then - fails during safetensors validation, triggering a rollback that restores + Verify base -> instruct rollback after loading corrupted instruct model. + Builds a corrupted model directory by copying the instruct model and + truncating the vae safetensors. Updates with target_modules=["transformer", + "vae"]. The transformer updates successfully first; the corrupted vae module + then fails during safetensors validation, triggering a rollback that restores the transformer to its previous weights. ------------------------------------------------------------------------------ -Class 2: TestUpdateWeightsFromDiskWithOffload ------------------------------------------------------------------------------ -Purpose: Ensure weight updates and checksum verification work when layerwise -offload is enabled (--dit-layerwise-offload). With offload, parameters live in -CPU buffers and placeholders on GPU; the updater must write into CPU buffers and -update prefetched GPU tensors without shape mismatch. The checksum endpoint must -read from CPU buffers (not the (1,) placeholders) to produce correct results. - - • test_update_weights_with_offload_enabled - Server started with --dit-layerwise-offload true. Call update_weights_from_disk - with the same model; must succeed (200, success=True) and message must not - contain "Shape mismatch". - - • test_update_weights_checksum_matches - Fetches checksum before update (base model), then updates weights and fetches - checksum again (update model). Verifies the post-update checksum matches the - update model's disk checksum, and differs from the pre-update checksum. + Ensure server is healthy after rollback and server's checksums equals to + base model's disk checksums. -============================================================================= -Relation to RL scenarios and reference implementation -============================================================================= - -In RL or iterative training, a typical pattern is: +----------------------------------------------------------------------------- - 1. Run a diffusion (or LLM) server for inference. - 2. Periodically pull new weights (e.g., from a training run or from disk) - without restarting the server. - 3. Continue serving with the updated model. +Class 2: TestUpdateWeightsFromDiskWithOffload -The diffusion engine supports this via POST /update_weights_from_disk: it loads -weights from a model_path (HF repo or local) and applies them in-place, with -rollback on failure and support for layerwise offload and DTensor. -For a distributed RL setup where the training process broadcasts weights to -inference engines (rather than loading from disk), see the SGLang LLM test that -simulates rank 0 as trainer and other ranks as inference engines, using -update_weights_from_distributed and init_weights_update_group: +Ensure weight updates and checksum verification work when layerwise offload is enabled +(--dit-layerwise-offload). With offload, parameters live in CPU buffers and only left +small torch.empty((1,)) as placeholders on GPU; the updater must write into CPU buffers +and update prefetched GPU tensors without shape mismatch. - https://github.com/sgl-project/sglang/blob/main/test/registered/rl/test_update_weights_from_distributed.py + • test_update_weights_with_offload_enabled -That test verifies weight synchronization across ranks (instruct vs base model) -and optional pause_generation/continue_generation during update. This diffusion -test suite focuses on the disk-based update path and offload/consistency -behavior of the diffusion engine only. + Server with --dit-layerwise-offload (base). Update to instruct; must succeed + (200, success=True), message must not contain "Shape mismatch". Assert server + checksums == instruct model disk checksums (server healthy). """ from __future__ import annotations @@ -171,11 +147,11 @@ "black-forest-labs/FLUX.2-klein-4B", 5.0, ), - ( - "Qwen/Qwen-Image", - "Qwen/Qwen-Image-2512", - 1.0, # Qwen Image is large; run it less often in CI. - ), + # ( + # "Qwen/Qwen-Image", + # "Qwen/Qwen-Image-2512", + # 1.0, # Qwen Image is large; run it less often in CI. + # ), ] @@ -233,6 +209,23 @@ def _get_modules_with_weights_on_disk( return result +def _get_modules_with_different_checksums( + base_model: str, update_model: str, module_names: list[str] +) -> list[str]: + """Return shared modules whose disk checksums differ across model pair.""" + base_modules = set(_get_modules_with_weights_on_disk(base_model, module_names)) + update_modules = set(_get_modules_with_weights_on_disk(update_model, module_names)) + shared_modules = sorted(base_modules & update_modules) + + changed_modules = [] + for name in shared_modules: + base_cs = _compute_checksum_from_disk(base_model, name) + update_cs = _compute_checksum_from_disk(update_model, name) + if base_cs != update_cs: + changed_modules.append(name) + return changed_modules + + def _prepare_corrupted_model( src_model: str, dst_model: str, corrupt_module: str ) -> None: @@ -395,75 +388,95 @@ def _get_weights_checksum( ), f"get_weights_checksum failed: {response.status_code} {response.text}" return response.json() - def test_update_weights_with_flush_cache(self, diffusion_server_no_offload): - """Test updating weights with flush_cache=True. + def _assert_server_checksums_match_base_disk( + self, base_url: str, default_model: str, update_model: str + ) -> None: + """Assert changed modules on server match base model disk (server healthy).""" + all_checksums = self._get_weights_checksum(base_url, module_names=None) + module_names = [k for k in all_checksums if all_checksums.get(k) != "not_found"] + changed_modules = _get_modules_with_different_checksums( + default_model, update_model, module_names + ) + if not changed_modules: + pytest.skip("No checksum-different shared modules in model pair") + + for name in changed_modules: + server_cs = all_checksums.get(name) + base_disk_cs = _compute_checksum_from_disk(default_model, name) + assert server_cs == base_disk_cs, ( + f"Server checksum for '{name}' should match base model disk (server healthy)\n" + f" base_disk: {base_disk_cs}\n server: {server_cs}" + ) + + def test_update_weights_from_disk_default(self, diffusion_server_no_offload): + """Base→instruct with flush_cache=True; verify before/after; rollback to base. - Verifies the API accepts flush_cache=True and returns success; does not - assert that TeaCache was actually reset (server does not expose cache state). + Resets to base, records before checksum. Updates to instruct with + flush_cache=True. Asserts: (1) before == base disk; (2) after == instruct + disk; (3) before != after. Then rollback to base so server ends on base. """ - ctx, _default_model, update_model = diffusion_server_no_offload + ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) - result, status_code = self._update_weights( - base_url, - update_model, - flush_cache=True, - ) - - assert status_code == 200 - assert result.get("success", False), f"Update failed: {result.get('message')}" + # Reset to base so we have a real base→instruct. + self._update_weights(base_url, default_model) - def test_update_weights_without_flush_cache(self, diffusion_server_no_offload): - """Test updating weights with flush_cache=False.""" - ctx, _default_model, update_model = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + before_checksum = self._get_weights_checksum( + base_url, module_names=["transformer"] + )["transformer"] + base_disk = _compute_checksum_from_disk(default_model, "transformer") result, status_code = self._update_weights( base_url, update_model, - flush_cache=False, + flush_cache=True, ) + assert status_code == 200 and result.get( + "success" + ), f"Update failed: {result.get('message')}" - assert status_code == 200 - assert result.get("success", False), f"Update failed: {result.get('message')}" + after_checksum = self._get_weights_checksum( + base_url, module_names=["transformer"] + )["transformer"] + instruct_disk = _compute_checksum_from_disk(update_model, "transformer") - def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): - """Test that updating with non-existent model fails gracefully.""" - ctx, _default_model, _update_model = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + print(f"\n{'='*60}") + print("Checksum test (base→instruct with flush_cache=True)") + print(f" before (gpu): {before_checksum}") + print(f" base (disk): {base_disk}") + print(f" after (gpu): {after_checksum}") + print(f" instruct (disk): {instruct_disk}") + print(f" before==base_disk: {before_checksum == base_disk}") + print(f" after==instruct_disk: {after_checksum == instruct_disk}") + print(f" before!=after: {before_checksum != after_checksum}") + print(f"{'='*60}") - result, status_code = self._update_weights( - base_url, - "/nonexistent/path/to/model", - timeout=60, + assert before_checksum == base_disk, ( + f"Before-update checksum should match base model disk\n" + f" base_disk: {base_disk}\n before: {before_checksum}" ) - logger.info(f"Update result for nonexistent model: {result}") - - # Should fail gracefully - assert not result.get("success", True), "Should fail for nonexistent model" - - def test_update_weights_missing_model_path(self, diffusion_server_no_offload): - """Test that request without model_path returns 400.""" - ctx, _default_model, _update_model = diffusion_server_no_offload - base_url = self._get_base_url(ctx) - - response = requests.post( - f"{base_url}/update_weights_from_disk", - json={}, - timeout=30, + assert after_checksum == instruct_disk, ( + f"After-update checksum should match instruct model disk\n" + f" instruct_disk: {instruct_disk}\n after: {after_checksum}" ) + assert ( + before_checksum != after_checksum + ), "Before and after checksums should differ (update changed weights)" - # Should return 400 Bad Request - assert response.status_code == 400, f"Expected 400, got {response.status_code}" + # Rollback to base so server ends in known state. + self._update_weights(base_url, default_model) def test_update_weights_specific_modules(self, diffusion_server_no_offload): - """Partial update: random subset of modules updated; checksums verified. - - Randomly picks a non-empty subset of modules that have weights on disk - for the update model, performs update_weights_from_disk with that - target_modules, then asserts: - - Updated modules: in-memory checksum == update model disk checksum. - - Non-updated modules: checksum unchanged (before == after). + """Partial update base→instruct with flush_cache=False; verify checksums; rollback to base. + + Randomly picks target_modules, updates only those to instruct with + flush_cache=False. Asserts: + (1) for modules whose base/update disk checksums differ, updated modules + match update-model disk and actually change; + (2) for modules with identical base/update checksums, updating them keeps + checksums unchanged; + (3) non-updated modules remain unchanged (before == after). + Then rollback to base. """ ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) @@ -477,20 +490,35 @@ def test_update_weights_specific_modules(self, diffusion_server_no_offload): if not all_module_names: pytest.skip("No updatable modules reported by server") - # Only consider modules that exist on disk for the update model. - candidates = _get_modules_with_weights_on_disk(update_model, all_module_names) + # Only consider modules that have weights on disk in both models. + base_modules = set( + _get_modules_with_weights_on_disk(default_model, all_module_names) + ) + update_modules = set( + _get_modules_with_weights_on_disk(update_model, all_module_names) + ) + candidates = sorted(base_modules & update_modules) if not candidates: - pytest.skip("Update model has no weight dirs for any pipeline module") + pytest.skip("No shared modules with weights on disk in model pair") - # Random non-empty subset (fixed seed for reproducibility). + changed_modules = _get_modules_with_different_checksums( + default_model, update_model, candidates + ) + if not changed_modules: + pytest.skip("No checksum-different shared modules in model pair") + + # Random non-empty subset (fixed seed) that always includes one changed module. random.seed(42) - k = random.randint(1, len(candidates)) - target_modules = random.sample(candidates, k) + must_include = random.choice(changed_modules) + optional = [m for m in candidates if m != must_include] + k_extra = random.randint(0, len(optional)) + target_modules = [must_include] + random.sample(optional, k_extra) target_set = set(target_modules) + changed_set = set(changed_modules) logger.info( - "Partial update test: target_modules=%s (unchanged: %s)", + "Partial update test (flush_cache=False): target_modules=%s (checksum-different modules: %s)", target_modules, - [m for m in all_module_names if m not in target_set], + changed_modules, ) before_checksums = self._get_weights_checksum(base_url, module_names=None) @@ -499,6 +527,7 @@ def test_update_weights_specific_modules(self, diffusion_server_no_offload): base_url, update_model, target_modules=target_modules, + flush_cache=False, ) assert status_code == 200, f"Update failed: {result}" assert result.get("success", False), f"Update failed: {result.get('message')}" @@ -507,11 +536,24 @@ def test_update_weights_specific_modules(self, diffusion_server_no_offload): for name in all_module_names: if name in target_set: - disk_cs = _compute_checksum_from_disk(update_model, name) - assert after_checksums.get(name) == disk_cs, ( - f"Updated module '{name}': checksum should match update model disk\n" - f" disk: {disk_cs}\n gpu: {after_checksums.get(name)}" - ) + if name in changed_set: + disk_cs = _compute_checksum_from_disk(update_model, name) + assert after_checksums.get(name) == disk_cs, ( + f"Updated module '{name}': checksum should match update model disk\n" + f" disk: {disk_cs}\n gpu: {after_checksums.get(name)}" + ) + assert after_checksums.get(name) != before_checksums.get(name), ( + f"Updated module '{name}' should change checksum (base != update)\n" + f" before: {before_checksums.get(name)}\n" + f" after: {after_checksums.get(name)}" + ) + else: + assert after_checksums.get(name) == before_checksums.get(name), ( + f"Updated module '{name}' has identical base/update disk checksum, " + "so it should remain unchanged\n" + f" before: {before_checksums.get(name)}\n" + f" after: {after_checksums.get(name)}" + ) else: assert after_checksums.get(name) == before_checksums.get(name), ( f"Non-updated module '{name}': checksum must be unchanged\n" @@ -519,57 +561,67 @@ def test_update_weights_specific_modules(self, diffusion_server_no_offload): f" after: {after_checksums.get(name)}" ) - def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): - """Test that requesting a non-existent module name fails with a clear error.""" - ctx, _default_model, update_model = diffusion_server_no_offload + # Rollback to base so server ends in known state. + self._update_weights(base_url, default_model) + + def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): + """Nonexistent model path must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) + self._update_weights(base_url, default_model) + result, status_code = self._update_weights( base_url, - update_model, - target_modules=["nonexistent_module"], + "/nonexistent/path/to/model", timeout=60, ) - logger.info(f"Update nonexistent module result: {result}") + logger.info(f"Update result for nonexistent model: {result}") assert status_code == 400, f"Expected 400, got {status_code}" - assert not result.get("success", True), "Should fail for nonexistent module" - assert "not found in pipeline" in result.get("message", "") - - def test_update_weights_checksum_matches(self, diffusion_server_no_offload): - """Verify GPU checksum matches disk after weight update. + assert not result.get("success", True), "Should fail for nonexistent model" + self._assert_server_checksums_match_base_disk( + base_url, default_model, update_model + ) - Resets to the base model first (shared fixture may be in any state), - then updates to the update model and compares the server-side - checksum with the disk checksum. - """ + def test_update_weights_missing_model_path(self, diffusion_server_no_offload): + """Request without model_path must fail (400). Server healthy, checksums == base disk.""" ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) - # Reset to base model so the subsequent update is a real change. self._update_weights(base_url, default_model) - result, status_code = self._update_weights(base_url, update_model) - assert status_code == 200 and result.get( - "success" - ), f"Update failed: {result.get('message')}" + response = requests.post( + f"{base_url}/update_weights_from_disk", + json={}, + timeout=30, + ) - gpu_checksum = self._get_weights_checksum( - base_url, module_names=["transformer"] - )["transformer"] - disk_checksum = _compute_checksum_from_disk(update_model, "transformer") + assert response.status_code == 400, f"Expected 400, got {response.status_code}" + self._assert_server_checksums_match_base_disk( + base_url, default_model, update_model + ) - print(f"\n{'='*60}") - print(f"Checksum test") - print(f" gpu: {gpu_checksum}") - print(f" disk: {disk_checksum}") - print(f" match: {gpu_checksum == disk_checksum}") - print(f"{'='*60}") + def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): + """Nonexistent module must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, update_model = diffusion_server_no_offload + base_url = self._get_base_url(ctx) + + self._update_weights(base_url, default_model) + + result, status_code = self._update_weights( + base_url, + update_model, + target_modules=["nonexistent_module"], + timeout=60, + ) + logger.info(f"Update nonexistent module result: {result}") - assert gpu_checksum == disk_checksum, ( - f"GPU checksum does not match disk checksum for update model\n" - f" disk: {disk_checksum}\n" - f" gpu: {gpu_checksum}" + assert status_code == 400, f"Expected 400, got {status_code}" + assert not result.get("success", True), "Should fail for nonexistent module" + assert "not found in pipeline" in result.get("message", "") + self._assert_server_checksums_match_base_disk( + base_url, default_model, update_model ) def test_corrupted_weights_rollback( @@ -577,11 +629,11 @@ def test_corrupted_weights_rollback( diffusion_server_no_offload, corrupted_model_dir: str, ): - """Load base -> update weights -> attempt corrupted -> verify rollback. + """Base→instruct then load corrupted instruct; verify rollback. - Checksums are restricted to ["transformer", "vae"] — the modules - involved in the partial update — to avoid computing checksums for - unrelated modules. + Updates to instruct, then attempts load from corrupted instruct dir + (vae safetensors truncated). Rollback restores to instruct state. + Ensures server healthy: reset to base and assert checksums == base disk. """ ctx, default_model, update_model = diffusion_server_no_offload base_url = self._get_base_url(ctx) @@ -666,7 +718,16 @@ def test_corrupted_weights_rollback( assert post_rollback_checksums != base_checksums, ( "Post-rollback checksums should not match base model " - "(rollback target is update model, not base)" + "(rollback target is instruct model, not base)" + ) + + # Ensure server healthy: reset to base and verify checksums == base disk. + result, status_code = self._update_weights(base_url, default_model) + assert status_code == 200 and result.get( + "success" + ), f"Failed to reset to base after rollback: {result.get('message')}" + self._assert_server_checksums_match_base_disk( + base_url, default_model, update_model ) @@ -716,23 +777,6 @@ def _update_weights( ) return response.json(), response.status_code - def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): - """Test that weight update works correctly when layerwise offload is enabled.""" - ctx, _default_model, update_model = diffusion_server_with_offload - base_url = self._get_base_url(ctx) - - logger.info("Testing weight update with offload enabled") - - result, status_code = self._update_weights(base_url, update_model) - logger.info(f"Update result: {result}") - - assert status_code == 200, f"Expected 200, got {status_code}" - assert result.get("success", False), f"Update failed: {result.get('message')}" - - # Verify no shape mismatch warnings in the message - message = result.get("message", "") - assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" - def _get_weights_checksum( self, base_url: str, @@ -754,40 +798,40 @@ def _get_weights_checksum( ), f"get_weights_checksum failed: {response.status_code} {response.text}" return response.json() - def test_update_weights_checksum_matches(self, diffusion_server_with_offload): - """Verify checksum from offloaded CPU buffers matches disk after update. + def _assert_server_checksums_match_instruct_disk( + self, base_url: str, default_model: str, update_model: str + ) -> None: + """Assert changed modules on server match update model disk (server healthy).""" + all_checksums = self._get_weights_checksum(base_url, module_names=None) + module_names = [k for k in all_checksums if all_checksums.get(k) != "not_found"] + changed_modules = _get_modules_with_different_checksums( + default_model, update_model, module_names + ) + if not changed_modules: + pytest.skip("No checksum-different shared modules in model pair") + + for name in changed_modules: + server_cs = all_checksums.get(name) + instruct_disk_cs = _compute_checksum_from_disk(update_model, name) + assert server_cs == instruct_disk_cs, ( + f"Server checksum for '{name}' should match instruct model disk\n" + f" instruct_disk: {instruct_disk_cs}\n server: {server_cs}" + ) - Resets to the base model first, then updates to the update model - and compares the server-side checksum (read from CPU buffers) with - the disk checksum. - """ + def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): + """Offload: base→instruct update; no Shape mismatch; checksums == instruct disk.""" ctx, default_model, update_model = diffusion_server_with_offload base_url = self._get_base_url(ctx) - # Reset to base model so the subsequent update is a real change. - self._update_weights(base_url, default_model) - result, status_code = self._update_weights(base_url, update_model) - assert status_code == 200 and result.get( - "success" - ), f"Update failed: {result.get('message')}" - - gpu_checksum = self._get_weights_checksum( - base_url, module_names=["transformer"] - )["transformer"] - disk_checksum = _compute_checksum_from_disk(update_model, "transformer") + assert status_code == 200, f"Expected 200, got {status_code}" + assert result.get("success", False), f"Update failed: {result.get('message')}" - print(f"\n{'='*60}") - print(f"Offload checksum test") - print(f" gpu: {gpu_checksum}") - print(f" disk: {disk_checksum}") - print(f" match: {gpu_checksum == disk_checksum}") - print(f"{'='*60}") + message = result.get("message", "") + assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" - assert gpu_checksum == disk_checksum, ( - f"GPU checksum does not match disk checksum for update model\n" - f" disk: {disk_checksum}\n" - f" gpu: {gpu_checksum}" + self._assert_server_checksums_match_instruct_disk( + base_url, default_model, update_model ) From 41148c40139a8ca322787922a37a3ab21b01a5da Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Fri, 13 Feb 2026 22:59:06 -0800 Subject: [PATCH 15/30] clean up codes with mixin; currently spanning 16mins; too long; should reduce it again --- .../server/test_update_weights_from_disk.py | 276 ++++++------------ 1 file changed, 92 insertions(+), 184 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index abd7c8677b89..a7eecf80100f 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -1,9 +1,13 @@ -""" -Tests for update_weights_from_disk API in SGLang diffusion server. +"""Tests for diffusion `update_weights_from_disk`. This module verifies the ability to update model weights in place without restarting the server, which is critical for RL workflows and iterative fine-tuning scenarios. +Author: + +Menyang Liu, https://github.com/dreamyang-liu +Chenyang Zhao, https://github.com/zhaochenyang20 + We use two model pairs for testing (base model / instruct model pairs): - FLUX.2-klein-base-4B / FLUX.2-klein-4B @@ -13,11 +17,6 @@ guaranteed to have different weights between base and update models. Some modules can be identical across the pair. -Author: - -Menyang Liu, https://github.com/dreamyang-liu -Chenyang Zhao, https://github.com/zhaochenyang20 - ============================================================================= Test organization: @@ -138,20 +137,17 @@ logger = init_logger(__name__) -# Model pairs for weight update tests: (default_model, update_model, ci_weight). -# The server starts with default_model; tests update weights to update_model. -# ci_weight controls how likely each pair is to be selected in CI runs. _ALL_MODEL_PAIRS: list[tuple[str, str, float]] = [ ( "black-forest-labs/FLUX.2-klein-base-4B", "black-forest-labs/FLUX.2-klein-4B", 5.0, ), - # ( - # "Qwen/Qwen-Image", - # "Qwen/Qwen-Image-2512", - # 1.0, # Qwen Image is large; run it less often in CI. - # ), + ( + "Qwen/Qwen-Image", + "Qwen/Qwen-Image-2512", + 1.0, # Qwen Image is large; run it less often in CI. + ), ] @@ -174,6 +170,7 @@ def _select_model_pairs() -> list[tuple[str, str]]: _ACTIVE_MODEL_PAIRS = _select_model_pairs() +_PAIR_IDS = [p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS] @functools.lru_cache(maxsize=None) @@ -285,7 +282,72 @@ def _prepare_corrupted_model( ) -class TestUpdateWeightsFromDisk: +class _UpdateWeightsApiMixin: + def _get_base_url(self, ctx: ServerContext) -> str: + return f"http://localhost:{ctx.port}" + + def _update_weights( + self, + base_url: str, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + timeout: int = 300, + ) -> tuple[dict, int]: + payload = {"model_path": model_path, "flush_cache": flush_cache} + if target_modules is not None: + payload["target_modules"] = target_modules + response = requests.post( + f"{base_url}/update_weights_from_disk", + json=payload, + timeout=timeout, + ) + return response.json(), response.status_code + + def _get_weights_checksum( + self, + base_url: str, + module_names: list[str] | None = None, + timeout: int = 300, + ) -> dict: + payload = {} + if module_names is not None: + payload["module_names"] = module_names + response = requests.post( + f"{base_url}/get_weights_checksum", + json=payload, + timeout=timeout, + ) + assert ( + response.status_code == 200 + ), f"get_weights_checksum failed: {response.status_code} {response.text}" + return response.json() + + def _assert_server_matches_model_on_changed_modules( + self, + base_url: str, + base_model: str, + update_model: str, + expected_model: str, + ) -> None: + all_checksums = self._get_weights_checksum(base_url) + module_names = [k for k, v in all_checksums.items() if v != "not_found"] + changed_modules = _get_modules_with_different_checksums( + base_model, update_model, module_names + ) + if not changed_modules: + pytest.skip("No checksum-different shared modules in model pair") + for name in changed_modules: + server_cs = all_checksums.get(name) + expected_cs = _compute_checksum_from_disk(expected_model, name) + assert server_cs == expected_cs, ( + f"Checksum mismatch on '{name}'\n" + f" expected({expected_model}): {expected_cs}\n" + f" server: {server_cs}" + ) + + +class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): """Test suite for update_weights_from_disk API and corrupted-weight rollback. Uses a class-scoped server fixture so the server is torn down at class end, @@ -295,7 +357,7 @@ class TestUpdateWeightsFromDisk: @pytest.fixture( scope="class", params=_ACTIVE_MODEL_PAIRS, - ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], + ids=_PAIR_IDS, ) def diffusion_server_no_offload(self, request): """Start a diffusion server (no offload) for this test class. @@ -341,73 +403,6 @@ def corrupted_model_dir(self, diffusion_server_no_offload): yield tmpdir shutil.rmtree(tmpdir, ignore_errors=True) - def _get_base_url(self, ctx: ServerContext) -> str: - return f"http://localhost:{ctx.port}" - - def _update_weights( - self, - base_url: str, - model_path: str, - flush_cache: bool = True, - target_modules: list[str] | None = None, - timeout: int = 300, - ) -> tuple[dict, int]: - """Call update_weights_from_disk API.""" - payload = { - "model_path": model_path, - "flush_cache": flush_cache, - } - if target_modules is not None: - payload["target_modules"] = target_modules - - response = requests.post( - f"{base_url}/update_weights_from_disk", - json=payload, - timeout=timeout, - ) - return response.json(), response.status_code - - def _get_weights_checksum( - self, - base_url: str, - module_names: list[str] | None = None, - timeout: int = 300, - ) -> dict: - """Call get_weights_checksum API and return the checksum dict.""" - payload = {} - if module_names is not None: - payload["module_names"] = module_names - - response = requests.post( - f"{base_url}/get_weights_checksum", - json=payload, - timeout=timeout, - ) - assert ( - response.status_code == 200 - ), f"get_weights_checksum failed: {response.status_code} {response.text}" - return response.json() - - def _assert_server_checksums_match_base_disk( - self, base_url: str, default_model: str, update_model: str - ) -> None: - """Assert changed modules on server match base model disk (server healthy).""" - all_checksums = self._get_weights_checksum(base_url, module_names=None) - module_names = [k for k in all_checksums if all_checksums.get(k) != "not_found"] - changed_modules = _get_modules_with_different_checksums( - default_model, update_model, module_names - ) - if not changed_modules: - pytest.skip("No checksum-different shared modules in model pair") - - for name in changed_modules: - server_cs = all_checksums.get(name) - base_disk_cs = _compute_checksum_from_disk(default_model, name) - assert server_cs == base_disk_cs, ( - f"Server checksum for '{name}' should match base model disk (server healthy)\n" - f" base_disk: {base_disk_cs}\n server: {server_cs}" - ) - def test_update_weights_from_disk_default(self, diffusion_server_no_offload): """Base→instruct with flush_cache=True; verify before/after; rollback to base. @@ -440,17 +435,6 @@ def test_update_weights_from_disk_default(self, diffusion_server_no_offload): )["transformer"] instruct_disk = _compute_checksum_from_disk(update_model, "transformer") - print(f"\n{'='*60}") - print("Checksum test (base→instruct with flush_cache=True)") - print(f" before (gpu): {before_checksum}") - print(f" base (disk): {base_disk}") - print(f" after (gpu): {after_checksum}") - print(f" instruct (disk): {instruct_disk}") - print(f" before==base_disk: {before_checksum == base_disk}") - print(f" after==instruct_disk: {after_checksum == instruct_disk}") - print(f" before!=after: {before_checksum != after_checksum}") - print(f"{'='*60}") - assert before_checksum == base_disk, ( f"Before-update checksum should match base model disk\n" f" base_disk: {base_disk}\n before: {before_checksum}" @@ -580,8 +564,8 @@ def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): assert status_code == 400, f"Expected 400, got {status_code}" assert not result.get("success", True), "Should fail for nonexistent model" - self._assert_server_checksums_match_base_disk( - base_url, default_model, update_model + self._assert_server_matches_model_on_changed_modules( + base_url, default_model, update_model, default_model ) def test_update_weights_missing_model_path(self, diffusion_server_no_offload): @@ -598,8 +582,8 @@ def test_update_weights_missing_model_path(self, diffusion_server_no_offload): ) assert response.status_code == 400, f"Expected 400, got {response.status_code}" - self._assert_server_checksums_match_base_disk( - base_url, default_model, update_model + self._assert_server_matches_model_on_changed_modules( + base_url, default_model, update_model, default_model ) def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): @@ -620,8 +604,8 @@ def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): assert status_code == 400, f"Expected 400, got {status_code}" assert not result.get("success", True), "Should fail for nonexistent module" assert "not found in pipeline" in result.get("message", "") - self._assert_server_checksums_match_base_disk( - base_url, default_model, update_model + self._assert_server_matches_model_on_changed_modules( + base_url, default_model, update_model, default_model ) def test_corrupted_weights_rollback( @@ -696,19 +680,6 @@ def test_corrupted_weights_rollback( ) logger.info(f"Post-rollback checksums: {post_rollback_checksums}") - print(f"\n{'='*80}") - print("Corrupted-weight rollback test (transformer, vae)") - for module in sorted(update_checksums.keys()): - update_cs = update_checksums.get(module, "N/A") - rollback_cs = post_rollback_checksums.get(module, "N/A") - base_cs = base_checksums.get(module, "N/A") - match = "OK" if update_cs == rollback_cs else "MISMATCH" - print(f" [{match}] {module}") - print(f" base: {base_cs}") - print(f" update: {update_cs}") - print(f" rollback: {rollback_cs}") - print(f"{'='*80}") - for module in update_checksums: assert post_rollback_checksums.get(module) == update_checksums[module], ( f"Module '{module}' checksum mismatch after rollback\n" @@ -726,25 +697,17 @@ def test_corrupted_weights_rollback( assert status_code == 200 and result.get( "success" ), f"Failed to reset to base after rollback: {result.get('message')}" - self._assert_server_checksums_match_base_disk( - base_url, default_model, update_model + self._assert_server_matches_model_on_changed_modules( + base_url, default_model, update_model, default_model ) -class TestUpdateWeightsFromDiskWithOffload: +class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): """Test update_weights_from_disk with layerwise offload enabled.""" - @pytest.fixture( - scope="class", - params=_ACTIVE_MODEL_PAIRS, - ids=[p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS], - ) + @pytest.fixture(scope="class", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS) def diffusion_server_with_offload(self, request): - """Start a diffusion server with layerwise offload enabled. - - Disk checksums are already cached by diffusion_server_no_offload - (which runs first), so no background precomputation is needed here. - """ + """Start a diffusion server with layerwise offload enabled.""" default_model, update_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) @@ -763,61 +726,6 @@ def diffusion_server_with_offload(self, request): finally: ctx.cleanup() - def _get_base_url(self, ctx: ServerContext) -> str: - return f"http://localhost:{ctx.port}" - - def _update_weights( - self, base_url: str, model_path: str, **kwargs - ) -> tuple[dict, int]: - payload = {"model_path": model_path, **kwargs} - response = requests.post( - f"{base_url}/update_weights_from_disk", - json=payload, - timeout=kwargs.get("timeout", 300), - ) - return response.json(), response.status_code - - def _get_weights_checksum( - self, - base_url: str, - module_names: list[str] | None = None, - timeout: int = 300, - ) -> dict: - """Call get_weights_checksum API and return the checksum dict.""" - payload = {} - if module_names is not None: - payload["module_names"] = module_names - - response = requests.post( - f"{base_url}/get_weights_checksum", - json=payload, - timeout=timeout, - ) - assert ( - response.status_code == 200 - ), f"get_weights_checksum failed: {response.status_code} {response.text}" - return response.json() - - def _assert_server_checksums_match_instruct_disk( - self, base_url: str, default_model: str, update_model: str - ) -> None: - """Assert changed modules on server match update model disk (server healthy).""" - all_checksums = self._get_weights_checksum(base_url, module_names=None) - module_names = [k for k in all_checksums if all_checksums.get(k) != "not_found"] - changed_modules = _get_modules_with_different_checksums( - default_model, update_model, module_names - ) - if not changed_modules: - pytest.skip("No checksum-different shared modules in model pair") - - for name in changed_modules: - server_cs = all_checksums.get(name) - instruct_disk_cs = _compute_checksum_from_disk(update_model, name) - assert server_cs == instruct_disk_cs, ( - f"Server checksum for '{name}' should match instruct model disk\n" - f" instruct_disk: {instruct_disk_cs}\n server: {server_cs}" - ) - def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): """Offload: base→instruct update; no Shape mismatch; checksums == instruct disk.""" ctx, default_model, update_model = diffusion_server_with_offload @@ -830,8 +738,8 @@ def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload message = result.get("message", "") assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" - self._assert_server_checksums_match_instruct_disk( - base_url, default_model, update_model + self._assert_server_matches_model_on_changed_modules( + base_url, default_model, update_model, update_model ) From 14e69ec8f07a8180c7feb449ce5b53e71936fe61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Sat, 14 Feb 2026 11:44:49 -0800 Subject: [PATCH 16/30] Update docstring for GetWeightsChecksumReqInput --- .../runtime/entrypoints/post_training/io_struct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py index 749a39e817f8..bda72df12a8f 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -14,6 +14,6 @@ class UpdateWeightFromDiskReqInput: @dataclass class GetWeightsChecksumReqInput: - """Request to compute SHA-256 checksum of loaded module weights.""" + """Compute SHA-256 checksum of loaded module weights for verification.""" module_names: list[str] | None = None From b64b185ff05d678920179e31c521f7aa33b5e1bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Sat, 14 Feb 2026 11:52:10 -0800 Subject: [PATCH 17/30] Refine docstring for weight checksum verification Updated docstring for compute_checksum function to clarify its purpose and usage. --- .../multimodal_gen/runtime/loader/weight_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py index 6b7f24bafdbe..517cecb37a25 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -343,14 +343,9 @@ def compute_weights_checksum( ) -> str: """Compute a SHA-256 checksum for a set of (name, tensor) pairs. - Helper function for verifying the correctness of weight refitting - (update_weights_from_disk). After a refit, callers can compare the - checksum of the in-GPU model weights against the checksum of the - on-disk tensors to confirm they match exactly. - - Parameters are sorted by name so the digest is deterministic - regardless of iteration order. Raw bytes are hashed directly - (no dtype conversion) for speed and fidelity. + Used to verify the correctness of weight refitting. After a refit, + compare the checksum of the in-GPU model weights against the checksum + of the on-disk tensors or the tensors in the training engine. """ hasher = hashlib.sha256() for name, tensor in sorted(named_params, key=lambda x: x[0]): From d68da5ceb43813dd34ffbfdf257922235c11d51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Sat, 14 Feb 2026 12:02:13 -0800 Subject: [PATCH 18/30] Simplify comments in layerwise offload method Refactor documentation for clarity and conciseness. --- .../runtime/utils/layerwise_offload.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 089cc608ab69..1ab7869e57e5 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -285,13 +285,11 @@ def update_cpu_weights( When layerwise offload (--dit-layerwise-offload) is enabled, the offload manager replaces GPU parameters with small torch.empty((1,)) placeholders while real weights live in consolidated pinned CPU - buffers. A naive param.data.copy_() would fail with a shape - mismatch. Instead, this method writes new weights directly into - the CPU buffers, bypassing the placeholders entirely. For any - layer that happens to be resident on GPU at update time, the live - GPU tensor is also updated so the change takes effect immediately. - This requires no extra GPU memory and does not disturb the offload - state. + buffers. + + The refit process writes new weights directly into the CPU buffers, + bypassing the placeholders. For any layer that happens to be resident + on the GPU at update time, the live GPU tensor is also updated. Args: weight_dict: Mapping of parameter name to new weight tensor. @@ -301,7 +299,7 @@ def update_cpu_weights( Raises: ValueError: If a weight's shape does not match the recorded - metadata (i.e. the real shape, not the placeholder shape). + metadata (i.e., the real shape, not the placeholder shape). """ if not self.enabled: return None From d3728cb385d7e3868c88848f54b6533c2f096376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Sat, 14 Feb 2026 12:06:46 -0800 Subject: [PATCH 19/30] Improve documentation for iter_materialized_weights Clarify the function's behavior regarding offloaded layers and ensure that callers see real tensors. --- .../sglang/multimodal_gen/runtime/utils/layerwise_offload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 1ab7869e57e5..7ab3bc5da887 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -469,9 +469,9 @@ def iter_materialized_weights(module: torch.nn.Module): """Yield (name, tensor) pairs with materialized weights, even under offload. When layerwise offload is active, module.named_parameters() returns - (1,) placeholders for offloaded layers. This helper reads the + (1,) placeholders for offloaded layers. This function reads the actual data from the offload manager's CPU buffers and chains it with - the non-offloaded parameters so callers always see real tensors. + the non-offloaded parameters. """ offload_managers: list = [] if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: From f831d946b671a359464afcf1a2700defbf81b8e3 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sat, 14 Feb 2026 18:06:45 -0800 Subject: [PATCH 20/30] fix lint --- python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py index 7ab3bc5da887..0bdbd841db2f 100644 --- a/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -286,7 +286,7 @@ def update_cpu_weights( offload manager replaces GPU parameters with small torch.empty((1,)) placeholders while real weights live in consolidated pinned CPU buffers. - + The refit process writes new weights directly into the CPU buffers, bypassing the placeholders. For any layer that happens to be resident on the GPU at update time, the live GPU tensor is also updated. From bd235198db10ebc5ca81ee410cfc5ece80c2ece6 Mon Sep 17 00:00:00 2001 From: Mengyang Liu Date: Sun, 15 Feb 2026 08:19:58 +0000 Subject: [PATCH 21/30] Refactor update_weights_from_disk tests --- .../multimodal_gen/runtime/loader/utils.py | 7 +- .../runtime/loader/weights_updater.py | 151 ++--- .../server/test_update_weights_from_disk.py | 617 ++++++++---------- 3 files changed, 337 insertions(+), 438 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/utils.py b/python/sglang/multimodal_gen/runtime/loader/utils.py index 725cf32265ab..39ca22a34125 100644 --- a/python/sglang/multimodal_gen/runtime/loader/utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/utils.py @@ -8,6 +8,7 @@ import re from collections import defaultdict from collections.abc import Callable, Iterator +from pathlib import Path from typing import Any, Dict, Type import torch @@ -145,14 +146,14 @@ def _list_safetensors_files(model_path: str) -> list[str]: return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors"))) -def find_weights_dir(local_path: str, module_name: str) -> str | None: +def find_weights_dir(local_path: str, module_name: str) -> Path | None: """Locate the safetensors directory for module_name under local_path. Diffusion models store weights in per-module subdirectories (e.g. transformer/, vae/, text_encoder/). """ - dir_path = os.path.join(local_path, module_name) - if os.path.exists(dir_path): + dir_path = Path(local_path) / module_name + if dir_path.exists(): return dir_path return None diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index f4422b74bc2f..440a96e32e5a 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -87,6 +87,77 @@ def get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]: return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)} +def _get_weights_iter(weights_dir: str): + """Return a (name, tensor) iterator over safetensors in weights_dir.""" + safetensors_files = _list_safetensors_files(weights_dir) + if not safetensors_files: + raise FileNotFoundError(f"No safetensors files found in {weights_dir}") + return safetensors_weights_iterator(safetensors_files) + + +def _validate_weight_files( + local_model_path: str, + modules_to_update: list[tuple[str, torch.nn.Module]], +) -> tuple[dict[str, str], list[str]]: + """Check that every module has a weights directory with safetensors files. + + Returns: + (weights_map, missing) where weights_map maps module name to its + weights directory and missing lists modules without weight files. + """ + weights_map: dict[str, str] = {} + missing: list[str] = [] + for module_name, _ in modules_to_update: + weights_dir = find_weights_dir(local_model_path, module_name) + if weights_dir and _list_safetensors_files(weights_dir): + weights_map[module_name] = weights_dir + else: + missing.append(module_name) + return weights_map, missing + + +def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: + """Load weights into a module, handling offload-managed parameters. + + For offloaded modules, updates CPU buffers directly via + update_cpu_weights(); non-offloaded parameters use in-place copy. + """ + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + + if offload_managers: + weight_dict = dict(weights_iter) + offloaded_names: set[str] = set() + for manager in offload_managers: + offloaded_names.update(manager.update_cpu_weights(weight_dict)) + remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) + load_weights_into_model(remaining, dict(module.named_parameters())) + else: + load_weights_into_model(weights_iter, dict(module.named_parameters())) + + +def load_weights_into_model(weights_iter, model_params: dict) -> None: + """Copy weights from weights_iter into model_params in-place.""" + for name, loaded_weight in weights_iter: + if name not in model_params: + continue + param = model_params[name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" + ) + if isinstance(param, DTensor): + distributed_weight = distribute_tensor( + loaded_weight.to(param.dtype), + param.device_mesh, + param.placements, + ) + param._local_tensor.copy_(distributed_weight._local_tensor) + else: + param.data.copy_(loaded_weight.to(param.dtype)) + + class WeightsUpdater: """In-place weight updates for diffusion pipeline modules. @@ -168,10 +239,6 @@ def update_weights_from_disk( logger.info(message) return success, message - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - def _collect_modules( self, target_modules: list[str] | None ) -> list[tuple[str, torch.nn.Module]]: @@ -244,79 +311,3 @@ def _rollback(self, updated_modules: list[str]) -> None: continue weights_iter = _get_weights_iter(weights_dir) _load_weights_into_module(module, weights_iter) - - -# --------------------------------------------------------------------------- -# Module-level utility functions -# --------------------------------------------------------------------------- - - -def _get_weights_iter(weights_dir: str): - """Return a (name, tensor) iterator over safetensors in weights_dir.""" - safetensors_files = _list_safetensors_files(weights_dir) - if not safetensors_files: - raise FileNotFoundError(f"No safetensors files found in {weights_dir}") - return safetensors_weights_iterator(safetensors_files) - - -def _validate_weight_files( - local_model_path: str, - modules_to_update: list[tuple[str, torch.nn.Module]], -) -> tuple[dict[str, str], list[str]]: - """Check that every module has a weights directory with safetensors files. - - Returns: - (weights_map, missing) where weights_map maps module name to its - weights directory and missing lists modules without weight files. - """ - weights_map: dict[str, str] = {} - missing: list[str] = [] - for module_name, _ in modules_to_update: - weights_dir = find_weights_dir(local_model_path, module_name) - if weights_dir and _list_safetensors_files(weights_dir): - weights_map[module_name] = weights_dir - else: - missing.append(module_name) - return weights_map, missing - - -def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: - """Load weights into a module, handling offload-managed parameters. - - For offloaded modules, updates CPU buffers directly via - update_cpu_weights(); non-offloaded parameters use in-place copy. - """ - offload_managers: list = [] - if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: - offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] - - if offload_managers: - weight_dict = dict(weights_iter) - offloaded_names: set[str] = set() - for manager in offload_managers: - offloaded_names.update(manager.update_cpu_weights(weight_dict)) - remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) - load_weights_into_model(remaining, dict(module.named_parameters())) - else: - load_weights_into_model(weights_iter, dict(module.named_parameters())) - - -def load_weights_into_model(weights_iter, model_params: dict) -> None: - """Copy weights from weights_iter into model_params in-place.""" - for name, loaded_weight in weights_iter: - if name not in model_params: - continue - param = model_params[name] - if param.shape != loaded_weight.shape: - raise ValueError( - f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" - ) - if isinstance(param, DTensor): - distributed_weight = distribute_tensor( - loaded_weight.to(param.dtype), - param.device_mesh, - param.placements, - ) - param._local_tensor.copy_(distributed_weight._local_tensor) - else: - param.data.copy_(loaded_weight.to(param.dtype)) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index a7eecf80100f..15a4e051a7e0 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -8,14 +8,22 @@ Menyang Liu, https://github.com/dreamyang-liu Chenyang Zhao, https://github.com/zhaochenyang20 -We use two model pairs for testing (base model / instruct model pairs): +We use two model pairs for testing (base model / source model pairs): - FLUX.2-klein-base-4B / FLUX.2-klein-4B - Qwen/Qwen-Image / Qwen/Qwen-Image-2512 -These model pairs share the same architecture, but not every module is -guaranteed to have different weights between base and update models. -Some modules can be identical across the pair. +These model pairs share the same architecture but differ in transformer +weights (all other modules — vae, text_encoder, … — are identical). + +The source model is not used directly by any test. Instead, at fixture +setup time we clone it and perturb its vae weights, producing a synthetic +perturbed checkpoint (perturbed_vae_model_dir) where both transformer AND +vae differ from the base model. This perturbed checkpoint is used in all +tests, giving us two modules with known different checksums to verify. + +NOTE: Disk-vs-server checksum verification currently ONLY covers transformer. +Other modules have weight-name remapping / QKV merge mismatches to resolve first. ============================================================================= @@ -39,56 +47,48 @@ All tests share one class-scoped server (same process, same in-memory weights). Tests that require "base model then update" should be explicitly reset to default_model first so behavior is order-independent and updates are real - (base→update), not no-ops (update→update). +(base→perturbed), not no-ops (perturbed→perturbed). • test_update_weights_from_disk_default - base -> instruct with flush_cache=True. Verifies: - (1) before-update checksum == base model disk checksum; - (2) after-update checksum == instruct model disk checksum; - (3) before != after (update actually changed weights). - rollback to base model after update. + base -> perturbed with flush_cache=True. + Verifies after-update checksum == perturbed checkpoint disk checksum + (implicitly confirms weights changed, since fixture guarantees + base ≠ perturbed). • test_update_weights_specific_modules - base -> instruct with flush_cache=False: randomly selects target_modules, - updates only those from base to instruct model. Verifies: - (1) updated modules' checksums match instruct model disk checksum; - (2) non-updated modules' checksums are unchanged (before == after == disk). - rollback to base model after update. + base -> perturbed with flush_cache=False. Randomly selects one module + from _DIFFERING_MODULES (modules whose weights differ between base and + perturbed checkpoint) as target_modules, updates only that module. Verifies: + (1) targeted module's in-memory checksum changed; + (2) non-targeted modules' in-memory checksums are unchanged. • test_update_weights_nonexistent_model model_path set to a non-existent path; must fail (400, success=False). - Ensure server is healthy after inaccurate update and server's checksums - equals to base model's disk checksums. + Ensure server is healthy after failed update and server's checksums + equal base model's disk checksums. • test_update_weights_missing_model_path Request body empty (no model_path); must fail (400, success=False). - Ensure server is healthy after inaccurate update and server's checksums - equals to base model's disk checksums. + Ensure server is healthy after failed update and server's checksums + equal base model's disk checksums. • test_update_weights_nonexistent_module target_modules=["nonexistent_module"]; must fail (400, success=False). - Verify server is healthy after inaccurate update and server's checksums - equals to base model's disk checksums. + Verify server is healthy after failed update and server's checksums + equal base model's disk checksums. • test_corrupted_weights_rollback - Verify base -> instruct rollback after loading corrupted instruct model. - Builds a corrupted model directory by copying the instruct model and - truncating the vae safetensors. Updates with target_modules=["transformer", - "vae"]. The transformer updates successfully first; the corrupted vae module - then fails during safetensors validation, triggering a rollback that restores - the transformer to its previous weights. - - Ensure server is healthy after rollback and server's checksums equals to - base model's disk checksums. + All-or-nothing rollback: base→perturbed succeeds, then perturbed→corrupted + fails (truncated vae), server rolls back to the perturbed checkpoint. ----------------------------------------------------------------------------- @@ -102,9 +102,8 @@ • test_update_weights_with_offload_enabled - Server with --dit-layerwise-offload (base). Update to instruct; must succeed - (200, success=True), message must not contain "Shape mismatch". Assert server - checksums == instruct model disk checksums (server healthy). + Server with --dit-layerwise-offload (base). Load perturbed checkpoint; + must succeed (200, success=True), no "Shape mismatch". Checksums match disk. """ from __future__ import annotations @@ -115,6 +114,8 @@ import shutil import tempfile import threading +from collections.abc import Callable +from enum import StrEnum import pytest import requests @@ -137,6 +138,18 @@ logger = init_logger(__name__) + +class _Module(StrEnum): + """Updatable pipeline module names.""" + + TRANSFORMER = "transformer" + VAE = "vae" + + +# Modules whose weights differ between the base model and the synthetic +# perturbed checkpoint +_DIFFERING_MODULES: list[str] = [_Module.TRANSFORMER, _Module.VAE] + _ALL_MODEL_PAIRS: list[tuple[str, str, float]] = [ ( "black-forest-labs/FLUX.2-klein-base-4B", @@ -152,7 +165,7 @@ def _select_model_pairs() -> list[tuple[str, str]]: - """Return the (default, update) model pairs to test. + """Return the (default, source) model pairs to test. When SGLANG_TEST_DIFFUSION_MODEL / SGLANG_TEST_UPDATE_MODEL env vars are set, use them as a single explicit pair. Otherwise, run both @@ -193,48 +206,17 @@ def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) -def _get_modules_with_weights_on_disk( - model_path: str, module_names: list[str] -) -> list[str]: - """Return module names that have safetensors on disk for the given model.""" - local_path = maybe_download_model(model_path) - result = [] - for name in module_names: - weights_dir = find_weights_dir(local_path, name) - if weights_dir and _list_safetensors_files(weights_dir): - result.append(name) - return result - - -def _get_modules_with_different_checksums( - base_model: str, update_model: str, module_names: list[str] -) -> list[str]: - """Return shared modules whose disk checksums differ across model pair.""" - base_modules = set(_get_modules_with_weights_on_disk(base_model, module_names)) - update_modules = set(_get_modules_with_weights_on_disk(update_model, module_names)) - shared_modules = sorted(base_modules & update_modules) - - changed_modules = [] - for name in shared_modules: - base_cs = _compute_checksum_from_disk(base_model, name) - update_cs = _compute_checksum_from_disk(update_model, name) - if base_cs != update_cs: - changed_modules.append(name) - return changed_modules - - -def _prepare_corrupted_model( - src_model: str, dst_model: str, corrupt_module: str +def _clone_model_with_modified_module( + src_model: str, + dst_model: str, + target_module: str, + transform_safetensor: Callable[[str, str], None], ) -> None: - """Build a corrupted model directory from src_model. + """Clone a model directory via symlinks, applying transform to one module. - Uses symlinks for everything except the corrupt_module directory to - save disk space and time. Only the corrupt_module's safetensors are - physically copied and then truncated so that safetensors_weights_iterator - detects corruption at load time, triggering a rollback. - - Must be called before every test attempt because the server deletes - corrupted files on detection. + Everything is symlinked except the target module's first .safetensors + file, which is transformed (causing a checksum difference or corruption); + remaining files are symlinked for speed. """ # Symlink root-level files (model_index.json, etc.). for fname in os.listdir(src_model): @@ -249,37 +231,52 @@ def _prepare_corrupted_model( if not os.path.isdir(src_dir): continue - # Non-corrupted modules: symlink the entire directory. - if module_dir != corrupt_module: + if module_dir != target_module: if not os.path.exists(dst_dir): os.symlink(src_dir, dst_dir) continue - # Corrupted module: create a real directory, symlink non-safetensors - # files, and copy + truncate safetensors files. os.makedirs(dst_dir, exist_ok=True) - for fname in os.listdir(src_dir): + transformed = False + for fname in sorted(os.listdir(src_dir)): src_file = os.path.join(src_dir, fname) dst_file = os.path.join(dst_dir, fname) if not os.path.isfile(src_file): continue - if not fname.endswith(".safetensors"): + if not fname.endswith(".safetensors") or transformed: if not os.path.exists(dst_file): os.symlink(src_file, dst_file) continue - # Copy safetensors then truncate to corrupt it. - shutil.copy2(src_file, dst_file) - size = os.path.getsize(dst_file) - with open(dst_file, "r+b") as f: - f.truncate(size - 1000) - logger.info( - "Created corrupted safetensors: %s (%d -> %d bytes)", - dst_file, - size, - size - 1000, - ) + transform_safetensor(src_file, dst_file) + transformed = True + + +def _truncate_safetensor(src_file: str, dst_file: str) -> None: + """Copy then truncate — produces an invalid safetensors that triggers rollback.""" + shutil.copy2(src_file, dst_file) + size = os.path.getsize(dst_file) + with open(dst_file, "r+b") as f: + f.truncate(size - 1000) + logger.info( + "Created corrupted safetensors: %s (%d -> %d bytes)", + dst_file, + size, + size - 1000, + ) + + +def _perturb_safetensor(src_file: str, dst_file: str) -> None: + """Load, add small perturbation to floating-point tensors, and save.""" + from safetensors.torch import load_file, save_file + + tensors = load_file(src_file) + perturbed = { + k: (t + 0.01 if t.is_floating_point() else t) for k, t in tensors.items() + } + save_file(perturbed, dst_file) + logger.info("Created perturbed safetensors: %s", dst_file) class _UpdateWeightsApiMixin: @@ -323,28 +320,31 @@ def _get_weights_checksum( ), f"get_weights_checksum failed: {response.status_code} {response.text}" return response.json() - def _assert_server_matches_model_on_changed_modules( + def _assert_server_matches_model( self, base_url: str, - base_model: str, - update_model: str, expected_model: str, ) -> None: - all_checksums = self._get_weights_checksum(base_url) - module_names = [k for k, v in all_checksums.items() if v != "not_found"] - changed_modules = _get_modules_with_different_checksums( - base_model, update_model, module_names + """Assert the server's transformer checksum matches expected_model on disk. + + Only the transformer is verified because weight-name remapping and + QKV merge during model loading cause in-memory parameter names/shapes + to diverge from on-disk safetensors for other modules (e.g. vae), + making their checksums incomparable. + + TODO: Extend to verify all modules once these + discrepancies are resolved. + """ + server_checksums = self._get_weights_checksum( + base_url, module_names=[_Module.TRANSFORMER] + ) + expected_cs = _compute_checksum_from_disk(expected_model, _Module.TRANSFORMER) + server_cs = server_checksums.get(_Module.TRANSFORMER) + assert server_cs == expected_cs, ( + f"Checksum mismatch on '{_Module.TRANSFORMER}'\n" + f" expected({expected_model}): {expected_cs}\n" + f" server: {server_cs}" ) - if not changed_modules: - pytest.skip("No checksum-different shared modules in model pair") - for name in changed_modules: - server_cs = all_checksums.get(name) - expected_cs = _compute_checksum_from_disk(expected_model, name) - assert server_cs == expected_cs, ( - f"Checksum mismatch on '{name}'\n" - f" expected({expected_model}): {expected_cs}\n" - f" server: {server_cs}" - ) class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): @@ -362,11 +362,17 @@ class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): def diffusion_server_no_offload(self, request): """Start a diffusion server (no offload) for this test class. - Precomputes disk checksums for the update model in background threads - while the server is starting, so they are already cached (via lru_cache) - by the time tests need them. + Builds two synthetic checkpoints from the source model: + - perturbed_vae_model_dir: source model with perturbed vae (both + transformer and vae differ from base). + - corrupted_vae_model_dir: base model with truncated vae — triggers + load failure for rollback testing. + + Checksum cache warmup and synthetic checkpoints building run in background + threads while the server boots, so everything is ready by the time + tests start. """ - default_model, update_model = request.param + default_model, source_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) @@ -377,180 +383,129 @@ def diffusion_server_no_offload(self, request): extra_args="--num-gpus 1", ) - # Warm the lru_cache while the server boots (disk I/O is independent). - checksum_threads = [ + # Ensure models are local before spawning threads that need the paths. + local_default = maybe_download_model(default_model) + local_source = maybe_download_model(source_model) + + perturbed_vae_model_dir = tempfile.mkdtemp(prefix="sglang_perturbed_vae_") + corrupted_vae_model_dir = tempfile.mkdtemp(prefix="sglang_corrupted_") + + # Run all disk I/O in background while the server boots. + bg_threads = [ threading.Thread( - target=_compute_checksum_from_disk, args=(update_model, module) + target=_compute_checksum_from_disk, args=(default_model, module) ) - for module in ("transformer", "vae") + for module in _DIFFERING_MODULES + ] + [ + threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_source, + perturbed_vae_model_dir, + _Module.VAE, + _perturb_safetensor, + ), + ), + threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_default, + corrupted_vae_model_dir, + _Module.VAE, + _truncate_safetensor, + ), + ), ] - for t in checksum_threads: + for t in bg_threads: t.start() ctx = manager.start() - for t in checksum_threads: + for t in bg_threads: t.join() + # Sanity: all _DIFFERING_MODULES should differ between base and perturbed. + for module in _DIFFERING_MODULES: + assert _compute_checksum_from_disk( + default_model, module + ) != _compute_checksum_from_disk(perturbed_vae_model_dir, module), ( + f"Assumption violated: {module} should differ between " + f"{default_model} and {perturbed_vae_model_dir}" + ) + try: - yield ctx, default_model, update_model + yield ctx, default_model, perturbed_vae_model_dir, corrupted_vae_model_dir finally: ctx.cleanup() - - @pytest.fixture(scope="class") - def corrupted_model_dir(self, diffusion_server_no_offload): - """Create a separate temporary directory per parametrized model pair.""" - tmpdir = tempfile.mkdtemp(prefix="sglang_corrupted_model_") - yield tmpdir - shutil.rmtree(tmpdir, ignore_errors=True) + shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) + shutil.rmtree(corrupted_vae_model_dir, ignore_errors=True) def test_update_weights_from_disk_default(self, diffusion_server_no_offload): - """Base→instruct with flush_cache=True; verify before/after; rollback to base. - - Resets to base, records before checksum. Updates to instruct with - flush_cache=True. Asserts: (1) before == base disk; (2) after == instruct - disk; (3) before != after. Then rollback to base so server ends on base. - """ - ctx, default_model, update_model = diffusion_server_no_offload + """Default update (target_modules=None, flush_cache=True): all changed modules updated.""" + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload base_url = self._get_base_url(ctx) - # Reset to base so we have a real base→instruct. self._update_weights(base_url, default_model) - before_checksum = self._get_weights_checksum( - base_url, module_names=["transformer"] - )["transformer"] - base_disk = _compute_checksum_from_disk(default_model, "transformer") - - result, status_code = self._update_weights( - base_url, - update_model, - flush_cache=True, - ) - assert status_code == 200 and result.get( - "success" - ), f"Update failed: {result.get('message')}" - - after_checksum = self._get_weights_checksum( - base_url, module_names=["transformer"] - )["transformer"] - instruct_disk = _compute_checksum_from_disk(update_model, "transformer") - - assert before_checksum == base_disk, ( - f"Before-update checksum should match base model disk\n" - f" base_disk: {base_disk}\n before: {before_checksum}" - ) - assert after_checksum == instruct_disk, ( - f"After-update checksum should match instruct model disk\n" - f" instruct_disk: {instruct_disk}\n after: {after_checksum}" - ) - assert ( - before_checksum != after_checksum - ), "Before and after checksums should differ (update changed weights)" + result, status_code = self._update_weights(base_url, perturbed_model_dir) + assert status_code == 200 + assert result.get("success", False), f"Update failed: {result.get('message')}" - # Rollback to base so server ends in known state. - self._update_weights(base_url, default_model) + self._assert_server_matches_model(base_url, perturbed_model_dir) def test_update_weights_specific_modules(self, diffusion_server_no_offload): - """Partial update base→instruct with flush_cache=False; verify checksums; rollback to base. - - Randomly picks target_modules, updates only those to instruct with - flush_cache=False. Asserts: - (1) for modules whose base/update disk checksums differ, updated modules - match update-model disk and actually change; - (2) for modules with identical base/update checksums, updating them keeps - checksums unchanged; - (3) non-updated modules remain unchanged (before == after). - Then rollback to base. + """Verify target_modules filtering: only the specified module is updated. + + The perturbed checkpoint has different weights for both transformer and + vae. This test randomly picks ONE of them as target_modules and loads + from the perturbed checkpoint. Assertions: + (1) the targeted module's in-memory checksum changed (before != after); + (2) every non-targeted module's in-memory checksum is unchanged, + proving the server only touched what was requested. """ - ctx, default_model, update_model = diffusion_server_no_offload + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload base_url = self._get_base_url(ctx) - # Reset to base model so we start from a known state. + # Reset server to default_model. self._update_weights(base_url, default_model) - - # All pipeline module names (from server). - all_checksums = self._get_weights_checksum(base_url, module_names=None) - all_module_names = [k for k in all_checksums if all_checksums[k] != "not_found"] - if not all_module_names: - pytest.skip("No updatable modules reported by server") - - # Only consider modules that have weights on disk in both models. - base_modules = set( - _get_modules_with_weights_on_disk(default_model, all_module_names) - ) - update_modules = set( - _get_modules_with_weights_on_disk(update_model, all_module_names) - ) - candidates = sorted(base_modules & update_modules) - if not candidates: - pytest.skip("No shared modules with weights on disk in model pair") - - changed_modules = _get_modules_with_different_checksums( - default_model, update_model, candidates - ) - if not changed_modules: - pytest.skip("No checksum-different shared modules in model pair") - - # Random non-empty subset (fixed seed) that always includes one changed module. - random.seed(42) - must_include = random.choice(changed_modules) - optional = [m for m in candidates if m != must_include] - k_extra = random.randint(0, len(optional)) - target_modules = [must_include] + random.sample(optional, k_extra) - target_set = set(target_modules) - changed_set = set(changed_modules) - logger.info( - "Partial update test (flush_cache=False): target_modules=%s (checksum-different modules: %s)", - target_modules, - changed_modules, + before_checksums = self._get_weights_checksum( + base_url, module_names=_DIFFERING_MODULES ) - before_checksums = self._get_weights_checksum(base_url, module_names=None) - + target_modules = [random.choice(_DIFFERING_MODULES)] result, status_code = self._update_weights( base_url, - update_model, + perturbed_model_dir, target_modules=target_modules, flush_cache=False, ) assert status_code == 200, f"Update failed: {result}" assert result.get("success", False), f"Update failed: {result.get('message')}" - after_checksums = self._get_weights_checksum(base_url, module_names=None) - - for name in all_module_names: - if name in target_set: - if name in changed_set: - disk_cs = _compute_checksum_from_disk(update_model, name) - assert after_checksums.get(name) == disk_cs, ( - f"Updated module '{name}': checksum should match update model disk\n" - f" disk: {disk_cs}\n gpu: {after_checksums.get(name)}" - ) - assert after_checksums.get(name) != before_checksums.get(name), ( - f"Updated module '{name}' should change checksum (base != update)\n" - f" before: {before_checksums.get(name)}\n" - f" after: {after_checksums.get(name)}" - ) - else: - assert after_checksums.get(name) == before_checksums.get(name), ( - f"Updated module '{name}' has identical base/update disk checksum, " - "so it should remain unchanged\n" - f" before: {before_checksums.get(name)}\n" - f" after: {after_checksums.get(name)}" - ) - else: - assert after_checksums.get(name) == before_checksums.get(name), ( - f"Non-updated module '{name}': checksum must be unchanged\n" - f" before: {before_checksums.get(name)}\n" - f" after: {after_checksums.get(name)}" - ) - - # Rollback to base so server ends in known state. - self._update_weights(base_url, default_model) + after_checksums = self._get_weights_checksum( + base_url, module_names=_DIFFERING_MODULES + ) + + # Targeted module should have changed. + for name in target_modules: + assert after_checksums.get(name) != before_checksums.get(name), ( + f"Targeted module '{name}' checksum should change after update\n" + f" before: {before_checksums.get(name)}\n" + f" after: {after_checksums.get(name)}" + ) + + # Non-targeted modules should be unchanged. + for name, cs in after_checksums.items(): + if name in target_modules or cs == "not_found": + continue + assert cs == before_checksums.get(name), ( + f"Non-targeted module '{name}' should be unchanged\n" + f" before: {before_checksums.get(name)}\n" + f" after: {cs}" + ) def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): """Nonexistent model path must fail (400). Server healthy, checksums == base disk.""" - ctx, default_model, update_model = diffusion_server_no_offload + ctx, default_model, _, _ = diffusion_server_no_offload base_url = self._get_base_url(ctx) self._update_weights(base_url, default_model) @@ -564,13 +519,11 @@ def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): assert status_code == 400, f"Expected 400, got {status_code}" assert not result.get("success", True), "Should fail for nonexistent model" - self._assert_server_matches_model_on_changed_modules( - base_url, default_model, update_model, default_model - ) + self._assert_server_matches_model(base_url, default_model) def test_update_weights_missing_model_path(self, diffusion_server_no_offload): """Request without model_path must fail (400). Server healthy, checksums == base disk.""" - ctx, default_model, update_model = diffusion_server_no_offload + ctx, default_model, _, _ = diffusion_server_no_offload base_url = self._get_base_url(ctx) self._update_weights(base_url, default_model) @@ -582,20 +535,18 @@ def test_update_weights_missing_model_path(self, diffusion_server_no_offload): ) assert response.status_code == 400, f"Expected 400, got {response.status_code}" - self._assert_server_matches_model_on_changed_modules( - base_url, default_model, update_model, default_model - ) + self._assert_server_matches_model(base_url, default_model) def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): """Nonexistent module must fail (400). Server healthy, checksums == base disk.""" - ctx, default_model, update_model = diffusion_server_no_offload + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload base_url = self._get_base_url(ctx) self._update_weights(base_url, default_model) result, status_code = self._update_weights( base_url, - update_model, + perturbed_model_dir, target_modules=["nonexistent_module"], timeout=60, ) @@ -604,102 +555,40 @@ def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): assert status_code == 400, f"Expected 400, got {status_code}" assert not result.get("success", True), "Should fail for nonexistent module" assert "not found in pipeline" in result.get("message", "") - self._assert_server_matches_model_on_changed_modules( - base_url, default_model, update_model, default_model - ) - - def test_corrupted_weights_rollback( - self, - diffusion_server_no_offload, - corrupted_model_dir: str, - ): - """Base→instruct then load corrupted instruct; verify rollback. - - Updates to instruct, then attempts load from corrupted instruct dir - (vae safetensors truncated). Rollback restores to instruct state. - Ensures server healthy: reset to base and assert checksums == base disk. + self._assert_server_matches_model(base_url, default_model) + + def test_corrupted_weights_rollback(self, diffusion_server_no_offload): + """Verify all-or-nothing rollback on corrupted weights. + + Steps: + 1. base → perturbed (succeeds, server now on perturbed checkpoint). + 2. perturbed → corrupted with target_modules=_DIFFERING_MODULES. + The corrupted checkpoint has a truncated vae safetensors file. + Transformer loads first (succeeds), then vae fails during + safetensors parsing, triggering rollback of both modules. + 3. Assert the server rolled back to the perturbed checkpoint, not base. """ - ctx, default_model, update_model = diffusion_server_no_offload - base_url = self._get_base_url(ctx) - rollback_modules = ["transformer", "vae"] - - # --- Step 0: Reset to default model --- - # Previous tests may have left the server on a different model. - result, status_code = self._update_weights(base_url, default_model) - assert status_code == 200 and result.get( - "success" - ), f"Failed to reset to default model: {result.get('message')}" - - # --- Step 1: Get base-model checksums for rollback modules --- - base_checksums = self._get_weights_checksum( - base_url, module_names=rollback_modules + ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = ( + diffusion_server_no_offload ) - logger.info(f"Base model checksums: {base_checksums}") - - # --- Step 2: Update to the update model --- - result, status_code = self._update_weights(base_url, update_model) - assert status_code == 200 - assert result.get( - "success", False - ), f"Weight update failed: {result.get('message')}" - - # --- Step 3: Record update-model checksums for rollback modules --- - update_checksums = self._get_weights_checksum( - base_url, module_names=rollback_modules - ) - logger.info(f"Update model checksums: {update_checksums}") - - assert ( - update_checksums != base_checksums - ), "Base and update checksums should differ" + base_url = self._get_base_url(ctx) - # --- Step 4: Recreate corrupted model, then attempt load --- - # Copy all modules from the base model (valid), but corrupt only the - # vae. With target_modules=["transformer", "vae"], the transformer - # updates successfully first, then vae fails, giving a meaningful - # rollback that actually restores the transformer. - local_base = maybe_download_model(default_model) - _prepare_corrupted_model(local_base, corrupted_model_dir, corrupt_module="vae") + # base → perturbed + self._update_weights(base_url, default_model) + result, status_code = self._update_weights(base_url, perturbed_model_dir) + assert status_code == 200 and result.get("success") + # perturbed → corrupted (should fail and rollback) result, status_code = self._update_weights( base_url, - corrupted_model_dir, - target_modules=rollback_modules, - timeout=120, - ) - logger.info(f"Corrupted update result: status={status_code}, body={result}") - - assert not result.get("success", True), "Loading corrupted weights should fail" - assert ( - "rolled back" in result.get("message", "").lower() - ), f"Expected rollback message, got: {result.get('message')}" - - # --- Step 5: Verify rollback — rollback module checksums must match update model --- - post_rollback_checksums = self._get_weights_checksum( - base_url, module_names=rollback_modules + corrupted_vae_model_dir, + target_modules=_DIFFERING_MODULES, ) - logger.info(f"Post-rollback checksums: {post_rollback_checksums}") + assert not result.get("success", True) + assert "rolled back" in result.get("message", "").lower() - for module in update_checksums: - assert post_rollback_checksums.get(module) == update_checksums[module], ( - f"Module '{module}' checksum mismatch after rollback\n" - f" update: {update_checksums[module]}\n" - f" post-rollback: {post_rollback_checksums.get(module)}" - ) - - assert post_rollback_checksums != base_checksums, ( - "Post-rollback checksums should not match base model " - "(rollback target is instruct model, not base)" - ) - - # Ensure server healthy: reset to base and verify checksums == base disk. - result, status_code = self._update_weights(base_url, default_model) - assert status_code == 200 and result.get( - "success" - ), f"Failed to reset to base after rollback: {result.get('message')}" - self._assert_server_matches_model_on_changed_modules( - base_url, default_model, update_model, default_model - ) + # Verify: server still on perturbed, not base + self._assert_server_matches_model(base_url, perturbed_model_dir) class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): @@ -707,11 +596,29 @@ class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): @pytest.fixture(scope="class", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS) def diffusion_server_with_offload(self, request): - """Start a diffusion server with layerwise offload enabled.""" - default_model, update_model = request.param + """Start a diffusion server with layerwise offload enabled. + + Also builds perturbed_vae_model_dir in a background thread + while the server boots. + """ + default_model, source_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + local_source = maybe_download_model(source_model) + perturbed_vae_model_dir = tempfile.mkdtemp(prefix="sglang_perturbed_vae_") + + clone_thread = threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_source, + perturbed_vae_model_dir, + _Module.VAE, + _perturb_safetensor, + ), + ) + clone_thread.start() + manager = ServerManager( model=default_model, port=port, @@ -720,27 +627,27 @@ def diffusion_server_with_offload(self, request): ) ctx = manager.start() + clone_thread.join() try: - yield ctx, default_model, update_model + yield ctx, default_model, perturbed_vae_model_dir finally: ctx.cleanup() + shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): - """Offload: base→instruct update; no Shape mismatch; checksums == instruct disk.""" - ctx, default_model, update_model = diffusion_server_with_offload + """Offload: base→perturbed; no Shape mismatch; checksums == perturbed disk.""" + ctx, _, perturbed_model_dir = diffusion_server_with_offload base_url = self._get_base_url(ctx) - result, status_code = self._update_weights(base_url, update_model) + result, status_code = self._update_weights(base_url, perturbed_model_dir) assert status_code == 200, f"Expected 200, got {status_code}" assert result.get("success", False), f"Update failed: {result.get('message')}" message = result.get("message", "") assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" - self._assert_server_matches_model_on_changed_modules( - base_url, default_model, update_model, update_model - ) + self._assert_server_matches_model(base_url, perturbed_model_dir) if __name__ == "__main__": From 52ba35b6d99dd180c428c10608a049eeb148a4c1 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sun, 15 Feb 2026 23:26:25 -0800 Subject: [PATCH 22/30] new docs string --- .../server/test_update_weights_from_disk.py | 87 ++++++++++++------- 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 15a4e051a7e0..db6ca7dced75 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -8,22 +8,36 @@ Menyang Liu, https://github.com/dreamyang-liu Chenyang Zhao, https://github.com/zhaochenyang20 -We use two model pairs for testing (base model / source model pairs): +We use two model pairs for testing (base model / instruct model pairs): - FLUX.2-klein-base-4B / FLUX.2-klein-4B - Qwen/Qwen-Image / Qwen/Qwen-Image-2512 These model pairs share the same architecture but differ in transformer -weights (all other modules — vae, text_encoder, … — are identical). - -The source model is not used directly by any test. Instead, at fixture -setup time we clone it and perturb its vae weights, producing a synthetic -perturbed checkpoint (perturbed_vae_model_dir) where both transformer AND -vae differ from the base model. This perturbed checkpoint is used in all -tests, giving us two modules with known different checksums to verify. - -NOTE: Disk-vs-server checksum verification currently ONLY covers transformer. -Other modules have weight-name remapping / QKV merge mismatches to resolve first. +weights. The basic testing logic is to refit the instruct model into the +base model and verify the checksum of the transformer weights are the same, +which simulates the real-world RL scenario. However, since these two model +pairs only differ in transformer weights, and we want to verify update a +specific module with update_weights_from_disk API, we need to create a perturbed +instruct model that adds noise to the vae weights. In this sense, the instruct +model differs from the base model in vae and transformer weights, the text +encoder are still the same. + +To strictly verify the correctness of the refit API, we compare the checksum in +SHA-256 on the disk and the server. + +NOTE and TODO: In the refit a specific module test, we randomly select one module +from the transformer and vae to refit the server and keep other modules the same. +As described above, the vae's weights are perturbed. If we select the vae to be the +target module, ideally speaking, we should assert that the refitted vae's checksum +is the same as directly computed from the perturbed vae weights in the disk. However, +since the there is complex weight-name remapping and QKV merge during model loading, +it is not easy to compare the server-disk checksum for vae and text encoder directly. +Therefore, if the target module is vae, we only verify that the refitted vae's checksum +is different from the base model's vae's checksum. + +It should be good issue to solve for the community to adds comparison the server-disk +checksum for vae and text encoder in this test. ============================================================================= @@ -46,21 +60,21 @@ All tests share one class-scoped server (same process, same in-memory weights). Tests that require "base model then update" should be explicitly reset to -default_model first so behavior is order-independent and updates are real -(base→perturbed), not no-ops (perturbed→perturbed). +base model first so behavior is order-independent and updates are real +(base -> perturbed), not no-ops (perturbed -> perturbed). • test_update_weights_from_disk_default - base -> perturbed with flush_cache=True. - Verifies after-update checksum == perturbed checkpoint disk checksum - (implicitly confirms weights changed, since fixture guarantees - base ≠ perturbed). + base model -> perturbed model with flush_cache=True. + Verifies after-update transformer checksum == perturbed model's + transformer disk checksum + • test_update_weights_specific_modules base -> perturbed with flush_cache=False. Randomly selects one module - from _DIFFERING_MODULES (modules whose weights differ between base and - perturbed checkpoint) as target_modules, updates only that module. Verifies: + from _DIFFERING_MODULES (transformer and vae) as target_modules, updates + only that module. Verifies that: (1) targeted module's in-memory checksum changed; (2) non-targeted modules' in-memory checksums are unchanged. @@ -68,27 +82,39 @@ model_path set to a non-existent path; must fail (400, success=False). - Ensure server is healthy after failed update and server's checksums - equal base model's disk checksums. + Ensure server is healthy after failed update and server's transformer + checksums equal base model's transformer disk checksum. • test_update_weights_missing_model_path Request body empty (no model_path); must fail (400, success=False). - Ensure server is healthy after failed update and server's checksums - equal base model's disk checksums. + Ensure server is healthy after failed update and server's transformer + checksums equal base model's transformer disk checksum. • test_update_weights_nonexistent_module target_modules=["nonexistent_module"]; must fail (400, success=False). Verify server is healthy after failed update and server's checksums - equal base model's disk checksums. + equal base model's transformer disk checksum. • test_corrupted_weights_rollback - All-or-nothing rollback: base→perturbed succeeds, then perturbed→corrupted - fails (truncated vae), server rolls back to the perturbed checkpoint. + All-or-nothing rollback: We first refit the server from base model -> + perturbed model. We manually truncate the vae weights of the base + model to get a corrupted model. We then call the refit to update + the server from the perturbed model -> corrupted model. Verify that: + + 1. The update fails due to truncated vae, server should roll back to the + perturbed model, i.e., server's transformer weights == perturbed model's + transformer weights != base model's transformer weights. + + 2. After the rollback, server's vae weights == perturbed model's vae + weights != base model's vae weights. + + 3. After the rollback, server's text encoder weights == base model's + text encoder weights == perturbed model's text encoder weights. ----------------------------------------------------------------------------- @@ -103,7 +129,8 @@ • test_update_weights_with_offload_enabled Server with --dit-layerwise-offload (base). Load perturbed checkpoint; - must succeed (200, success=True), no "Shape mismatch". Checksums match disk. + must succeed (200, success=True), no "Shape mismatch". server's transformer checksum + matches perturbed model's transformer disk checksum. """ from __future__ import annotations @@ -146,7 +173,7 @@ class _Module(StrEnum): VAE = "vae" -# Modules whose weights differ between the base model and the synthetic +# Modules whose weights differ between the base model and the perturbed # perturbed checkpoint _DIFFERING_MODULES: list[str] = [_Module.TRANSFORMER, _Module.VAE] @@ -362,13 +389,13 @@ class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): def diffusion_server_no_offload(self, request): """Start a diffusion server (no offload) for this test class. - Builds two synthetic checkpoints from the source model: + Builds two perturbed checkpoints from the source model: - perturbed_vae_model_dir: source model with perturbed vae (both transformer and vae differ from base). - corrupted_vae_model_dir: base model with truncated vae — triggers load failure for rollback testing. - Checksum cache warmup and synthetic checkpoints building run in background + Checksum cache warmup and perturbed checkpoints building run in background threads while the server boots, so everything is ready by the time tests start. """ From c3d478e224591c6dab596fbef0de2312dcdbfb7a Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sun, 15 Feb 2026 23:36:41 -0800 Subject: [PATCH 23/30] remove one line function --- .../server/test_update_weights_from_disk.py | 127 ++++++++++-------- 1 file changed, 72 insertions(+), 55 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index db6ca7dced75..f96e23d7800a 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -142,10 +142,10 @@ import tempfile import threading from collections.abc import Callable -from enum import StrEnum import pytest import requests +from safetensors.torch import load_file, save_file from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, @@ -158,7 +158,6 @@ from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger from sglang.multimodal_gen.test.server.test_server_utils import ( - ServerContext, ServerManager, ) from sglang.multimodal_gen.test.test_utils import get_dynamic_server_port, is_in_ci @@ -166,50 +165,30 @@ logger = init_logger(__name__) -class _Module(StrEnum): - """Updatable pipeline module names.""" - - TRANSFORMER = "transformer" - VAE = "vae" +_TRANSFORMER_MODULE = "transformer" +_VAE_MODULE = "vae" +_TEXT_ENCODER_MODULE_PREFIX = "text_encoder" # Modules whose weights differ between the base model and the perturbed # perturbed checkpoint -_DIFFERING_MODULES: list[str] = [_Module.TRANSFORMER, _Module.VAE] +_DIFFERING_MODULES: list[str] = [_TRANSFORMER_MODULE, _VAE_MODULE] -_ALL_MODEL_PAIRS: list[tuple[str, str, float]] = [ +_ALL_MODEL_PAIRS: list[tuple[str, str]] = [ ( "black-forest-labs/FLUX.2-klein-base-4B", "black-forest-labs/FLUX.2-klein-4B", - 5.0, ), ( "Qwen/Qwen-Image", "Qwen/Qwen-Image-2512", - 1.0, # Qwen Image is large; run it less often in CI. ), ] -def _select_model_pairs() -> list[tuple[str, str]]: - """Return the (default, source) model pairs to test. - - When SGLANG_TEST_DIFFUSION_MODEL / SGLANG_TEST_UPDATE_MODEL env vars - are set, use them as a single explicit pair. Otherwise, run both - pairs locally, or randomly pick one in CI (weighted) to save resources. - """ - default_env = os.environ.get("SGLANG_TEST_DIFFUSION_MODEL") - update_env = os.environ.get("SGLANG_TEST_UPDATE_MODEL") - if default_env and update_env: - return [(default_env, update_env)] - pairs = [(d, u) for d, u, _ in _ALL_MODEL_PAIRS] - if is_in_ci(): - weights = [w for _, _, w in _ALL_MODEL_PAIRS] - return random.choices(pairs, weights=weights, k=1) - return pairs - - -_ACTIVE_MODEL_PAIRS = _select_model_pairs() +_ACTIVE_MODEL_PAIRS = ( + _ALL_MODEL_PAIRS if not is_in_ci() else [random.choice(_ALL_MODEL_PAIRS)] +) _PAIR_IDS = [p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS] @@ -285,18 +264,17 @@ def _truncate_safetensor(src_file: str, dst_file: str) -> None: shutil.copy2(src_file, dst_file) size = os.path.getsize(dst_file) with open(dst_file, "r+b") as f: - f.truncate(size - 1000) + f.truncate(size - 2) logger.info( "Created corrupted safetensors: %s (%d -> %d bytes)", dst_file, size, - size - 1000, + size - 2, ) def _perturb_safetensor(src_file: str, dst_file: str) -> None: """Load, add small perturbation to floating-point tensors, and save.""" - from safetensors.torch import load_file, save_file tensors = load_file(src_file) perturbed = { @@ -307,9 +285,6 @@ def _perturb_safetensor(src_file: str, dst_file: str) -> None: class _UpdateWeightsApiMixin: - def _get_base_url(self, ctx: ServerContext) -> str: - return f"http://localhost:{ctx.port}" - def _update_weights( self, base_url: str, @@ -363,12 +338,12 @@ def _assert_server_matches_model( discrepancies are resolved. """ server_checksums = self._get_weights_checksum( - base_url, module_names=[_Module.TRANSFORMER] + base_url, module_names=[_TRANSFORMER_MODULE] ) - expected_cs = _compute_checksum_from_disk(expected_model, _Module.TRANSFORMER) - server_cs = server_checksums.get(_Module.TRANSFORMER) + expected_cs = _compute_checksum_from_disk(expected_model, _TRANSFORMER_MODULE) + server_cs = server_checksums.get(_TRANSFORMER_MODULE) assert server_cs == expected_cs, ( - f"Checksum mismatch on '{_Module.TRANSFORMER}'\n" + f"Checksum mismatch on '{_TRANSFORMER_MODULE}'\n" f" expected({expected_model}): {expected_cs}\n" f" server: {server_cs}" ) @@ -429,7 +404,7 @@ def diffusion_server_no_offload(self, request): args=( local_source, perturbed_vae_model_dir, - _Module.VAE, + _VAE_MODULE, _perturb_safetensor, ), ), @@ -438,7 +413,7 @@ def diffusion_server_no_offload(self, request): args=( local_default, corrupted_vae_model_dir, - _Module.VAE, + _VAE_MODULE, _truncate_safetensor, ), ), @@ -469,11 +444,13 @@ def diffusion_server_no_offload(self, request): def test_update_weights_from_disk_default(self, diffusion_server_no_offload): """Default update (target_modules=None, flush_cache=True): all changed modules updated.""" ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" - self._update_weights(base_url, default_model) + self._update_weights(base_url, default_model, flush_cache=True) - result, status_code = self._update_weights(base_url, perturbed_model_dir) + result, status_code = self._update_weights( + base_url, perturbed_model_dir, flush_cache=True + ) assert status_code == 200 assert result.get("success", False), f"Update failed: {result.get('message')}" @@ -490,7 +467,7 @@ def test_update_weights_specific_modules(self, diffusion_server_no_offload): proving the server only touched what was requested. """ ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" # Reset server to default_model. self._update_weights(base_url, default_model) @@ -533,7 +510,7 @@ def test_update_weights_specific_modules(self, diffusion_server_no_offload): def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): """Nonexistent model path must fail (400). Server healthy, checksums == base disk.""" ctx, default_model, _, _ = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" self._update_weights(base_url, default_model) @@ -551,7 +528,7 @@ def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): def test_update_weights_missing_model_path(self, diffusion_server_no_offload): """Request without model_path must fail (400). Server healthy, checksums == base disk.""" ctx, default_model, _, _ = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" self._update_weights(base_url, default_model) @@ -562,12 +539,14 @@ def test_update_weights_missing_model_path(self, diffusion_server_no_offload): ) assert response.status_code == 400, f"Expected 400, got {response.status_code}" + result = response.json() + assert not result.get("success", True), "Should fail when model_path is missing" self._assert_server_matches_model(base_url, default_model) def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): """Nonexistent module must fail (400). Server healthy, checksums == base disk.""" ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" self._update_weights(base_url, default_model) @@ -598,12 +577,26 @@ def test_corrupted_weights_rollback(self, diffusion_server_no_offload): ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = ( diffusion_server_no_offload ) - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" # base → perturbed self._update_weights(base_url, default_model) + base_checksums = self._get_weights_checksum(base_url) + result, status_code = self._update_weights(base_url, perturbed_model_dir) assert status_code == 200 and result.get("success") + perturbed_checksums = self._get_weights_checksum(base_url) + + text_encoder_modules = sorted( + name + for name in perturbed_checksums + if _TEXT_ENCODER_MODULE_PREFIX in name + and perturbed_checksums.get(name) != "not_found" + and base_checksums.get(name) != "not_found" + ) + assert ( + text_encoder_modules + ), "Expected at least one text encoder module checksum" # perturbed → corrupted (should fail and rollback) result, status_code = self._update_weights( @@ -611,11 +604,35 @@ def test_corrupted_weights_rollback(self, diffusion_server_no_offload): corrupted_vae_model_dir, target_modules=_DIFFERING_MODULES, ) + assert ( + status_code == 400 + ), f"Expected 400 on corrupted weights, got {status_code}" assert not result.get("success", True) assert "rolled back" in result.get("message", "").lower() - - # Verify: server still on perturbed, not base - self._assert_server_matches_model(base_url, perturbed_model_dir) + rolled_back_checksums = self._get_weights_checksum(base_url) + + # 1) transformer: server == perturbed != base + transformer_base = base_checksums.get(_TRANSFORMER_MODULE) + transformer_perturbed = perturbed_checksums.get(_TRANSFORMER_MODULE) + transformer_rolled_back = rolled_back_checksums.get(_TRANSFORMER_MODULE) + assert transformer_rolled_back == transformer_perturbed + assert transformer_rolled_back != transformer_base + + # 2) vae: server == perturbed != base + vae_base = base_checksums.get(_VAE_MODULE) + vae_perturbed = perturbed_checksums.get(_VAE_MODULE) + vae_rolled_back = rolled_back_checksums.get(_VAE_MODULE) + assert vae_rolled_back == vae_perturbed + assert vae_rolled_back != vae_base + + # 3) text encoder(s): server == base == perturbed + for name in text_encoder_modules: + assert rolled_back_checksums.get(name) == perturbed_checksums.get( + name + ), f"Text encoder module '{name}' should stay equal to perturbed" + assert rolled_back_checksums.get(name) == base_checksums.get( + name + ), f"Text encoder module '{name}' should stay equal to base" class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): @@ -640,7 +657,7 @@ def diffusion_server_with_offload(self, request): args=( local_source, perturbed_vae_model_dir, - _Module.VAE, + _VAE_MODULE, _perturb_safetensor, ), ) @@ -665,7 +682,7 @@ def diffusion_server_with_offload(self, request): def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): """Offload: base→perturbed; no Shape mismatch; checksums == perturbed disk.""" ctx, _, perturbed_model_dir = diffusion_server_with_offload - base_url = self._get_base_url(ctx) + base_url = f"http://localhost:{ctx.port}" result, status_code = self._update_weights(base_url, perturbed_model_dir) assert status_code == 200, f"Expected 200, got {status_code}" From cde71fef5339182f74cbed831abbedb05fca1d15 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sun, 15 Feb 2026 23:39:33 -0800 Subject: [PATCH 24/30] consolidate rollback tests --- .../server/test_update_weights_from_disk.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index f96e23d7800a..18dcf6396119 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -568,10 +568,12 @@ def test_corrupted_weights_rollback(self, diffusion_server_no_offload): Steps: 1. base → perturbed (succeeds, server now on perturbed checkpoint). - 2. perturbed → corrupted with target_modules=_DIFFERING_MODULES. + 2. perturbed → corrupted with target_modules=[transformer, vae]. The corrupted checkpoint has a truncated vae safetensors file. - Transformer loads first (succeeds), then vae fails during - safetensors parsing, triggering rollback of both modules. + We explicitly assert the first failed module is vae from the API + error message (which reports the failing module name), proving + transformer was attempted before the vae parse failure and that + rollback then covered both modules. 3. Assert the server rolled back to the perturbed checkpoint, not base. """ ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = ( @@ -599,16 +601,24 @@ def test_corrupted_weights_rollback(self, diffusion_server_no_offload): ), "Expected at least one text encoder module checksum" # perturbed → corrupted (should fail and rollback) + rollback_targets = [_TRANSFORMER_MODULE, _VAE_MODULE] result, status_code = self._update_weights( base_url, corrupted_vae_model_dir, - target_modules=_DIFFERING_MODULES, + target_modules=rollback_targets, ) assert ( status_code == 400 ), f"Expected 400 on corrupted weights, got {status_code}" assert not result.get("success", True) - assert "rolled back" in result.get("message", "").lower() + message = result.get("message", "") + assert "rolled back" in message.lower() + # The updater reports the first failing module in the error message. + # With ordered target_modules=[transformer, vae], this makes the + # failure point explicit: transformer is processed first, then vae fails. + assert ( + "Failed to update module 'vae'" in message + ), f"Expected vae to be the explicit failure point, got: {message}" rolled_back_checksums = self._get_weights_checksum(base_url) # 1) transformer: server == perturbed != base From 40bba8d5af4dc0efa7ebbf912c34d77418c611ac Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sun, 15 Feb 2026 23:43:42 -0800 Subject: [PATCH 25/30] finalize the test --- .../server/test_update_weights_from_disk.py | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 18dcf6396119..42fc471025b8 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -218,12 +218,6 @@ def _clone_model_with_modified_module( target_module: str, transform_safetensor: Callable[[str, str], None], ) -> None: - """Clone a model directory via symlinks, applying transform to one module. - - Everything is symlinked except the target module's first .safetensors - file, which is transformed (causing a checksum difference or corruption); - remaining files are symlinked for speed. - """ # Symlink root-level files (model_index.json, etc.). for fname in os.listdir(src_model): src_path = os.path.join(src_model, fname) @@ -260,7 +254,6 @@ def _clone_model_with_modified_module( def _truncate_safetensor(src_file: str, dst_file: str) -> None: - """Copy then truncate — produces an invalid safetensors that triggers rollback.""" shutil.copy2(src_file, dst_file) size = os.path.getsize(dst_file) with open(dst_file, "r+b") as f: @@ -274,7 +267,6 @@ def _truncate_safetensor(src_file: str, dst_file: str) -> None: def _perturb_safetensor(src_file: str, dst_file: str) -> None: - """Load, add small perturbation to floating-point tensors, and save.""" tensors = load_file(src_file) perturbed = { @@ -327,16 +319,6 @@ def _assert_server_matches_model( base_url: str, expected_model: str, ) -> None: - """Assert the server's transformer checksum matches expected_model on disk. - - Only the transformer is verified because weight-name remapping and - QKV merge during model loading cause in-memory parameter names/shapes - to diverge from on-disk safetensors for other modules (e.g. vae), - making their checksums incomparable. - - TODO: Extend to verify all modules once these - discrepancies are resolved. - """ server_checksums = self._get_weights_checksum( base_url, module_names=[_TRANSFORMER_MODULE] ) @@ -350,11 +332,6 @@ def _assert_server_matches_model( class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): - """Test suite for update_weights_from_disk API and corrupted-weight rollback. - - Uses a class-scoped server fixture so the server is torn down at class end, - freeing the port and GPU memory before the offload class starts. - """ @pytest.fixture( scope="class", @@ -362,18 +339,6 @@ class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): ids=_PAIR_IDS, ) def diffusion_server_no_offload(self, request): - """Start a diffusion server (no offload) for this test class. - - Builds two perturbed checkpoints from the source model: - - perturbed_vae_model_dir: source model with perturbed vae (both - transformer and vae differ from base). - - corrupted_vae_model_dir: base model with truncated vae — triggers - load failure for rollback testing. - - Checksum cache warmup and perturbed checkpoints building run in background - threads while the server boots, so everything is ready by the time - tests start. - """ default_model, source_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) @@ -457,15 +422,6 @@ def test_update_weights_from_disk_default(self, diffusion_server_no_offload): self._assert_server_matches_model(base_url, perturbed_model_dir) def test_update_weights_specific_modules(self, diffusion_server_no_offload): - """Verify target_modules filtering: only the specified module is updated. - - The perturbed checkpoint has different weights for both transformer and - vae. This test randomly picks ONE of them as target_modules and loads - from the perturbed checkpoint. Assertions: - (1) the targeted module's in-memory checksum changed (before != after); - (2) every non-targeted module's in-memory checksum is unchanged, - proving the server only touched what was requested. - """ ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload base_url = f"http://localhost:{ctx.port}" @@ -564,18 +520,6 @@ def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): self._assert_server_matches_model(base_url, default_model) def test_corrupted_weights_rollback(self, diffusion_server_no_offload): - """Verify all-or-nothing rollback on corrupted weights. - - Steps: - 1. base → perturbed (succeeds, server now on perturbed checkpoint). - 2. perturbed → corrupted with target_modules=[transformer, vae]. - The corrupted checkpoint has a truncated vae safetensors file. - We explicitly assert the first failed module is vae from the API - error message (which reports the failing module name), proving - transformer was attempted before the vae parse failure and that - rollback then covered both modules. - 3. Assert the server rolled back to the perturbed checkpoint, not base. - """ ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = ( diffusion_server_no_offload ) @@ -650,11 +594,6 @@ class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): @pytest.fixture(scope="class", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS) def diffusion_server_with_offload(self, request): - """Start a diffusion server with layerwise offload enabled. - - Also builds perturbed_vae_model_dir in a background thread - while the server boots. - """ default_model, source_model = request.param port = get_dynamic_server_port() wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) @@ -690,7 +629,6 @@ def diffusion_server_with_offload(self, request): shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): - """Offload: base→perturbed; no Shape mismatch; checksums == perturbed disk.""" ctx, _, perturbed_model_dir = diffusion_server_with_offload base_url = f"http://localhost:{ctx.port}" From 74676826848df1e6f95fa0a5255147082c2dc187 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 16 Feb 2026 00:12:23 -0800 Subject: [PATCH 26/30] fix CI random choice --- .../sglang/multimodal_gen/test/run_suite.py | 30 +++++++++++++++++++ .../server/test_update_weights_from_disk.py | 26 ++++++++++++++-- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index c9b34ca0b0fe..fc52247749dc 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -10,6 +10,7 @@ import argparse import os +import random import subprocess import sys from pathlib import Path @@ -20,6 +21,13 @@ logger = init_logger(__name__) +_UPDATE_WEIGHTS_FROM_DISK_TEST_FILE = "test_update_weights_from_disk.py" +_UPDATE_WEIGHTS_MODEL_PAIR_ENV = "SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR" +_UPDATE_WEIGHTS_MODEL_PAIR_IDS = ( + "FLUX.2-klein-base-4B", + "Qwen-Image", +) + SUITES = { "1-gpu": [ "test_server_a.py", @@ -226,6 +234,27 @@ def run_pytest(files, filter_expr=None): return returncode +def _is_in_ci() -> bool: + return os.environ.get("SGLANG_IS_IN_CI", "").lower() in ("1", "true", "yes", "on") + + +def _maybe_pin_update_weights_model_pair(suite_files_rel: list[str]) -> None: + if not _is_in_ci(): + return + if _UPDATE_WEIGHTS_FROM_DISK_TEST_FILE not in suite_files_rel: + return + if os.environ.get(_UPDATE_WEIGHTS_MODEL_PAIR_ENV): + print( + f"Using preset {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}=" + f"{os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV]}" + ) + return + + selected_pair = random.choice(_UPDATE_WEIGHTS_MODEL_PAIR_IDS) + os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV] = selected_pair + print(f"Selected {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}={selected_pair} for this CI run") + + def main(): args = parse_args() @@ -240,6 +269,7 @@ def main(): # 2. get files from suite suite_files_rel = SUITES[args.suite] + _maybe_pin_update_weights_model_pair(suite_files_rel) suite_files_abs = [] for f_rel in suite_files_rel: diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 42fc471025b8..276411f583cb 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -186,9 +186,29 @@ ] -_ACTIVE_MODEL_PAIRS = ( - _ALL_MODEL_PAIRS if not is_in_ci() else [random.choice(_ALL_MODEL_PAIRS)] -) +_CI_MODEL_PAIR_ENV = "SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR" + + +def _resolve_active_model_pairs() -> list[tuple[str, str]]: + if not is_in_ci(): + return _ALL_MODEL_PAIRS + + pair_by_id = {base.split("/")[-1]: pair for base, pair in _ALL_MODEL_PAIRS} + selected_pair_id = os.environ.get(_CI_MODEL_PAIR_ENV) + if selected_pair_id is None: + return [random.choice(_ALL_MODEL_PAIRS)] + + selected_pair = pair_by_id.get(selected_pair_id) + if selected_pair is None: + valid_ids = ", ".join(sorted(pair_by_id)) + raise ValueError( + f"Invalid {_CI_MODEL_PAIR_ENV}={selected_pair_id!r}. " + f"Expected one of: {valid_ids}." + ) + return [selected_pair] + + +_ACTIVE_MODEL_PAIRS = _resolve_active_model_pairs() _PAIR_IDS = [p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS] From 3100b2f80eadbf9e16397160509b8f0816d9ebaf Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 16 Feb 2026 00:40:18 -0800 Subject: [PATCH 27/30] fix paring issue --- .../multimodal_gen/test/server/test_update_weights_from_disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index 276411f583cb..c9e01eaa7116 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -193,7 +193,7 @@ def _resolve_active_model_pairs() -> list[tuple[str, str]]: if not is_in_ci(): return _ALL_MODEL_PAIRS - pair_by_id = {base.split("/")[-1]: pair for base, pair in _ALL_MODEL_PAIRS} + pair_by_id = {pair[0].split("/")[-1]: pair for pair in _ALL_MODEL_PAIRS} selected_pair_id = os.environ.get(_CI_MODEL_PAIR_ENV) if selected_pair_id is None: return [random.choice(_ALL_MODEL_PAIRS)] From fb87570021df92b6b176caecc4d24bdc9251118d Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 17 Feb 2026 11:42:29 -0800 Subject: [PATCH 28/30] incline path finding --- .../sglang/multimodal_gen/runtime/loader/utils.py | 13 ------------- .../runtime/loader/weights_updater.py | 14 +++++++------- .../test/server/test_update_weights_from_disk.py | 7 ++++--- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/utils.py b/python/sglang/multimodal_gen/runtime/loader/utils.py index 70981ca14ad2..3f031001167d 100644 --- a/python/sglang/multimodal_gen/runtime/loader/utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/utils.py @@ -9,7 +9,6 @@ import re from collections import defaultdict from collections.abc import Callable, Iterator -from pathlib import Path from typing import Any, Dict, Type import torch @@ -149,18 +148,6 @@ def _list_safetensors_files(model_path: str) -> list[str]: return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors"))) -def find_weights_dir(local_path: str, module_name: str) -> Path | None: - """Locate the safetensors directory for module_name under local_path. - - Diffusion models store weights in per-module subdirectories (e.g. - transformer/, vae/, text_encoder/). - """ - dir_path = Path(local_path) / module_name - if dir_path.exists(): - return dir_path - return None - - def get_memory_usage_of_component(module) -> float | None: """ returned value is in GB, rounded to 2 decimal digits diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 440a96e32e5a..d951a7750287 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -50,6 +50,7 @@ from __future__ import annotations import gc +from pathlib import Path import torch from torch.distributed.tensor import DTensor, distribute_tensor @@ -57,7 +58,6 @@ from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, - find_weights_dir, ) from sglang.multimodal_gen.runtime.loader.weight_utils import ( safetensors_weights_iterator, @@ -108,9 +108,9 @@ def _validate_weight_files( weights_map: dict[str, str] = {} missing: list[str] = [] for module_name, _ in modules_to_update: - weights_dir = find_weights_dir(local_model_path, module_name) - if weights_dir and _list_safetensors_files(weights_dir): - weights_map[module_name] = weights_dir + weights_dir = Path(local_model_path) / module_name + if weights_dir.exists() and _list_safetensors_files(str(weights_dir)): + weights_map[module_name] = str(weights_dir) else: missing.append(module_name) return weights_map, missing @@ -306,8 +306,8 @@ def _rollback(self, updated_modules: list[str]) -> None: module = self.pipeline.get_module(name) if module is None: continue - weights_dir = find_weights_dir(original_path, name) - if weights_dir is None: + weights_dir = Path(original_path) / name + if not weights_dir.exists(): continue - weights_iter = _get_weights_iter(weights_dir) + weights_iter = _get_weights_iter(str(weights_dir)) _load_weights_into_module(module, weights_iter) diff --git a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py index c9e01eaa7116..68700e93d016 100644 --- a/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +++ b/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -149,7 +149,6 @@ from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, - find_weights_dir, ) from sglang.multimodal_gen.runtime.loader.weight_utils import ( compute_weights_checksum, @@ -223,8 +222,10 @@ def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: same disk checksum is requested multiple times across tests. """ local_path = maybe_download_model(model_path) - weights_dir = find_weights_dir(local_path, module_name) - assert weights_dir is not None, f"No weights dir for {module_name} in {local_path}" + weights_dir = os.path.join(local_path, module_name) + assert os.path.exists( + weights_dir + ), f"No weights dir for {module_name} in {local_path}" safetensors_files = _list_safetensors_files(weights_dir) assert safetensors_files, f"No safetensors files in {weights_dir}" From fab939e10726457ab4a4b6a03b69bf091ab35466 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 17 Feb 2026 11:47:02 -0800 Subject: [PATCH 29/30] remove redundant comments --- .../runtime/loader/weights_updater.py | 28 +++---------------- 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index d951a7750287..f170809a738e 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -5,18 +5,9 @@ without restarting the server. It is the diffusion-engine counterpart of the LLM engine's ModelRunner.update_weights_from_disk. -Typical usage (from GPUWorker.update_weights_from_disk): - - updater = WeightsUpdater(self.pipeline) - success, message = updater.update_weights_from_disk( - model_path, - flush_cache=flush_cache, - target_modules=target_modules, - ) - if success: - self.server_args.model_path = model_path - self.pipeline.model_path = model_path - return success, message +Detailed usage of higher level API can be found in + +/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py Key design decisions: @@ -176,18 +167,7 @@ def update_weights_from_disk( flush_cache: bool = True, target_modules: list[str] | None = None, ) -> tuple[bool, str]: - """Update model weights from disk without restarting the server. - - Args: - model_path: HF repo id or local path to the new weights. - flush_cache: If True, reset TeaCache state after a successful - update so that stale cached residuals are not reused. - target_modules: Explicit list of module names to update. None - updates every nn.Module in the pipeline. - - Returns: - (success, message) tuple where success is True on success. - """ + """Update model weights from disk without restarting the server.""" logger.info(f"Updating weights from disk: {model_path}") try: From f60f638a32bd95dd2f9edb002d5f001c477c934d Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 18 Feb 2026 08:54:53 -0800 Subject: [PATCH 30/30] fix isort --- .../multimodal_gen/runtime/entrypoints/http_server.py | 2 +- .../sglang/multimodal_gen/runtime/managers/scheduler.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py index c0313f21aacf..30a60b35adcf 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -16,8 +16,8 @@ from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( VertexGenerateReqInput, ) -from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api from sglang.multimodal_gen.runtime.entrypoints.openai.utils import build_sampling_params +from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api from sglang.multimodal_gen.runtime.entrypoints.utils import ( prepare_request, save_outputs, diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 1732ca90da22..e1fef1df2d83 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -15,6 +15,10 @@ _parse_size, save_image_to_path, ) +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, + UpdateWeightFromDiskReqInput, +) from sglang.multimodal_gen.runtime.entrypoints.utils import ( ListLorasReq, MergeLoraWeightsReq, @@ -22,10 +26,6 @@ ShutdownReq, UnmergeLoraWeightsReq, ) -from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( - GetWeightsChecksumReqInput, - UpdateWeightFromDiskReqInput, -) from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker from sglang.multimodal_gen.runtime.pipelines_core import Req from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch