diff --git a/docs/diffusion/api/cli.md b/docs/diffusion/api/cli.md index 6480687bfd33..d91d1281ade2 100644 --- a/docs/diffusion/api/cli.md +++ b/docs/diffusion/api/cli.md @@ -160,6 +160,21 @@ HTTP server-only arguments are ignored by `sglang generate`. For diffusers pipelines, Cache-DiT can be enabled with `SGLANG_CACHE_DIT_ENABLED=true` or `--cache-dit-config`. See [Cache-DiT](../performance/cache/cache_dit.md). +### Layerwise Offload + +Use layerwise offload when a large component does not fit comfortably in GPU memory. By default, `--dit-layerwise-offload` only applies to legacy DiT components. Use `--layerwise-offload-components` to select pipeline component names explicitly (`--layerwise-offload-modules` is accepted as an alias): + +```bash +sglang generate \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --dit-layerwise-offload \ + --layerwise-offload-components transformer text_encoder \ + --dit-offload-prefetch-size 0 \ + --prompt "A quiet city street after rain" +``` + +The values must match keys in the selected pipeline's `pipeline.modules`, such as `transformer`, `text_encoder`, `image_encoder`, `vae`, `condition_image_encoder`, `spatial_upsampler`, or `vocoder`. Use `all` to select every layerwise-offloadable component. Prefer the smallest component set that solves the memory issue because layerwise offload can increase latency. + ## Serve `sglang serve` starts the HTTP server and keeps the model loaded for repeated requests. diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py index f42e68cc3ee1..b9e6b9545d2e 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -230,6 +230,7 @@ class PipelineConfig: # Image encoder configuration image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig) image_encoder_precision: str = "fp32" + image_encoder_extra_args: dict = field(default_factory=lambda: {}) # Text encoder configuration DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp32",) @@ -240,9 +241,6 @@ class PipelineConfig: text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}]) - # image encoding - image_encoder_extra_args: dict = field(default_factory=lambda: {}) - def get_model_deployment_config(self) -> ModelDeploymentConfig: return ModelDeploymentConfig() @@ -266,6 +264,10 @@ def postprocess_image(self, image): # DMD parameters dmd_denoising_steps: list[int] | None = field(default=None) + def get_model_deployment_config(self) -> ModelDeploymentConfig: + # return the model-specific config for optimal deployment setting + return ModelDeploymentConfig() + # Wan2.2 TI2V parameters boundary_ratio: float | None = None diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py b/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py index e5f33027db9f..f5ed8f011f75 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py @@ -88,8 +88,6 @@ class WanT2V480PConfig(PipelineConfig): vae_precision: str = "fp32" text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) - # WanConfig-specific added parameters - def __post_init__(self): self.vae_config.load_encoder = False self.vae_config.load_decoder = True diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py index 90e5e28b8481..ddb186cb1402 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py @@ -25,6 +25,15 @@ component_name_to_loader_cls, get_memory_usage_of_component, ) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + configure_layerwise_offload_modules, + is_layerwise_offloaded_module, +) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload_components import ( + LAYERWISE_OFFLOAD_ALL_COMPONENTS, + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS, + layerwise_component_matches_selection, +) from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( @@ -96,6 +105,66 @@ def target_device(self, should_offload): else: return get_local_torch_device() + def customized_load_kwargs_for_component( + self, _server_args: ServerArgs, _component_name: str + ) -> dict[str, Any]: + return {} + + @staticmethod + def _is_component_set_as_layerwise_load( + server_args: ServerArgs, component_name: str + ) -> bool: + """if a component should be loaded in a layerwise-fashion""" + selected_component_names = server_args.layerwise_offload_components + if selected_component_names is None: + return False + selected_component_names = set(selected_component_names) + if LAYERWISE_OFFLOAD_ALL_COMPONENTS in selected_component_names: + return True + explicit_component_names = selected_component_names - { + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS + } + return any( + layerwise_component_matches_selection(component_name, selected_component) + for selected_component in explicit_component_names + ) + + def _maybe_configure_layerwise_after_startup_cpu_staging( + self, + component: AutoModel, + server_args: ServerArgs, + component_name: str, + load_kwargs: dict[str, Any], + ) -> AutoModel: + if not load_kwargs.get("cpu_offload_flag"): + return component + if not isinstance(component, nn.Module): + return component + + # try to configure layerwise-offload with the component + configured_components = configure_layerwise_offload_modules( + {component_name: component}, + server_args, + component_names=server_args.layerwise_offload_components, + warn_missing=False, + ) + if is_layerwise_offloaded_module(component): + logger.info( + "Configured layerwise offload for %s immediately after startup CPU staging", + component_name, + ) + return component + + logger.warning( + "Layerwise startup CPU staging was requested for %s, but the loaded " + "module did not enable layerwise offload. Moving it to GPU.", + component_name, + ) + # ensures the module is on GPU + if component_name in configured_components: + return component + return component.to(get_local_torch_device()) + def load( self, component_model_path: str, @@ -135,8 +204,15 @@ def load( with component_attn_backend_context_manager( attn_backend, component_name=component_attn_name ): + load_kwargs = self.customized_load_kwargs_for_component( + server_args, component_name + ) component = self.load_customized( - component_model_path, server_args, component_name + component_model_path, server_args, component_name, **load_kwargs + ) + # configure layerwise to make enough VRAM headroom + component = self._maybe_configure_layerwise_after_startup_cpu_staging( + component, server_args, component_name, load_kwargs ) source = "sgl-diffusion" except Exception as e: diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py index 18a33b3bbb54..04de08f20616 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py @@ -30,7 +30,11 @@ def should_offload(self, server_args, model_config: ModelConfig | None = None): return use_cpu_offload def load_customized( - self, component_model_path: str, server_args: ServerArgs, *args + self, + component_model_path: str, + server_args: ServerArgs, + component_name: str = "image_encoder", + cpu_offload_flag: bool | None = None, ): """Load the text encoders based on the model path, and inference args.""" # model_config: PretrainedConfig = get_hf_config( @@ -53,5 +57,9 @@ def load_customized( encoder_config, server_args, server_args.pipeline_config.image_encoder_precision, - cpu_offload_flag=server_args.image_encoder_cpu_offload, + cpu_offload_flag=( + cpu_offload_flag + if cpu_offload_flag is not None + else server_args.image_encoder_cpu_offload + ), ) diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py index 691eb0ca4986..4d635b6860b1 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py @@ -81,6 +81,19 @@ def should_offload(self, server_args, model_config: ModelConfig | None = None): use_cpu_offload = should_offload and len(fsdp_shard_conditions) > 0 return use_cpu_offload + def customized_load_kwargs_for_component( + self, server_args: ServerArgs, component_name: str + ) -> dict[str, bool]: + if ComponentLoader._is_component_set_as_layerwise_load( + server_args, component_name + ): + logger.info( + "Loading %s on CPU first because it is selected for layerwise offload", + component_name, + ) + return {"cpu_offload_flag": True} + return {} + def load_native( self, component_model_path: str, diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 1007e9b41476..1967fb2d2230 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -53,7 +53,9 @@ from sglang.multimodal_gen.runtime.loader.weight_utils import ( safetensors_weights_iterator, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + is_layerwise_offloaded_module, +) from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -114,7 +116,7 @@ def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: update_cpu_weights(); non-offloaded parameters use in-place copy. """ offload_managers: list = [] - if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + if is_layerwise_offloaded_module(module): offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] if offload_managers: diff --git a/python/sglang/multimodal_gen/runtime/managers/dynamic_batch_admission.py b/python/sglang/multimodal_gen/runtime/managers/dynamic_batch_admission.py index 6d9153c75473..9772fe798300 100644 --- a/python/sglang/multimodal_gen/runtime/managers/dynamic_batch_admission.py +++ b/python/sglang/multimodal_gen/runtime/managers/dynamic_batch_admission.py @@ -160,7 +160,7 @@ def __init__(self, server_args: "ServerArgs", gpu_id: int): self._mode = getattr(server_args, "batching_mode", "dynamic") self._user_max_batch_size = max(1, int(server_args.batching_max_size)) self._model_path = server_args.model_path - self._offload = bool(server_args.dit_layerwise_offload) + self._offload = bool(server_args.layerwise_offload_components) self._device_memory_gb = self._get_device_memory_gb(gpu_id) self._rules = load_batching_config(server_args.batching_config) self._pipeline_config = server_args.pipeline_config diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 5982cd07ca1d..b07f210ee7d5 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -41,8 +41,8 @@ WeightsUpdater, get_updatable_modules, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import ( - OffloadableDiTMixin, +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + configure_layerwise_offload_modules, iter_materialized_weights, ) from sglang.multimodal_gen.runtime.pipelines_core import ( @@ -165,23 +165,12 @@ def init_device_and_model(self) -> None: # apply layerwise offload after lora is applied while building LoRAPipeline # otherwise empty offloaded weights could fail lora converting - if self.server_args.dit_layerwise_offload: - # enable layerwise offload if possible - for module_name in [ - "transformer", - "transformer_2", - "video_dit", - "video_dit_2", - "audio_dit", - ]: - dit = self.pipeline.get_module(module_name) - if dit: - if isinstance(dit, OffloadableDiTMixin): - dit.configure_layerwise_offload(self.server_args) - else: - logger.info( - f"Module {type(dit).__name__} does not support layerwise offload. Skipping." - ) + if self.server_args.layerwise_offload_components: + configure_layerwise_offload_modules( + self.pipeline.modules, + self.server_args, + component_names=self.server_args.layerwise_offload_components, + ) logger.info( f"Worker {self.rank}: Initialized device, model, and distributed environment." @@ -234,7 +223,7 @@ def _format_offload_disable_suggestions(self, components: List[str]) -> str: elif component in ("text_encoder", "text_encoder_2"): arg = "--text-encoder-cpu-offload" elif component == "transformer": - if self.server_args.dit_layerwise_offload: + if self.server_args.is_dit_layerwise_offload_selected: arg = "--dit-layerwise-offload" elif self.server_args.dit_cpu_offload: arg = "--dit-cpu-offload" @@ -740,7 +729,7 @@ def get_can_stay_resident_components( # If the flag is True, it is currently offloaded, so it is a candidate to stay resident. offload_flags = { "transformer": self.server_args.dit_cpu_offload - or self.server_args.dit_layerwise_offload, + or self.server_args.is_dit_layerwise_offload_selected, "vae": self.server_args.vae_cpu_offload, "text_encoder": self.server_args.text_encoder_cpu_offload, "text_encoder_2": self.server_args.text_encoder_cpu_offload, diff --git a/python/sglang/multimodal_gen/runtime/managers/component_manager.py b/python/sglang/multimodal_gen/runtime/managers/memory_managers/component_manager.py similarity index 93% rename from python/sglang/multimodal_gen/runtime/managers/component_manager.py rename to python/sglang/multimodal_gen/runtime/managers/memory_managers/component_manager.py index ae35a6449af0..256214438df2 100644 --- a/python/sglang/multimodal_gen/runtime/managers/component_manager.py +++ b/python/sglang/multimodal_gen/runtime/managers/memory_managers/component_manager.py @@ -7,13 +7,21 @@ import torch import torch.nn as nn -from sglang.multimodal_gen.runtime.managers.component_resident_strategies import ( +from sglang.multimodal_gen.runtime.managers.memory_managers.component_resident_strategies import ( ComponentResidencyStrategy, LayerwiseOffloadStrategy, ResidentStrategy, VanillaD2HStrategy, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + is_layerwise_offloaded_module, +) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload_components import ( + is_dit_component_name, + is_image_encoder_component_name, + is_text_encoder_component_name, + is_vae_component_name, +) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -85,70 +93,35 @@ class ComponentResidencyPipeline(Protocol): component_residency_strategies: MutableMapping[str, "ComponentResidencyStrategy"] -def build_dit_residency_strategy( - module: nn.Module, - server_args: ServerArgs, -) -> ComponentResidencyStrategy: - if ( - isinstance(module, OffloadableDiTMixin) - and module.layerwise_offload_managers - and any(manager.enabled for manager in module.layerwise_offload_managers) - ): - # only if dit_layerwise_offload is enabled - return LayerwiseOffloadStrategy() - if server_args.dit_cpu_offload and not server_args.use_fsdp_inference: - # handles offload by vanalla D2H - return VanillaD2HStrategy() - return ResidentStrategy() - - def is_fsdp_managed_module(module: nn.Module) -> bool: return module.__class__.__name__.startswith("FSDP") +def should_cpu_offload_component( + component_name: str, module: nn.Module, server_args: ServerArgs +) -> bool: + if server_args.use_fsdp_inference or is_fsdp_managed_module(module): + return False + if is_dit_component_name(component_name): + return bool(server_args.dit_cpu_offload) + if is_text_encoder_component_name(component_name): + return bool(server_args.text_encoder_cpu_offload) + if is_image_encoder_component_name(component_name): + return bool(server_args.image_encoder_cpu_offload) + if is_vae_component_name(component_name): + return bool(server_args.vae_cpu_offload) + return False + + def build_component_residency_strategy( component_name: str, module: nn.Module, server_args: ServerArgs, ) -> ComponentResidencyStrategy: - if component_name in { - "transformer", - "transformer_2", - "video_dit", - "video_dit_2", - "audio_dit", - "dual_tower_bridge", - }: - return build_dit_residency_strategy(module, server_args) - - if component_name.startswith("text_encoder") or component_name.endswith( - "text_encoder" - ): - if ( - server_args.text_encoder_cpu_offload - and not server_args.use_fsdp_inference - and not is_fsdp_managed_module(module) - ): - return VanillaD2HStrategy() - return ResidentStrategy() - - if component_name == "image_encoder": - if server_args.image_encoder_cpu_offload and not server_args.use_fsdp_inference: - return VanillaD2HStrategy() - return ResidentStrategy() - - if component_name in { - "vae", - "video_vae", - "audio_vae", - "vocoder", - "spatial_upsampler", - "condition_image_encoder", - }: - if server_args.vae_cpu_offload and not server_args.use_fsdp_inference: - return VanillaD2HStrategy() - return ResidentStrategy() - + if is_layerwise_offloaded_module(module): + return LayerwiseOffloadStrategy() + if should_cpu_offload_component(component_name, module, server_args): + return VanillaD2HStrategy() return ResidentStrategy() diff --git a/python/sglang/multimodal_gen/runtime/managers/component_resident_strategies.py b/python/sglang/multimodal_gen/runtime/managers/memory_managers/component_resident_strategies.py similarity index 97% rename from python/sglang/multimodal_gen/runtime/managers/component_resident_strategies.py rename to python/sglang/multimodal_gen/runtime/managers/memory_managers/component_resident_strategies.py index f3d3f1df27d5..eef4c9c5d9de 100644 --- a/python/sglang/multimodal_gen/runtime/managers/component_resident_strategies.py +++ b/python/sglang/multimodal_gen/runtime/managers/memory_managers/component_resident_strategies.py @@ -10,12 +10,14 @@ import torch.nn as nn from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger if TYPE_CHECKING: - from sglang.multimodal_gen.runtime.managers.component_manager import ( + from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( ComponentUse, ResidencyState, ) @@ -484,11 +486,11 @@ class LayerwiseOffloadStrategy(ComponentResidencyStrategy): name = "layerwise" def enter(self, module: nn.Module) -> None: - if isinstance(module, OffloadableDiTMixin): + if isinstance(module, LayerwiseOffloadableModuleMixin): module.prepare_for_next_req() def exit(self, module: nn.Module, next_module: nn.Module | None = None) -> None: - if not isinstance(module, OffloadableDiTMixin): + if not isinstance(module, LayerwiseOffloadableModuleMixin): return for manager in module.layerwise_offload_managers: manager.release_all() diff --git a/python/sglang/multimodal_gen/runtime/managers/layerwise_offload.py b/python/sglang/multimodal_gen/runtime/managers/memory_managers/layerwise_offload.py similarity index 71% rename from python/sglang/multimodal_gen/runtime/managers/layerwise_offload.py rename to python/sglang/multimodal_gen/runtime/managers/memory_managers/layerwise_offload.py index b52610c5c89c..646e04568f18 100644 --- a/python/sglang/multimodal_gen/runtime/managers/layerwise_offload.py +++ b/python/sglang/multimodal_gen/runtime/managers/memory_managers/layerwise_offload.py @@ -1,9 +1,14 @@ import re -from itertools import chain +from collections.abc import Mapping, Sequence from typing import Any, Dict, List, Set, Tuple import torch +from torch.distributed.tensor import DTensor +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload_components import ( + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS, + layerwise_component_matches_selection, +) from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -70,6 +75,7 @@ def __init__( self._named_parameters: Dict[str, torch.nn.Parameter] = {} self._named_buffers: Dict[str, torch.Tensor] = {} self._offload_placeholders: Dict[torch.dtype, torch.Tensor] = {} + self._has_dtensor_weights = False # Store forward hooks for removal self._forward_hooks: List[Any] = [] @@ -91,6 +97,26 @@ def _get_shared_empty_tensor(self, dtype: torch.dtype) -> torch.Tensor: self._offload_placeholders[dtype] = placeholder return placeholder + @staticmethod + def _to_local_tensor(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, DTensor): + return tensor.to_local() + return tensor + + def _wrap_for_target( + self, target: torch.Tensor, local_tensor: torch.Tensor + ) -> torch.Tensor: + if isinstance(target, DTensor): + return DTensor.from_local( + local_tensor, target.device_mesh, target.placements + ) + return local_tensor + + def _get_shared_empty_tensor_for_target( + self, target: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + return self._wrap_for_target(target, self._get_shared_empty_tensor(dtype)) + @staticmethod def _get_alignment_numel(dtype: torch.dtype, alignment_bytes: int = 32) -> int: element_size = torch.empty((), dtype=dtype).element_size() @@ -114,16 +140,20 @@ def _initialize(self) -> None: self._named_parameters = dict(self.model.named_parameters()) self._named_buffers = dict(self.model.named_buffers()) - # 1. collect and group tensors by layer and dtype + # 1. collect and group layer parameters by dtype. Keep buffers resident: + # shared buffers such as RoPE caches may be referenced by many layers. layer_groups: Dict[int, Dict[torch.dtype, List[Tuple[str, torch.Tensor]]]] = {} - all_tensors = chain(self._named_parameters.items(), self._named_buffers.items()) - for name, tensor in all_tensors: + for name, tensor in self._named_parameters.items(): layer_idx = self._match_layer_idx(name) if layer_idx is None or layer_idx >= self.num_layers: continue - layer_groups.setdefault(layer_idx, {}).setdefault(tensor.dtype, []).append( - (name, tensor) + self._has_dtensor_weights = self._has_dtensor_weights or isinstance( + tensor, DTensor ) + local_tensor = self._to_local_tensor(tensor) + layer_groups.setdefault(layer_idx, {}).setdefault( + local_tensor.dtype, [] + ).append((name, tensor)) # 2. concat and offload (in pinned memory) for layer_idx, dtype_to_params in layer_groups.items(): @@ -132,43 +162,46 @@ def _initialize(self) -> None: self._weight_metadata[layer_idx] = {} for dtype, weights in dtype_to_params.items(): - contiguous_weights: List[Tuple[str, torch.Tensor]] = [] + contiguous_weights: List[Tuple[str, torch.Tensor, torch.Tensor]] = [] for name, weight in weights: - if weight.is_contiguous(): - contiguous_weights.append((name, weight)) + local_weight = self._to_local_tensor(weight) + if local_weight.is_contiguous(): + contiguous_weights.append((name, weight, local_weight)) continue # Preserve non-contiguous layouts such as the transposed FP8 # weight views expected by CUTLASS kernels. cpu_tensor = torch.empty_strided( - size=weight.shape, - stride=weight.stride(), + size=local_weight.shape, + stride=local_weight.stride(), dtype=dtype, pin_memory=self.pin_cpu_memory, ) - cpu_tensor.copy_(weight) + cpu_tensor.copy_(local_weight) self._strided_cpu_weights[layer_idx][name] = cpu_tensor self._weight_metadata[layer_idx][name] = { "dtype": dtype, - "shape": weight.shape, - "stride": weight.stride(), + "shape": local_weight.shape, + "stride": local_weight.stride(), "preserve_strides": True, } - weight.data = self._get_shared_empty_tensor(dtype) + weight.data = self._get_shared_empty_tensor_for_target( + weight, dtype + ) if not contiguous_weights: continue current_offset = 0 aligned_offsets: Dict[str, int] = {} - for name, weight in contiguous_weights: + for name, weight, local_weight in contiguous_weights: # Some fused diffusion kernels require tensor base pointers to # satisfy a 32-byte alignment contract. Reusing one flat buffer # is still fine, but each logical tensor slice must start on an # aligned offset inside that buffer. current_offset = self._align_numel_offset(current_offset, dtype) aligned_offsets[name] = current_offset - current_offset += weight.numel() + current_offset += local_weight.numel() total_numel = current_offset @@ -178,22 +211,24 @@ def _initialize(self) -> None: ) # offload weights to the buffer - for name, weight in contiguous_weights: + for name, weight, local_weight in contiguous_weights: current_offset = aligned_offsets[name] - numel = weight.numel() + numel = local_weight.numel() cpu_buffer[current_offset : current_offset + numel].copy_( - weight.flatten() + local_weight.flatten() ) self._weight_metadata[layer_idx][name] = { "dtype": dtype, "offset": current_offset, "numel": numel, - "shape": weight.shape, - "stride": weight.stride(), + "shape": local_weight.shape, + "stride": local_weight.stride(), "preserve_strides": False, } - weight.data = self._get_shared_empty_tensor(dtype) + weight.data = self._get_shared_empty_tensor_for_target( + weight, dtype + ) current_offset += numel @@ -202,7 +237,8 @@ def _initialize(self) -> None: # Keep non-layer parameters resident on GPU. Layer tensors have already # been replaced by tiny device placeholders, so this does not reload the # offloaded layer weights. - self.model.to(self.device) + if not self._has_dtensor_weights: + self.model.to(self.device) # prefetch the first layer for warm-up self.prepare_for_next_req(non_blocking=False) @@ -271,16 +307,17 @@ def prefetch_layer(self, layer_idx: int, non_blocking: bool = True) -> None: device=self.device, ) gpu_tensor.copy_(cpu_tensor, non_blocking=non_blocking) - target.data = gpu_tensor + target.data = self._wrap_for_target(target, gpu_tensor) continue dtype = meta["dtype"] gpu_buffer = gpu_buffers[dtype] # map the parameter's data to the correct slice of the GPU buffer - target.data = gpu_buffer[ + local_tensor = gpu_buffer[ meta["offset"] : meta["offset"] + meta["numel"] ].view(meta["shape"]) + target.data = self._wrap_for_target(target, local_tensor) # record the prefetch event of this layer after all copies are enqueued event = torch.get_device_module().Event() @@ -307,7 +344,9 @@ def release_layer(self, layer_idx: int) -> None: for name, meta in self._weight_metadata.get(layer_idx, {}).items(): target = self.get_target_with_name(name) # Wraparound prefetch will reload the layer when it is needed again - target.data = self._get_shared_empty_tensor(meta["dtype"]) + target.data = self._get_shared_empty_tensor_for_target( + target, meta["dtype"] + ) self._gpu_layers.discard(layer_idx) @@ -347,11 +386,12 @@ def sync_layer_to_cpu(self, layer_idx: int) -> None: # Collect current GPU weights and write back to CPU buffer for name, meta in self._weight_metadata.get(layer_idx, {}).items(): target = self.get_target_with_name(name) + target_local = self._to_local_tensor(target) if meta.get("preserve_strides", False): - self._strided_cpu_weights[layer_idx][name].copy_(target.data.cpu()) + self._strided_cpu_weights[layer_idx][name].copy_(target_local.cpu()) continue - gpu_weight = target.data.flatten().cpu() + gpu_weight = target_local.flatten().cpu() dtype = meta["dtype"] cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] @@ -408,30 +448,32 @@ def update_cpu_weights( continue meta = meta_layer[name] - if tuple(meta["shape"]) != tuple(loaded_weight.shape): + local_loaded_weight = self._to_local_tensor(loaded_weight) + if tuple(meta["shape"]) != tuple(local_loaded_weight.shape): raise ValueError( f"Shape mismatch for {name}: " f"expected={tuple(meta['shape'])}, " - f"loaded={tuple(loaded_weight.shape)}" + f"loaded={tuple(local_loaded_weight.shape)}" ) dtype = meta["dtype"] if meta.get("preserve_strides", False): self._strided_cpu_weights[layer_idx][name].copy_( - loaded_weight.to(dtype=dtype) + local_loaded_weight.to(dtype=dtype) ) else: offset = meta["offset"] numel = meta["numel"] cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] cpu_buffer[offset : offset + numel].copy_( - loaded_weight.to(dtype=dtype).flatten() + local_loaded_weight.to(dtype=dtype).flatten() ) # If this layer is currently on GPU, update the live parameter. if layer_idx in self._gpu_layers: target = self.get_target_with_name(name) - target.data.copy_(loaded_weight.to(dtype=target.dtype)) + target_local = self._to_local_tensor(target) + target_local.copy_(local_loaded_weight.to(dtype=target_local.dtype)) updated_names.add(name) @@ -467,7 +509,7 @@ def register_forward_hooks(self) -> None: if not self.enabled: return - layers = getattr(self.model, self.layers_attr_str) + layers = dict(self.model.named_modules())[self.layers_attr_str] def make_pre_hook(i): def hook(module, input): @@ -509,21 +551,24 @@ def remove_forward_hooks(self) -> None: self._forward_hooks.clear() -class OffloadableDiTMixin: - """ - A mixin that registers forward hooks for a DiT to enable layerwise offload - """ +class LayerwiseOffloadableModuleMixin: + """A mixin that registers forward hooks to enable layerwise offload.""" - # the list of names of a DiT's layers/blocks - layer_names: List[str] + # Legacy --dit-layerwise-offload configures these modules when no component is named. + layerwise_offload_default_enabled: bool = True + # The list of names of this module's layer/block ModuleList or Sequential attributes. + layer_names: List[str] = [] layerwise_offload_managers: list[LayerwiseOffloadManager] = [] def configure_layerwise_offload(self, server_args: ServerArgs): self.layerwise_offload_managers = [] + named_modules = dict(self.named_modules()) + configured_layer_names = [] for layer_name in self.layer_names: - # a manager per layer-list - module_list = getattr(self, layer_name, None) - if module_list is None or not isinstance(module_list, torch.nn.ModuleList): + module_list = named_modules.get(layer_name) + if not isinstance(module_list, (torch.nn.ModuleList, torch.nn.Sequential)): + continue + if len(module_list) == 0: continue num_layers = len(module_list) @@ -543,10 +588,20 @@ def configure_layerwise_offload(self, server_args: ServerArgs): prefetch_size=prefetch_size, ) self.layerwise_offload_managers.append(manager) + configured_layer_names.append(layer_name) - logger.info( - f"Enabled layerwise offload for {self.__class__.__name__} on modules: {self.layer_names}" - ) + if configured_layer_names: + logger.info( + "Enabled layerwise offload for %s on modules: %s", + self.__class__.__name__, + configured_layer_names, + ) + else: + logger.info( + "No layerwise-offloadable ModuleList found for %s. Candidates: %s", + self.__class__.__name__, + self.layer_names, + ) def prepare_for_next_req(self): if self.layerwise_offload_managers is None: @@ -583,7 +638,7 @@ def iter_materialized_weights(module: torch.nn.Module): the non-offloaded parameters. """ offload_managers: list = [] - if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + if is_layerwise_offloaded_module(module): offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] if not offload_managers: @@ -601,3 +656,117 @@ def iter_materialized_weights(module: torch.nn.Module): for name, param in module.named_parameters(): if name not in offloaded_names: yield name, param + + +def is_layerwise_offloaded_module(module: torch.nn.Module) -> bool: + return isinstance(module, LayerwiseOffloadableModuleMixin) and any( + manager.enabled for manager in module.layerwise_offload_managers + ) + + +def configure_layerwise_offload_modules( + modules: Mapping[str, object], + server_args: ServerArgs, + component_names: Sequence[str] | None = None, + warn_missing: bool = True, +) -> list[str]: + """Configure layerwise offload for the given modules, from the given component_names + + Args: + modules: the dict of {component_name: component}, containing the components to be chosen from + component_names: list of component names. component with names not in this list shouldn't be configured + + Returns a list of component names of modules configured to be layerwise-offload + """ + + # components which has already been configured to be layerwise-offload + configured_component_names: list[str] = [] + configured_module_ids: set[int] = set() + selected_component_names = ( + set(component_names) if component_names is not None else None + ) + select_all = ( + selected_component_names is not None and "all" in selected_component_names + ) + select_default = ( + selected_component_names is not None + and LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS in selected_component_names + ) + + if warn_missing and selected_component_names is not None and not select_all: + explicit_component_names = selected_component_names - { + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS + } + missing_component_names = [ + selected_component_name + for selected_component_name in explicit_component_names + if not any( + layerwise_component_matches_selection( + component_name, selected_component_name + ) + for component_name in modules + ) + ] + if missing_component_names: + logger.warning( + "Layerwise offload components are not currently loaded: %s. " + "Available pipeline components: %s", + sorted(missing_component_names), + sorted(modules), + ) + + unsupported_component_names = [ + component_name + for component_name in modules + if any( + layerwise_component_matches_selection( + component_name, selected_component_name + ) + for selected_component_name in explicit_component_names + ) + if not isinstance(modules[component_name], LayerwiseOffloadableModuleMixin) + ] + if unsupported_component_names: + logger.warning( + "Layerwise offload components do not support layerwise offload: %s", + sorted(unsupported_component_names), + ) + + for component_name, module in modules.items(): + if not isinstance(module, LayerwiseOffloadableModuleMixin): + continue + if selected_component_names is None: + if not module.layerwise_offload_default_enabled: + continue + elif ( + not select_all + and not any( + layerwise_component_matches_selection( + component_name, selected_component_name + ) + for selected_component_name in selected_component_names + ) + and not (select_default and module.layerwise_offload_default_enabled) + ): + # if the current component is not selected to be layerwise-offload, skip + continue + module_id = id(module) + if module_id in configured_module_ids: + # avoid multiple configures on a same module + continue + + configured_module_ids.add(module_id) + + if not is_layerwise_offloaded_module(module): + module.configure_layerwise_offload(server_args) + if is_layerwise_offloaded_module(module): + configured_component_names.append(component_name) + + if configured_component_names: + logger.info( + "Enabled layerwise offload for pipeline components: %s", + configured_component_names, + ) + else: + logger.info("No pipeline component supports layerwise offload.") + return configured_component_names diff --git a/python/sglang/multimodal_gen/runtime/managers/memory_managers/layerwise_offload_components.py b/python/sglang/multimodal_gen/runtime/managers/memory_managers/layerwise_offload_components.py new file mode 100644 index 000000000000..48297d5476c0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/managers/memory_managers/layerwise_offload_components.py @@ -0,0 +1,117 @@ +from collections.abc import Sequence + +LAYERWISE_OFFLOAD_ALL_COMPONENTS = "all" +LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS = "default" +DIT_COMPONENT_NAMES = frozenset( + { + "transformer", + "transformer_2", + "video_dit", + "video_dit_2", + "audio_dit", + "dual_tower_bridge", + } +) +VAE_COMPONENT_NAMES = frozenset( + { + "vae", + "video_vae", + "audio_vae", + "vocoder", + "spatial_upsampler", + "condition_image_encoder", + } +) +CPU_OFFLOAD_FLAG_NAMES = ( + "dit_cpu_offload", + "text_encoder_cpu_offload", + "image_encoder_cpu_offload", + "vae_cpu_offload", +) + + +def is_dit_component_name(component_name: str) -> bool: + return component_name in DIT_COMPONENT_NAMES + + +def is_text_encoder_component_name(component_name: str) -> bool: + return component_name.startswith("text_encoder") or component_name.endswith( + "text_encoder" + ) + + +def is_image_encoder_component_name(component_name: str) -> bool: + return component_name == "image_encoder" + + +def is_vae_component_name(component_name: str) -> bool: + return component_name in VAE_COMPONENT_NAMES + + +def layerwise_component_matches_selection( + component_name: str, + selected_component_name: str, +) -> bool: + """if the provided component_name (unnormalized, e.g., text_encoder_2) matches with the selected_component_name (normalized)""" + if selected_component_name == "text_encoder": + return is_text_encoder_component_name(component_name) + if selected_component_name == "vae": + return is_vae_component_name(component_name) + return component_name == selected_component_name + + +def cpu_offload_flags_for_layerwise_components( + component_names: Sequence[str], +) -> tuple[str, ...]: + if LAYERWISE_OFFLOAD_ALL_COMPONENTS in component_names: + return CPU_OFFLOAD_FLAG_NAMES + + flag_names: list[str] = [] + if LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS in component_names: + flag_names.append("dit_cpu_offload") + + for component_name in component_names: + if component_name == LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS: + continue + if is_dit_component_name(component_name): + flag_name = "dit_cpu_offload" + elif is_text_encoder_component_name(component_name): + flag_name = "text_encoder_cpu_offload" + elif is_image_encoder_component_name(component_name): + flag_name = "image_encoder_cpu_offload" + elif is_vae_component_name(component_name): + flag_name = "vae_cpu_offload" + else: + continue + + if flag_name not in flag_names: + flag_names.append(flag_name) + + return tuple(flag_names) + + +def normalize_layerwise_offload_components( + component_names: str | Sequence[str] | None, +) -> list[str] | None: + if component_names is None: + return None + + raw_components = ( + [component_names] if isinstance(component_names, str) else component_names + ) + normalized_components: list[str] = [] + for raw_component in raw_components: + if not isinstance(raw_component, str): + raise ValueError( + f"Invalid layerwise offload component name: {raw_component}." + ) + for component_name in raw_component.split(","): + component_name = component_name.strip().replace("-", "_").lower() + if not component_name: + continue + if component_name == LAYERWISE_OFFLOAD_ALL_COMPONENTS: + return [LAYERWISE_OFFLOAD_ALL_COMPONENTS] + if component_name not in normalized_components: + normalized_components.append(component_name) + + return normalized_components or None diff --git a/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py b/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py index d4d4222a9ca7..4277bd24bb9f 100644 --- a/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py +++ b/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py @@ -26,7 +26,9 @@ from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( apply_flashinfer_rope_qk_inplace, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -397,7 +399,7 @@ def forward( class DualTowerConditionalBridge( CachableDiT, - OffloadableDiTMixin, + LayerwiseOffloadableModuleMixin, ): """Dual-tower conditional bridge module v2 (SGLang optimized version). @@ -407,6 +409,8 @@ class DualTowerConditionalBridge( 3. Cross-attention interaction between the hidden states of the two DiTs. """ + layerwise_offload_default_enabled = False + _fsdp_shard_conditions = MOVADualTowerConfig()._fsdp_shard_conditions _compile_conditions = MOVADualTowerConfig()._compile_conditions _supported_attention_backends = MOVADualTowerConfig()._supported_attention_backends diff --git a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py index 75d79a4aee88..dea47f10fcdd 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py @@ -13,7 +13,9 @@ flex_attention, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) # wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention # see https://github.com/pytorch/pytorch/issues/133254 @@ -58,7 +60,6 @@ class CausalWanSelfAttention(nn.Module): - def __init__( self, dim: int, @@ -251,7 +252,6 @@ def forward( class CausalWanTransformerBlock(nn.Module): - def __init__( self, dim: int, @@ -429,7 +429,7 @@ def forward( return hidden_states -class CausalWanTransformer3DModel(BaseDiT, OffloadableDiTMixin): +class CausalWanTransformer3DModel(BaseDiT, LayerwiseOffloadableModuleMixin): _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions _compile_conditions = WanVideoConfig()._compile_conditions _supported_attention_backends = WanVideoConfig()._supported_attention_backends @@ -660,10 +660,13 @@ def _forward_inference( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( - self.condition_embedder( - timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + ) = self.condition_embedder( + timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten( dim=0, sizes=timestep.shape @@ -802,10 +805,13 @@ def _forward_train( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( - self.condition_embedder( - timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + ) = self.condition_embedder( + timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten( dim=0, sizes=timestep.shape diff --git a/python/sglang/multimodal_gen/runtime/models/dits/ernie_image.py b/python/sglang/multimodal_gen/runtime/models/dits/ernie_image.py index 295319027208..f858a2bead42 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/ernie_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/ernie_image.py @@ -33,7 +33,9 @@ RowParallelLinear, ) from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT @@ -172,7 +174,6 @@ def forward( class ErnieImageMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -266,7 +267,7 @@ def _apply_rotary_bshd(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: return torch.cat((x_rot, x_pass), dim=-1) -class ErnieImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): +class ErnieImageTransformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ErnieImage DiT: Single-stream transformer with Shared AdaLN.""" _supports_gradient_checkpointing = True diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 80fed17fbc0a..24cc25e06d29 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -52,7 +52,9 @@ CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -355,9 +357,14 @@ def forward( freqs_cis=None, num_replicated_prefix: int = 0, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - query, key, value, encoder_query, encoder_key, encoder_value = ( - _get_qkv_projections(self, x, encoder_hidden_states) - ) + ( + query, + key, + value, + encoder_query, + encoder_key, + encoder_value, + ) = _get_qkv_projections(self, x, encoder_hidden_states) query = query.unflatten(-1, (self.heads, -1)) key = key.unflatten(-1, (self.heads, -1)) @@ -658,9 +665,13 @@ def forward( hidden_states, emb=temb ) - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( - self.norm1_context(encoder_hidden_states, emb=temb) - ) + ( + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1_context(encoder_hidden_states, emb=temb) joint_attention_kwargs = joint_attention_kwargs or {} # Attention. @@ -745,7 +756,7 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() -class FluxTransformer2DModel(CachableDiT, OffloadableDiTMixin): +class FluxTransformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ The Transformer model introduced in Flux. diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index dfad883fb8fd..13f99cae5f1b 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -42,7 +42,9 @@ NDRotaryEmbedding, apply_flashinfer_rope_qk_inplace, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, @@ -294,9 +296,14 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: - query, key, value, encoder_query, encoder_key, encoder_value = ( - _get_qkv_projections(self, hidden_states, encoder_hidden_states) - ) + ( + query, + key, + value, + encoder_query, + encoder_key, + encoder_value, + ) = _get_qkv_projections(self, hidden_states, encoder_hidden_states) query = query.unflatten(-1, (self.local_heads, -1)) key = key.unflatten(-1, (self.local_heads, -1)) @@ -850,7 +857,7 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() -class Flux2Transformer2DModel(CachableDiT, OffloadableDiTMixin): +class Flux2Transformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ The Transformer model introduced in Flux 2. diff --git a/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py b/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py index 3cd4454711a8..114094e7ab81 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py @@ -37,7 +37,9 @@ apply_flashinfer_rope_qk_inplace, ) from sglang.multimodal_gen.runtime.layers.visual_embedding import Timesteps -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, @@ -661,7 +663,7 @@ def forward( return x -class GlmImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): +class GlmImageTransformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): r""" Args: patch_size (`int`, defaults to `2`): diff --git a/python/sglang/multimodal_gen/runtime/models/dits/helios.py b/python/sglang/multimodal_gen/runtime/models/dits/helios.py index 09ef5d396799..7050e13c89df 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/helios.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/helios.py @@ -47,7 +47,9 @@ TimestepEmbedder, ) from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -525,7 +527,7 @@ def forward( # --------------------------------------------------------------------------- -class HeliosTransformer3DModel(CachableDiT, OffloadableDiTMixin): +class HeliosTransformer3DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ Helios Transformer 3D model for video generation. @@ -671,9 +673,13 @@ def forward( # 1. Patch embed the noisy latents hidden_states = self.patch_embedding(hidden_states) - _, _, post_patch_num_frames, post_patch_height, post_patch_width = ( - hidden_states.shape - ) + ( + _, + _, + post_patch_num_frames, + post_patch_height, + post_patch_width, + ) = hidden_states.shape if indices_hidden_states is None: indices_hidden_states = ( diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py index 91df7621d98a..3474e76e0ce4 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py @@ -27,7 +27,9 @@ RowParallelLinear, ) from sglang.multimodal_gen.runtime.layers.mlp import MLP -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -453,7 +455,7 @@ def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: return x -class Hunyuan3D2DiT(CachableDiT, OffloadableDiTMixin): +class Hunyuan3D2DiT(CachableDiT, LayerwiseOffloadableModuleMixin): """Hunyuan3D DiT model (Flux-style architecture for Hunyuan3D-2.0).""" _aliases = ["hy3dgen.shapegen.models.Hunyuan3DDiT"] @@ -560,7 +562,7 @@ def __init__( self.final_layer = _FluxLastLayer(self.hidden_size, 1, self.out_channels) - # OffloadableDiTMixin + # LayerwiseOffloadableModuleMixin self.layer_names = ["double_blocks", "single_blocks"] def forward( diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index 1750f7e7b361..ba616e1777b9 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -37,7 +37,9 @@ unpatchify, ) from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.models.utils import modulate from sglang.multimodal_gen.runtime.platforms import ( @@ -418,7 +420,7 @@ def forward( return self.output_residual(output, mod_gate, x) -class HunyuanVideoTransformer3DModel(CachableDiT, OffloadableDiTMixin): +class HunyuanVideoTransformer3DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ HunyuanVideo Transformer backbone adapted for distributed training. diff --git a/python/sglang/multimodal_gen/runtime/models/dits/joy_image.py b/python/sglang/multimodal_gen/runtime/models/dits/joy_image.py index 545990a5cf54..c988fd925565 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/joy_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/joy_image.py @@ -27,7 +27,9 @@ ) from sglang.multimodal_gen.runtime.layers.rotary_embedding import NDRotaryEmbedding from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.models.dits.wanvideo import WanTimeTextImageEmbedding from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs @@ -89,7 +91,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MMDoubleStreamBlock(nn.Module): - def __init__( self, hidden_size: int, @@ -328,7 +329,7 @@ def forward( return img, txt -class JoyTransformer3DModel(CachableDiT, OffloadableDiTMixin): +class JoyTransformer3DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ JoyImage Transformer 3D Model for image generation. diff --git a/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py b/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py index 178f362ab134..a4874179ef28 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py @@ -31,7 +31,9 @@ QuantizationConfig, ) from sglang.multimodal_gen.runtime.layers.visual_embedding import timestep_embedding -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -1235,7 +1237,7 @@ def forward( return hidden_states, audio_hidden_states -class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin): +class LTX2VideoTransformer3DModel(CachableDiT, LayerwiseOffloadableModuleMixin): _fsdp_shard_conditions = LTX2ArchConfig()._fsdp_shard_conditions _compile_conditions = LTX2ArchConfig()._compile_conditions _supported_attention_backends = LTX2ArchConfig()._supported_attention_backends diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py index cbe8489fc9bc..2cec6c3ebb8a 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py @@ -18,7 +18,9 @@ from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( QuantizationConfig, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT # Reuse common functions and classes from mova_video_dit @@ -101,7 +103,7 @@ def forward(self, input): return super().forward(input) -class WanAudioModel(CachableDiT, OffloadableDiTMixin): +class WanAudioModel(CachableDiT, LayerwiseOffloadableModuleMixin): _fsdp_shard_conditions = MOVAAudioConfig()._fsdp_shard_conditions _compile_conditions = MOVAAudioConfig()._compile_conditions _supported_attention_backends = MOVAAudioConfig()._supported_attention_backends diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py index 7e10767f5b53..5a6a2ded92f9 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -33,7 +33,9 @@ from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( QuantizationConfig, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -419,7 +421,7 @@ def forward(self, input): return super().forward(input) -class WanModel(CachableDiT, OffloadableDiTMixin): +class WanModel(CachableDiT, LayerwiseOffloadableModuleMixin): _fsdp_shard_conditions = MOVAVideoConfig()._fsdp_shard_conditions _compile_conditions = MOVAVideoConfig()._compile_conditions _supported_attention_backends = MOVAVideoConfig()._supported_attention_backends diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 048faf54553b..694d7ea953cc 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -45,7 +45,9 @@ from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( apply_flashinfer_rope_qk_inplace, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -652,9 +654,14 @@ def forward( "encoder_hidden_states_mask" ) - img_query, img_key, img_value, txt_query, txt_key, txt_value = ( - _get_qkv_projections(self, hidden_states, encoder_hidden_states) - ) + ( + img_query, + img_key, + img_value, + txt_query, + txt_key, + txt_value, + ) = _get_qkv_projections(self, hidden_states, encoder_hidden_states) # Reshape for multi-head attention img_query = img_query.unflatten(-1, (self.num_heads, -1)) @@ -1118,7 +1125,7 @@ def to_hashable(obj): return obj -class QwenImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): +class QwenImageTransformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): """ The Transformer model introduced in Qwen. diff --git a/python/sglang/multimodal_gen/runtime/models/dits/sana.py b/python/sglang/multimodal_gen/runtime/models/dits/sana.py index 56fb0c3dabb0..0a375f8c7460 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/sana.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/sana.py @@ -8,7 +8,9 @@ from sglang.multimodal_gen.configs.models.dits.sana import SanaConfig from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm from sglang.multimodal_gen.runtime.layers.visual_embedding import Timesteps -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -239,7 +241,7 @@ def forward( return hidden_states -class SanaTransformer2DModel(CachableDiT, OffloadableDiTMixin): +class SanaTransformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): _fsdp_shard_conditions = [ lambda n, m: isinstance(m, SanaTransformerBlock), diff --git a/python/sglang/multimodal_gen/runtime/models/dits/stablediffusion3.py b/python/sglang/multimodal_gen/runtime/models/dits/stablediffusion3.py index 1ceeec2d9092..419d41af8d55 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/stablediffusion3.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/stablediffusion3.py @@ -17,16 +17,20 @@ from sglang.multimodal_gen.configs.models.dits.stablediffusion3 import ( StableDiffusion3TransformerConfig, ) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) -class SD3Transformer2DModel(CachableDiT): +class SD3Transformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): _supports_gradient_checkpointing = True _no_split_modules = ["JointTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + layer_names = ["transformer_blocks"] def __init__( self, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index c38ac2356412..427ab3bc4485 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -49,7 +49,9 @@ TimestepEmbedder, ) from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.models.utils import ( _use_aiter, @@ -70,7 +72,6 @@ class WanImageEmbedding(torch.nn.Module): - def __init__(self, in_features: int, out_features: int): super().__init__() @@ -87,7 +88,6 @@ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: class WanTimeTextImageEmbedding(nn.Module): - def __init__( self, dim: int, @@ -130,7 +130,6 @@ def forward( class WanSelfAttention(nn.Module): - def __init__( self, dim: int, @@ -247,7 +246,6 @@ def forward(self, x, context, context_lens): class WanI2VCrossAttention(WanSelfAttention): - def __init__( self, dim: int, @@ -335,7 +333,6 @@ def forward(self, x, context, context_lens): class WanTransformerBlock(nn.Module): - def __init__( self, dim: int, @@ -514,9 +511,14 @@ def forward( else: # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) e = self.scale_shift_table + temb.float() - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - e.chunk(6, dim=1) - ) + ( + shift_msa, + scale_msa, + gate_msa, + c_shift_msa, + c_scale_msa, + c_gate_msa, + ) = e.chunk(6, dim=1) assert shift_msa.dtype == torch.float32 @@ -611,7 +613,6 @@ def forward( class WanTransformerBlock_VSA(nn.Module): - def __init__( self, dim: int, @@ -855,7 +856,7 @@ def forward( return hidden_states -class WanTransformer3DModel(CachableDiT, OffloadableDiTMixin): +class WanTransformer3DModel(CachableDiT, LayerwiseOffloadableModuleMixin): _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions _compile_conditions = WanVideoConfig()._compile_conditions _supported_attention_backends = WanVideoConfig()._supported_attention_backends @@ -1088,13 +1089,16 @@ def forward( else: ts_seq_len = None - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( - self.condition_embedder( - timestep, - encoder_hidden_states, - encoder_hidden_states_image, - timestep_seq_len=ts_seq_len, - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + ) = self.condition_embedder( + timestep, + encoder_hidden_states, + encoder_hidden_states_image, + timestep_seq_len=ts_seq_len, ) if ts_seq_len is not None: # batch_size, seq_len, 6, inner_dim diff --git a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py index aa3792a3c547..6dcac6fc1066 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py @@ -41,7 +41,9 @@ _apply_rotary_emb, apply_flashinfer_rope_qk_inplace, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -612,7 +614,7 @@ def __call__(self, ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return torch.cat(cos_out, dim=-1), torch.cat(sin_out, dim=-1) -class ZImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): +class ZImageTransformer2DModel(CachableDiT, LayerwiseOffloadableModuleMixin): _supports_gradient_checkpointing = True _no_split_modules = ["ZImageTransformerBlock"] _fsdp_shard_conditions = ZImageDitConfig().arch_config._fsdp_shard_conditions diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/base.py b/python/sglang/multimodal_gen/runtime/models/encoders/base.py index 7b3eece96bad..ebaa924ecc5a 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/base.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/base.py @@ -12,10 +12,20 @@ ImageEncoderConfig, TextEncoderConfig, ) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum -class TextEncoder(nn.Module, ABC): +class TextEncoder(nn.Module, ABC, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = [ + "layers", + "encoder.block", + "text_model.encoder.layers", + "model.language_model.layers", + ] _fsdp_shard_conditions: list = field(default_factory=lambda: []) _stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) _supported_attention_backends: set[AttentionBackendEnum] = ( @@ -49,7 +59,13 @@ def supported_attention_backends(self) -> set[AttentionBackendEnum]: return self._supported_attention_backends -class ImageEncoder(nn.Module, ABC): +class ImageEncoder(nn.Module, ABC, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = [ + "layers", + "vision_model.encoder.layers", + "model.visual.blocks", + ] _supported_attention_backends: set[AttentionBackendEnum] = ( ImageEncoderConfig()._supported_attention_backends ) diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/gemma2.py b/python/sglang/multimodal_gen/runtime/models/encoders/gemma2.py index 1832c709eb37..2d0eaa27a30c 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/gemma2.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/gemma2.py @@ -35,6 +35,9 @@ VocabParallelEmbedding, ) from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) logger = logging.getLogger(__name__) @@ -280,10 +283,12 @@ def forward( return hidden_states -class Gemma2Model(nn.Module): +class Gemma2Model(nn.Module, LayerwiseOffloadableModuleMixin): """Gemma2 text encoder model for SANA pipeline.""" _fsdp_shard_conditions = [] + layerwise_offload_default_enabled = False + layer_names = ["layers"] def __init__(self, config: Gemma2Config, **kwargs): super().__init__() diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py b/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py index 16ed42134b8a..0bc0ae9643b5 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py @@ -24,6 +24,9 @@ from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig from sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.utils.common import add_prefix logger = logging.getLogger(__name__) @@ -934,10 +937,13 @@ def _load_with_shard_id( return loaded_params -class Gemma3ForConditionalGeneration(nn.Module): +class Gemma3ForConditionalGeneration(nn.Module, LayerwiseOffloadableModuleMixin): # transformers 5.6.0 flattened SiglipVisionModel, dropping the # `vision_model` intermediate wrapper. Our reimpl keeps it, so remap # HF source keys back into our nested namespace when transferring weights. + layerwise_offload_default_enabled = False + layer_names = ["language_model.layers"] + param_names_mapping = { r"^(vision_tower\.)(embeddings|encoder|post_layernorm|head)\.": r"\1vision_model.\2.", } diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py b/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py index 448ffaa8bdc3..2f4347009150 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py @@ -11,6 +11,10 @@ Dinov2Model, ) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): @@ -28,7 +32,12 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): return np.concatenate([emb_sin, emb_cos], axis=1) -class ImageEncoder(nn.Module): +class ImageEncoder(nn.Module, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = [ + "model.encoder.layer", + "model.vision_model.encoder.layers", + ] MODEL_CLASS = None MODEL_CONFIG_CLASS = None mean = [] @@ -203,7 +212,15 @@ def build_image_encoder(config): raise ValueError(f'Unknown image encoder type: {config["type"]}') -class DualImageEncoder(nn.Module): +class DualImageEncoder(nn.Module, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = [ + "main_image_encoder.model.encoder.layer", + "main_image_encoder.model.vision_model.encoder.layers", + "additional_image_encoder.model.encoder.layer", + "additional_image_encoder.model.vision_model.encoder.layers", + ] + def __init__( self, main_image_encoder, @@ -232,7 +249,13 @@ def unconditional_embedding(self, batch_size, **kwargs): return outputs -class SingleImageEncoder(nn.Module): +class SingleImageEncoder(nn.Module, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = [ + "main_image_encoder.model.encoder.layer", + "main_image_encoder.model.vision_model.encoder.layers", + ] + def __init__( self, main_image_encoder, diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py b/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py index 8d3ca5f1442e..b6b5066ccd92 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py @@ -41,6 +41,9 @@ ) from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -370,7 +373,7 @@ def forward( ) -class Mistral3ForConditionalGeneration(nn.Module): +class Mistral3ForConditionalGeneration(nn.Module, LayerwiseOffloadableModuleMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^multi_modal_projector": "model.multi_modal_projector", @@ -378,6 +381,8 @@ class Mistral3ForConditionalGeneration(nn.Module): } _tied_weights_keys = ["lm_head.weight"] uses_sglang_forward_context = False + layerwise_offload_default_enabled = False + layer_names = ["model.language_model.layers"] def __init__(self, config: LlavaConfig): super().__init__() diff --git a/python/sglang/multimodal_gen/runtime/models/upsampler/latent_upsampler.py b/python/sglang/multimodal_gen/runtime/models/upsampler/latent_upsampler.py index c62ec85f3a33..ab1cbd69326c 100644 --- a/python/sglang/multimodal_gen/runtime/models/upsampler/latent_upsampler.py +++ b/python/sglang/multimodal_gen/runtime/models/upsampler/latent_upsampler.py @@ -8,6 +8,10 @@ import torch.nn.functional as F from einops import rearrange +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) + class BlurDownsample(torch.nn.Module): """Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.""" @@ -146,7 +150,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class LatentUpsampler(torch.nn.Module): +class LatentUpsampler(torch.nn.Module, LayerwiseOffloadableModuleMixin): """ Upsample VAE latents spatially and/or temporally. @@ -161,6 +165,9 @@ class LatentUpsampler(torch.nn.Module): rational_resampler: Whether to use rational resampler for spatial upsampling. """ + layerwise_offload_default_enabled = False + layer_names = ["res_blocks", "post_upsample_res_blocks"] + def __init__( self, in_channels: int = 128, diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py index c6a49b95c9ac..ec028423df9d 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py @@ -22,9 +22,12 @@ from torch import nn from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) -class AutoencoderKL(nn.Module): +class AutoencoderKL(nn.Module, LayerwiseOffloadableModuleMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -59,8 +62,10 @@ class AutoencoderKL(nn.Module): mid_block will only have resnet blocks """ + layerwise_offload_default_enabled = False _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + layer_names = ["encoder.down_blocks", "decoder.up_blocks"] def __init__( self, diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_dc.py b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_dc.py index f8434e949834..ac913e62d127 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_dc.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_dc.py @@ -6,14 +6,20 @@ from torch import nn from sglang.multimodal_gen.configs.models.vaes.sana import SanaVAEConfig +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) -class AutoencoderDC(nn.Module): +class AutoencoderDC(nn.Module, LayerwiseOffloadableModuleMixin): """Deep Compression Autoencoder wrapper with 32x spatial compression.""" + layerwise_offload_default_enabled = False + layer_names = ["_inner_model.encoder.down_blocks", "_inner_model.decoder.up_blocks"] + def __init__(self, config: SanaVAEConfig = None, **kwargs): super().__init__() self._config = config diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/common.py b/python/sglang/multimodal_gen/runtime/models/vaes/common.py index 6631a00d692c..57fc5b3b3e35 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/common.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/common.py @@ -18,9 +18,19 @@ get_sp_parallel_rank, get_sp_world_size, ) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) -class ParallelTiledVAE(ABC, nn.Module): +class ParallelTiledVAE(ABC, nn.Module, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = [ + "encoder.down_blocks", + "decoder.up_blocks", + "encoder.down", + "decoder.up", + ] tile_sample_min_height: int tile_sample_min_width: int tile_sample_min_num_frames: int diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/dac.py b/python/sglang/multimodal_gen/runtime/models/vaes/dac.py index 3d6d821ab830..cd3797f2b8de 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/dac.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/dac.py @@ -12,6 +12,9 @@ from torch import nn from sglang.multimodal_gen.configs.models.vaes.dac import DacVAEConfig +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.vaes.common import ( DiagonalGaussianDistribution, ) @@ -413,7 +416,10 @@ def forward(self, x): return self.model(x) -class DAC(nn.Module): +class DAC(nn.Module, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = ["encoder.block", "decoder.model"] + def __init__( self, config: DacVAEConfig, diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py b/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py index c2792ff18018..692998f162a0 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py @@ -12,6 +12,9 @@ from einops import rearrange, repeat from tqdm import tqdm +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) @@ -1099,9 +1102,12 @@ def run(self, grid_logit, *, octree_resolution, **kwargs): } -class VectsetVAE(nn.Module): +class VectsetVAE(nn.Module, LayerwiseOffloadableModuleMixin): """Base VAE class for vector set encoding.""" + layerwise_offload_default_enabled = False + layer_names = ["transformer.resblocks"] + def __init__(self, volume_decoder=None, surface_extractor=None): super().__init__() if volume_decoder is None: diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_3_condition_encoder.py b/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_3_condition_encoder.py index 587d0e573908..cba5764b98c7 100644 --- a/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_3_condition_encoder.py +++ b/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_3_condition_encoder.py @@ -3,6 +3,9 @@ import torch import torch.nn as nn +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) from sglang.multimodal_gen.runtime.models.vaes.ltx_2_vae import ( LTX2VideoCausalConv3d, LTX2VideoResnetBlock3d, @@ -110,7 +113,10 @@ def _make_ltx23_encoder_block( ) -class LTX23VideoConditionEncoder(nn.Module): +class LTX23VideoConditionEncoder(nn.Module, LayerwiseOffloadableModuleMixin): + layerwise_offload_default_enabled = False + layer_names = ["down_blocks"] + def __init__(self, config: dict[str, Any]) -> None: super().__init__() diff --git a/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py b/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py index efd5e56aca92..48985e8bdf8c 100644 --- a/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py +++ b/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py @@ -9,6 +9,9 @@ import torch.nn.functional as F from sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import LTXVocoderConfig +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, +) LRELU_SLOPE = 0.1 @@ -531,11 +534,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class LTX2Vocoder(ABC, nn.Module): +class LTX2Vocoder(ABC, nn.Module, LayerwiseOffloadableModuleMixin): r""" LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. """ + layerwise_offload_default_enabled = False + layer_names = [ + "upsamplers", + "resnets", + "vocoder.ups", + "vocoder.resblocks", + "bwe_generator.ups", + "bwe_generator.resblocks", + ] + def __init__( self, config: LTXVocoderConfig, diff --git a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py index 7781346fba6f..164a7aef62ec 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py @@ -20,7 +20,7 @@ from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ( +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( ComponentResidencyStrategy, get_global_component_residency_manager, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py index e2e26de70a13..bb8bfb5a4448 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py @@ -16,12 +16,12 @@ PipelineComponentLoader, ) from sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB -from sglang.multimodal_gen.runtime.managers.component_manager import ( +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( ComponentResidencyStrategy, ComponentUse, ResidencyState, ) -from sglang.multimodal_gen.runtime.managers.component_resident_strategies import ( +from sglang.multimodal_gen.runtime.managers.memory_managers.component_resident_strategies import ( SnapshotModuleResidency, SnapshotStrategy, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py b/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py index e9778b22903b..316896637575 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py @@ -24,7 +24,7 @@ from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( PipelineComponentLoader, ) -from sglang.multimodal_gen.runtime.managers.component_manager import ( +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( ComponentResidencyManager, ComponentResidencyStrategy, get_global_component_residency_manager, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py index 82fcc5b31112..d7979cff3981 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py @@ -170,8 +170,8 @@ def _temporarily_disable_offload( Yields: List of modules that had offload disabled. """ - from sglang.multimodal_gen.runtime.managers.layerwise_offload import ( - OffloadableDiTMixin, + from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + is_layerwise_offloaded_module, ) module_names = [] @@ -204,10 +204,9 @@ def _temporarily_disable_offload( offload_disabled_modules = [] for module_name in module_names: module = self.modules.get(module_name) - if module is not None and isinstance(module, OffloadableDiTMixin): - if module.layerwise_offload_managers is not None: - module.disable_offload() - offload_disabled_modules.append(module) + if module is not None and is_layerwise_offloaded_module(module): + module.disable_offload() + offload_disabled_modules.append(module) try: yield offload_disabled_modules diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py index a62cd2923f2e..5c6d4109b4ac 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py @@ -17,7 +17,9 @@ import torch from sglang.multimodal_gen.runtime.disaggregation.roles import RoleType -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.dedup import StageDedupMixin from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py index a6010a5240d0..a14d173c0b6e 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py @@ -11,7 +11,9 @@ from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.loader.component_loaders.vae_loader import VAELoader -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py index 6f82300cc2aa..9383326c8d3b 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py @@ -1,7 +1,9 @@ import torch from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req from sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import DecodingStage from sglang.multimodal_gen.runtime.platforms import current_platform diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index ef56c3b8a4c1..eab07f26d439 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -67,8 +67,10 @@ from sglang.multimodal_gen.runtime.loader.component_loaders.transformer_loader import ( TransformerLoader, ) -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( PipelineStage, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py index aeac9cd73bbc..7b37a0d0dc64 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py @@ -2,7 +2,9 @@ from diffusers.utils.torch_utils import randn_tensor from sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import is_ltx23_native_variant -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( clone_scheduler_runtime, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py index 8ac73ccb30db..c09ab36e36ed 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py @@ -8,7 +8,9 @@ import torch from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py index fdd9ad0d6c89..dedacb7d70fd 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py @@ -20,8 +20,13 @@ from sglang.multimodal_gen.configs.pipeline_configs.base import TextConditioningOutput from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + configure_layerwise_offload_modules, +) from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE from sglang.multimodal_gen.runtime.models.vision_utils import ( normalize, @@ -485,6 +490,14 @@ def _ensure_condition_image_encoder(self, server_args: ServerArgs) -> bool: safetensors_load_file(weights_path), strict=True ) self._condition_image_encoder_dir = encoder_dir + if server_args.should_configure_layerwise_offload_for_lazy_component(): + modules = {"condition_image_encoder": self._condition_image_encoder} + configure_layerwise_offload_modules( + modules, + server_args, + component_names=server_args.layerwise_offload_components, + warn_missing=False, + ) return True # -- image preprocessing --------------------------------------------- diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py index 84c70bbe0a4a..7fed04ffe215 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py @@ -13,8 +13,10 @@ import torch import torch.nn.functional as F -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( get_or_create_request_scheduler, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py index 9ab2d6f8652e..8c0e14e629c6 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py @@ -46,7 +46,9 @@ # Create aliases for backward compatibility video_sinusoidal_embedding_1d = sinusoidal_embedding_1d audio_sinusoidal_embedding_1d = sinusoidal_embedding_1d -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( PipelineStage, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py index 9f723d4953f6..73e3d3b513c9 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py @@ -8,8 +8,10 @@ from diffusers.utils.torch_utils import randn_tensor from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.models.vision_utils import load_image from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py index 8396f344819a..09f60d2b7e83 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py @@ -16,8 +16,10 @@ from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput from sglang.multimodal_gen.configs.pipeline_configs.base import TextConditioningOutput from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/upsampling.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/upsampling.py index b88a16e06833..f4dcddeffd9c 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/upsampling.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/upsampling.py @@ -1,7 +1,9 @@ import torch from sglang.multimodal_gen.runtime.distributed import get_local_torch_device -from sglang.multimodal_gen.runtime.managers.component_manager import ComponentUse +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + ComponentUse, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage from sglang.multimodal_gen.runtime.server_args import ServerArgs diff --git a/python/sglang/multimodal_gen/runtime/postprocess/realesrgan_upscaler.py b/python/sglang/multimodal_gen/runtime/postprocess/realesrgan_upscaler.py index e01dcfc53de3..9b6a6a056714 100644 --- a/python/sglang/multimodal_gen/runtime/postprocess/realesrgan_upscaler.py +++ b/python/sglang/multimodal_gen/runtime/postprocess/realesrgan_upscaler.py @@ -26,6 +26,9 @@ # Default HuggingFace repo and filename for Real-ESRGAN weights _DEFAULT_REALESRGAN_HF_REPO = "ai-forever/Real-ESRGAN" _DEFAULT_REALESRGAN_FILENAME = "RealESRGAN_x4.pth" +_LOW_MEMORY_TILED_UPSCALE_FREE_BYTES = 2 * 1024**3 +_REALESRGAN_TILE_SIZE = 256 +_REALESRGAN_TILE_PAD = 32 # Module-level cache: model_path -> UpscalerModel instance _MODEL_CACHE: dict[str, "UpscalerModel"] = {} @@ -263,6 +266,60 @@ def __init__(self, net: nn.Module, scale: int): def device(self) -> torch.device: return next(self.net.parameters()).device + @property + def dtype(self) -> torch.dtype: + return next(self.net.parameters()).dtype + + def _should_use_tiled_upscale(self, h: int, w: int) -> bool: + if self.device.type != "cuda": + return False + free_bytes, _ = torch.cuda.mem_get_info(self.device) + output_bytes = h * w * self.scale * self.scale * 3 * 4 + required_free_bytes = max( + _LOW_MEMORY_TILED_UPSCALE_FREE_BYTES, + output_bytes * 4, + ) + return free_bytes < required_free_bytes + + def _upscale_tiled_to_cpu( + self, + img_t: torch.Tensor, + tile_size: int = _REALESRGAN_TILE_SIZE, + tile_pad: int = _REALESRGAN_TILE_PAD, + ) -> torch.Tensor: + _, channels, h, w = img_t.shape + scale = self.scale + output = torch.empty( + (1, channels, h * scale, w * scale), + dtype=torch.float32, + device="cpu", + ) + + for y in range(0, h, tile_size): + tile_h = min(tile_size, h - y) + in_y0 = max(y - tile_pad, 0) + in_y1 = min(y + tile_h + tile_pad, h) + out_y0 = y * scale + out_y1 = (y + tile_h) * scale + crop_y0 = (y - in_y0) * scale + crop_y1 = crop_y0 + tile_h * scale + + for x in range(0, w, tile_size): + tile_w = min(tile_size, w - x) + in_x0 = max(x - tile_pad, 0) + in_x1 = min(x + tile_w + tile_pad, w) + out_x0 = x * scale + out_x1 = (x + tile_w) * scale + crop_x0 = (x - in_x0) * scale + crop_x1 = crop_x0 + tile_w * scale + + tile = img_t[..., in_y0:in_y1, in_x0:in_x1] + out_tile = self.net(tile) + out_tile = out_tile[..., crop_y0:crop_y1, crop_x0:crop_x1].float() + output[..., out_y0:out_y1, out_x0:out_x1].copy_(out_tile.cpu()) + + return output + def upscale(self, frame: np.ndarray, outscale: float | None = None) -> np.ndarray: """Upscale a single HWC uint8 frame → HWC uint8 frame. @@ -276,9 +333,34 @@ def upscale(self, frame: np.ndarray, outscale: float | None = None) -> np.ndarra """ h, w = frame.shape[:2] img = frame.astype(np.float32) / 255.0 - img_t = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device) + img_t = ( + torch.from_numpy(img) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=self.device, dtype=self.dtype) + ) with torch.no_grad(): - out = self.net(img_t) + if self._should_use_tiled_upscale(h, w): + logger.info( + "Using tiled Real-ESRGAN upscale for low GPU memory: " + "frame=%dx%d, tile_size=%d, tile_pad=%d", + w, + h, + _REALESRGAN_TILE_SIZE, + _REALESRGAN_TILE_PAD, + ) + out = self._upscale_tiled_to_cpu(img_t) + else: + try: + out = self.net(img_t) + except torch.cuda.OutOfMemoryError: + if self.device.type != "cuda": + raise + torch.cuda.empty_cache() + logger.warning( + "Real-ESRGAN full-frame upscale OOM; retrying with tiled upscale" + ) + out = self._upscale_tiled_to_cpu(img_t) # If the desired outscale differs from the model's native scale, # resize to (h * outscale, w * outscale). diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 9a57bf57ff01..295a52779bc9 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -37,6 +37,11 @@ NunchakuConfig, ) from sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload_components import ( + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS, + cpu_offload_flags_for_layerwise_components, + normalize_layerwise_offload_components, +) from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, current_platform, @@ -194,7 +199,9 @@ class ServerArgs(DisaggArgsMixin): # CPU offload parameters dit_cpu_offload: bool | None = None + # if true, add the legacy default DiT components dit_layerwise_offload: bool | None = None + layerwise_offload_components: list[str] | None = None dit_offload_prefetch_size: float = 0.0 text_encoder_cpu_offload: bool | None = None image_encoder_cpu_offload: bool | None = None @@ -323,14 +330,15 @@ def _adjust_path(self): def _adjust_parameters(self): """set defaults and normalize values.""" auto_tuner = ServerArgsAutoTuner(self) - auto_tuner.adjust() + auto_tuner.adjust_based_on_performance_mode() if auto_tuner.could_override_server_args(): self._adjust_offload() - auto_tuner.maybe_adjust_auto_dit_layerwise_offload() + auto_tuner.maybe_adjust_auto_default_layerwise_offload() self._adjust_ltx2_two_stage_device_mode() if auto_tuner.could_override_server_args(): auto_tuner.maybe_adjust_auto_component_residency_after_offload() auto_tuner.maybe_adjust_auto_fsdp_with_offload_enabled() + auto_tuner.maybe_replace_cpu_offloaded_components_with_layerwise() self._adjust_path() self._adjust_quant_config() self._adjust_warmup() @@ -339,6 +347,7 @@ def _adjust_parameters(self): self._adjust_parallelism() self._adjust_attention_backend() self._adjust_platform_specific() + self._adjust_layerwise_offload_components() self._adjust_autocast() auto_tuner.finalize_auto_flags() self.adjust_pipeline_config() @@ -802,6 +811,82 @@ def _adjust_platform_specific(self): if current_platform.is_mps(): self.use_fsdp_inference = False self.dit_layerwise_offload = False + self.layerwise_offload_components = None + + def should_configure_layerwise_offload_for_lazy_component(self) -> bool: + """Return whether a lazy-loaded component should try layerwise offload. + + Lazy components are loaded after the normal pipeline-wide configuration + pass, so they should only attempt layerwise configuration when a + component scope is present. + """ + return bool(self.layerwise_offload_components) + + @property + def is_dit_layerwise_offload_selected(self) -> bool: + """returns if dit is selected to be layerwise-offload""" + component_names = self.layerwise_offload_components + return bool( + component_names + and "dit_cpu_offload" + in cpu_offload_flags_for_layerwise_components(component_names) + ) + + def _adjust_layerwise_offload_components(self): + explicitly_set_component_names = normalize_layerwise_offload_components( + self.layerwise_offload_components + ) + if self.dit_layerwise_offload: + if explicitly_set_component_names is None: + explicitly_set_component_names = [LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS] + elif ( + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS + not in explicitly_set_component_names + ): + explicitly_set_component_names = [ + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS, + *explicitly_set_component_names, + ] + + if explicitly_set_component_names is not None: + self.layerwise_offload_components = explicitly_set_component_names + self._disable_cpu_offload_for_layerwise_components( + explicitly_set_component_names + ) + return + + def _disable_cpu_offload_for_layerwise_components( + self, component_names: list[str] + ) -> None: + # Layerwise offload owns H2D/D2H for selected component weights. + flag_names = cpu_offload_flags_for_layerwise_components(component_names) + disabled_flag_names: list[str] = [] + + if "dit_cpu_offload" in flag_names and self.dit_cpu_offload is not False: + self.dit_cpu_offload = False + disabled_flag_names.append("dit_cpu_offload") + if ( + "text_encoder_cpu_offload" in flag_names + and self.text_encoder_cpu_offload is not False + ): + self.text_encoder_cpu_offload = False + disabled_flag_names.append("text_encoder_cpu_offload") + if ( + "image_encoder_cpu_offload" in flag_names + and self.image_encoder_cpu_offload is not False + ): + self.image_encoder_cpu_offload = False + disabled_flag_names.append("image_encoder_cpu_offload") + if "vae_cpu_offload" in flag_names and self.vae_cpu_offload is not False: + self.vae_cpu_offload = False + disabled_flag_names.append("vae_cpu_offload") + + if disabled_flag_names: + logger.info( + "Disabling %s because the selected layerwise offload components " + "manage the same weights.", + ", ".join(disabled_flag_names), + ) def _adjust_autocast(self): if self.disable_autocast is None: @@ -944,13 +1029,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="The specific model version to use (can be a branch name, tag name, or commit id)", ) - # Parallelism - parser.add_argument( - "--num-gpus", - type=int, - default=ServerArgs.num_gpus, - help="The number of GPUs to use.", - ) parser.add_argument( "--performance-mode", "--mode", @@ -959,7 +1037,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=ServerArgs.performance_mode, help=( "Preset for performance and memory defaults. " - "'manual' keeps performance-related server args under explicit user control; " + "'manual' keeps performance-related server args under explicit user control, no adjustment is made; " "'auto' keeps safe defaults and applies high-confidence FSDP/CFG improvements; " "'speed' favors GPU-resident execution for lower latency and higher throughput, and may OOM; " "'memory' favors lower GPU memory usage; " @@ -967,6 +1045,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ), ) + # Parallelism + parser.add_argument( + "--num-gpus", + type=int, + default=ServerArgs.num_gpus, + help="The number of GPUs to use.", + ) parser.add_argument( "--tp-size", type=int, @@ -1093,8 +1178,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--dit-layerwise-offload", action=StoreBoolean, default=ServerArgs.dit_layerwise_offload, - help="Enable layerwise CPU offload with async H2D prefetch overlap for supported DiT models (e.g., Wan, MOVA). " - "Cannot be used together with cache-dit (SGLANG_CACHE_DIT_ENABLED), dit_cpu_offload, or use_fsdp_inference.", + help="Enable layerwise CPU offload with async H2D prefetch overlap for DiTs. " + "It only selects the legacy default DiT components. Cannot be used together with cache-dit " + "(SGLANG_CACHE_DIT_ENABLED), dit_cpu_offload, or use_fsdp_inference.", + ) + parser.add_argument( + "--layerwise-offload-components", + "--layerwise-offload-modules", + type=str, + nargs="+", + default=ServerArgs.layerwise_offload_components, + help="Select pipeline components for layerwise offload. " + "Use default to select the legacy default DiT components, " + "or all to select every layerwise-offloadable component. " + "This option does not imply --dit-layerwise-offload. Example: " + "--layerwise-offload-components text_encoder image_encoder.", ) parser.add_argument( "--dit-offload-prefetch-size", @@ -1520,7 +1618,7 @@ def from_cli_args( @classmethod def from_dict(cls, kwargs: dict[str, Any]) -> "ServerArgs": """Create a ServerArgs object from a dictionary.""" - attrs = [attr.name for attr in dataclasses.fields(cls)] + attrs = [attr.name for attr in dataclasses.fields(cls) if attr.init] server_args_kwargs: dict[str, Any] = {} component_paths = dict(kwargs.get("component_paths") or {}) @@ -1623,34 +1721,36 @@ def _validate_offload(self): "We do not recommend --dit-offload-prefetch-size to be between 0.5 and 1.0" ) - # validate dit_layerwise_offload conflicts - if self.dit_layerwise_offload: + # validate layerwise offload conflicts + if self.layerwise_offload_components: if self.dit_offload_prefetch_size < 0.0: raise ValueError("dit_offload_prefetch_size must be non-negative") if self.use_fsdp_inference: logger.warning( - "dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference." + "layerwise offload components are selected, automatically disabling use_fsdp_inference." ) self.use_fsdp_inference = False - if self.dit_cpu_offload is None: + should_disable_dit_cpu_offload = self.is_dit_layerwise_offload_selected + if should_disable_dit_cpu_offload and self.dit_cpu_offload is not False: logger.warning( - "dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload." + "layerwise offload is selected for DiT components, automatically disabling dit_cpu_offload." ) self.dit_cpu_offload = False - if envs.SGLANG_CACHE_DIT_ENABLED: + if envs.SGLANG_CACHE_DIT_ENABLED and should_disable_dit_cpu_offload: raise ValueError( - "dit_layerwise_offload cannot be enabled together with cache-dit. " + "DiT layerwise offload cannot be enabled together with cache-dit. " "cache-dit may reuse skipped blocks whose weights have been released by layerwise offload, " "causing shape mismatch errors. " - "Please disable either --dit-layerwise-offload or SGLANG_CACHE_DIT_ENABLED." + "Please disable --dit-layerwise-offload, remove DiT from --layerwise-offload-components, " + "or disable SGLANG_CACHE_DIT_ENABLED." ) logger.warning( - "dit_layerwise_offload is enabled: %slower GPU memory usage%s, but %smay reduce throughput or increase latency%s. " - "%sIf you are using multi-GPU deployment and already have enough memory headroom, prefer keeping dit_layerwise_offload disabled.%s " + "layerwise offload components are selected: %slower GPU memory usage%s, but %smay reduce throughput or increase latency%s. " + "%sIf you are using multi-GPU deployment and already have enough memory headroom, prefer keeping layerwise offload disabled.%s " "Please tune this based on your memory headroom and performance target.", GREEN, RESET, diff --git a/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py b/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py index d9b67e235990..d9e3811d9785 100644 --- a/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py +++ b/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py @@ -10,6 +10,9 @@ from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( ModelDeploymentConfig, ) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload_components import ( + LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS, +) from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -27,11 +30,14 @@ class ServerArgsAutoTuner: def __init__(self, server_args: "ServerArgs"): self.server_args = server_args self._explicit_memory_policy = self._has_explicit_memory_policy() + self._explicit_layerwise_replacement_policy = ( + self._has_explicit_layerwise_replacement_policy() + ) def _deployment_config(self) -> ModelDeploymentConfig: return self.server_args.pipeline_config.get_model_deployment_config() - def adjust(self) -> None: + def adjust_based_on_performance_mode(self) -> None: """Adjust the server args based on the performance mode""" args = self.server_args args.performance_mode = self._normalize_performance_mode() @@ -56,8 +62,8 @@ def adjust(self) -> None: self._set_gpu_resident_defaults(use_fsdp=True) return args.use_fsdp_inference = False - if self._can_apply_dit_layerwise_offload_policy(): - # apply dit layerwise offload to save VRAM during denoising stage + if self._can_apply_default_layerwise_offload_policy(): + # apply default layerwise offload to save VRAM during denoising stage self._set_layerwise_offload_defaults() else: self._set_component_offload_defaults() @@ -127,7 +133,8 @@ def maybe_adjust_auto_fsdp_with_offload_enabled(self) -> None: args.dit_layerwise_offload = False self._enable_cfg_parallel_if_supported() - def maybe_adjust_auto_dit_layerwise_offload(self) -> None: + def maybe_adjust_auto_default_layerwise_offload(self) -> None: + """adjust the default layerwise offload policy""" args = self.server_args if not self.could_override_server_args(): return @@ -178,6 +185,43 @@ def maybe_adjust_auto_dit_layerwise_offload(self) -> None: args.dit_layerwise_offload = True args.dit_cpu_offload = False + def maybe_replace_cpu_offloaded_components_with_layerwise(self) -> None: + args = self.server_args + if ( + not self.could_override_server_args() + or self._explicit_layerwise_replacement_policy + or current_platform.is_cpu() + or not current_platform.is_cuda() + or envs.SGLANG_CACHE_DIT_ENABLED + or args.use_fsdp_inference + or args.layerwise_offload_components is not None + ): + return + + layerwise_components: list[str] = [] + if args.dit_layerwise_offload: + layerwise_components.append(LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS) + + changed: list[str] = [] + if args.text_encoder_cpu_offload: + layerwise_components.append("text_encoder") + changed.append("text_encoder") + if args.image_encoder_cpu_offload: + layerwise_components.append("image_encoder") + changed.append("image_encoder") + if args.vae_cpu_offload: + layerwise_components.append("vae") + changed.append("vae") + + if not changed: + return + + args.layerwise_offload_components = layerwise_components + logger.info( + "Automatically replacing CPU offload with layerwise offload for components: %s", + ", ".join(changed), + ) + def finalize_auto_flags(self) -> None: """if some args are unset after all the adjustment, set them to defaults""" if not self.could_override_server_args(): @@ -266,7 +310,7 @@ def _set_layerwise_offload_defaults(self) -> None: if args.image_encoder_cpu_offload is None: args.image_encoder_cpu_offload = True - def _can_apply_dit_layerwise_offload_policy(self) -> bool: + def _can_apply_default_layerwise_offload_policy(self) -> bool: return ( self._deployment_config().auto_dit_layerwise_offload and not envs.SGLANG_CACHE_DIT_ENABLED @@ -299,8 +343,19 @@ def _has_explicit_memory_policy(self) -> bool: args.use_fsdp_inference is not None or args.dit_cpu_offload is not None or args.dit_layerwise_offload is not None + or args.layerwise_offload_components is not None + or args.text_encoder_cpu_offload is not None + or args.image_encoder_cpu_offload is not None + ) + + def _has_explicit_layerwise_replacement_policy(self) -> bool: + args = self.server_args + return ( + args.dit_layerwise_offload is not None + or args.layerwise_offload_components is not None or args.text_encoder_cpu_offload is not None or args.image_encoder_cpu_offload is not None + or args.vae_cpu_offload is True ) def _has_explicit_parallel_policy(self) -> bool: diff --git a/python/sglang/multimodal_gen/test/test_utils.py b/python/sglang/multimodal_gen/test/test_utils.py index 6c64ce70923d..de63df221fb6 100644 --- a/python/sglang/multimodal_gen/test/test_utils.py +++ b/python/sglang/multimodal_gen/test/test_utils.py @@ -801,7 +801,10 @@ def get_clip_model() -> tuple[Any, Any]: ) model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) - device = "cuda" if torch.cuda.is_available() else "cpu" + # ci server tests keep the generation server alive while consistency runs + device = ( + "cpu" if is_in_ci() else ("cuda" if torch.cuda.is_available() else "cpu") + ) model = model.to(device) model.eval() diff --git a/python/sglang/multimodal_gen/test/unit/test_layerwise_offload.py b/python/sglang/multimodal_gen/test/unit/test_layerwise_offload.py index c8a98fd0a4b5..e9d49a33b091 100644 --- a/python/sglang/multimodal_gen/test/unit/test_layerwise_offload.py +++ b/python/sglang/multimodal_gen/test/unit/test_layerwise_offload.py @@ -9,11 +9,22 @@ from sglang.multimodal_gen.runtime.loader.transformer_load_utils import ( _ModelOptFp8OffloadAdapter, ) -from sglang.multimodal_gen.runtime.managers import ( +from sglang.multimodal_gen.runtime.managers.memory_managers import ( layerwise_offload as layerwise_offload_mod, ) -from sglang.multimodal_gen.runtime.managers.layerwise_offload import ( +from sglang.multimodal_gen.runtime.managers.memory_managers.component_manager import ( + build_component_residency_strategy, +) +from sglang.multimodal_gen.runtime.managers.memory_managers.component_resident_strategies import ( + LayerwiseOffloadStrategy, + ResidentStrategy, + VanillaD2HStrategy, +) +from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import ( + LayerwiseOffloadableModuleMixin, LayerwiseOffloadManager, + configure_layerwise_offload_modules, + is_layerwise_offloaded_module, ) @@ -65,6 +76,65 @@ def __init__(self) -> None: self.blocks = torch.nn.ModuleList([_DummyBlock()]) +class _NestedDummyModel(torch.nn.Module, LayerwiseOffloadableModuleMixin): + layer_names = ["encoder.blocks"] + + def __init__(self) -> None: + super().__init__() + self.encoder = _DummyModel() + + +class _SharedBuffer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer( + "cache", torch.arange(12, dtype=torch.float32).reshape(6, 2) + ) + + +class _SharedBufferLayer(torch.nn.Module): + def __init__(self, shared: _SharedBuffer) -> None: + super().__init__() + self.shared = shared + self.weight = torch.nn.Parameter(torch.ones(2, 2, dtype=torch.float32)) + + +class _SharedBufferModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + shared = _SharedBuffer() + self.blocks = torch.nn.ModuleList( + [_SharedBufferLayer(shared), _SharedBufferLayer(shared)] + ) + + +class _NestedEncoderDummyModel(_NestedDummyModel): + layerwise_offload_default_enabled = False + + +class _LayerwiseComponent(torch.nn.Module, LayerwiseOffloadableModuleMixin): + layer_names = ["blocks"] + + def __init__(self, enabled: bool) -> None: + super().__init__() + self.blocks = torch.nn.ModuleList([_DummyBlock()]) + self.layerwise_offload_managers = [SimpleNamespace(enabled=enabled)] + + +def _server_args(**kwargs): + defaults = dict( + use_fsdp_inference=False, + dit_cpu_offload=False, + text_encoder_cpu_offload=False, + image_encoder_cpu_offload=False, + vae_cpu_offload=False, + dit_offload_prefetch_size=1, + pin_cpu_memory=False, + ) + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + def test_layerwise_offload_preserves_non_contiguous_stride(monkeypatch): monkeypatch.setattr( layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule @@ -103,6 +173,36 @@ def test_layerwise_offload_preserves_non_contiguous_stride(monkeypatch): assert torch.equal(reloaded_weight, original_weight) +def test_layerwise_offload_keeps_shared_buffers_resident(monkeypatch): + monkeypatch.setattr( + layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule + ) + monkeypatch.setattr(layerwise_offload_mod.current_platform, "device_type", "cpu") + + model = _SharedBufferModel() + original_cache = model.blocks[0].shared.cache.detach().clone() + + manager = LayerwiseOffloadManager( + model=model, + layers_attr_str="blocks", + num_layers=2, + enabled=True, + pin_cpu_memory=False, + prefetch_size=1, + ) + + assert not any( + "cache" in name + for metadata in manager._weight_metadata.values() + for name in metadata + ) + manager.release_layer(0) + + cache = model.blocks[1].shared.cache + assert torch.equal(cache, original_cache) + assert torch.equal(cache.index_select(0, torch.tensor([2])), original_cache[2:3]) + + def test_modelopt_fp8_adapter_keeps_layerwise_offload_enabled(): server_args = SimpleNamespace( dit_cpu_offload=True, @@ -119,6 +219,133 @@ def test_modelopt_fp8_adapter_keeps_layerwise_offload_enabled(): assert server_args.dit_layerwise_offload is True +def test_layerwise_capability_selects_layerwise_strategy_for_any_component(): + module = _LayerwiseComponent(enabled=True) + + assert is_layerwise_offloaded_module(module) + strategy = build_component_residency_strategy( + "text_encoder", module, _server_args(text_encoder_cpu_offload=True) + ) + + assert isinstance(strategy, LayerwiseOffloadStrategy) + + +def test_layerwise_configuration_uses_legacy_default_components(monkeypatch): + monkeypatch.setattr( + layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule + ) + monkeypatch.setattr(layerwise_offload_mod.current_platform, "device_type", "cpu") + layerwise_module = _NestedDummyModel() + modules = { + "text_encoder": layerwise_module, + "text_encoder_alias": layerwise_module, + "scheduler": object(), + } + + configured = configure_layerwise_offload_modules(modules, _server_args()) + + assert configured == ["text_encoder"] + assert is_layerwise_offloaded_module(layerwise_module) + + +def test_layerwise_configuration_filters_by_component_name(monkeypatch): + monkeypatch.setattr( + layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule + ) + monkeypatch.setattr(layerwise_offload_mod.current_platform, "device_type", "cpu") + text_encoder = _NestedEncoderDummyModel() + transformer = _NestedDummyModel() + vae = _NestedDummyModel() + modules = { + "custom_encoder_name": text_encoder, + "custom_transformer_name": transformer, + "custom_vae_name": vae, + } + + configured = configure_layerwise_offload_modules( + modules, _server_args(), component_names=["custom_encoder_name"] + ) + + assert configured == ["custom_encoder_name"] + assert is_layerwise_offloaded_module(text_encoder) + assert not is_layerwise_offloaded_module(transformer) + assert not is_layerwise_offloaded_module(vae) + + +def test_layerwise_configuration_default_marker_extends_legacy_defaults(monkeypatch): + monkeypatch.setattr( + layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule + ) + monkeypatch.setattr(layerwise_offload_mod.current_platform, "device_type", "cpu") + text_encoder = _NestedEncoderDummyModel() + text_encoder_2 = _NestedEncoderDummyModel() + transformer = _NestedDummyModel() + vae = _NestedEncoderDummyModel() + audio_vae = _NestedEncoderDummyModel() + condition_image_encoder = _NestedEncoderDummyModel() + modules = { + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "condition_image_encoder": condition_image_encoder, + } + + configured = configure_layerwise_offload_modules( + modules, _server_args(), component_names=["default", "text_encoder", "vae"] + ) + + assert configured == [ + "text_encoder", + "text_encoder_2", + "transformer", + "vae", + "audio_vae", + "condition_image_encoder", + ] + assert is_layerwise_offloaded_module(text_encoder) + assert is_layerwise_offloaded_module(text_encoder_2) + assert is_layerwise_offloaded_module(transformer) + assert is_layerwise_offloaded_module(vae) + assert is_layerwise_offloaded_module(audio_vae) + assert is_layerwise_offloaded_module(condition_image_encoder) + + +def test_layerwise_configuration_all_selects_every_capable_component(monkeypatch): + monkeypatch.setattr( + layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule + ) + monkeypatch.setattr(layerwise_offload_mod.current_platform, "device_type", "cpu") + text_encoder = _NestedEncoderDummyModel() + transformer = _NestedDummyModel() + modules = { + "custom_encoder_name": text_encoder, + "custom_transformer_name": transformer, + "scheduler": object(), + } + + configured = configure_layerwise_offload_modules( + modules, _server_args(), component_names=["all"] + ) + + assert configured == ["custom_encoder_name", "custom_transformer_name"] + assert is_layerwise_offloaded_module(text_encoder) + assert is_layerwise_offloaded_module(transformer) + + +def test_component_cpu_offload_strategy_remains_flag_driven(): + strategy = build_component_residency_strategy( + "text_encoder", _DummyModel(), _server_args(text_encoder_cpu_offload=True) + ) + assert isinstance(strategy, VanillaD2HStrategy) + + strategy = build_component_residency_strategy( + "unknown_component", _DummyModel(), _server_args(text_encoder_cpu_offload=True) + ) + assert isinstance(strategy, ResidentStrategy) + + def test_layerwise_offload_aligns_contiguous_tensor_offsets(monkeypatch): monkeypatch.setattr( layerwise_offload_mod.torch, "get_device_module", lambda: _FakeDeviceModule diff --git a/python/sglang/multimodal_gen/test/unit/test_server_args.py b/python/sglang/multimodal_gen/test/unit/test_server_args.py index 463577b66148..40dcbb4553ca 100644 --- a/python/sglang/multimodal_gen/test/unit/test_server_args.py +++ b/python/sglang/multimodal_gen/test/unit/test_server_args.py @@ -147,6 +147,106 @@ def test_dynamic_component_attention_backend_cli_args(self): server_args.component_attention_backends, {"text_encoder": "torch_sdpa"} ) + def test_layerwise_offload_components_imply_layerwise(self): + args = self._from_dict_without_model_resolution( + { + "model_path": "/data/my-model", + "performance_mode": "manual", + } + ) + args.layerwise_offload_components = ["text_encoder", "transformer"] + args._adjust_layerwise_offload_components() + + self.assertTrue(args.layerwise_offload_components) + self.assertEqual( + args.layerwise_offload_components, ["text_encoder", "transformer"] + ) + + def test_dit_layerwise_offload_extends_default_components(self): + args = self._from_dict_without_model_resolution( + { + "model_path": "/data/my-model", + "performance_mode": "manual", + "dit_layerwise_offload": True, + } + ) + + self.assertTrue(args.layerwise_offload_components) + self.assertEqual(args.layerwise_offload_components, ["default"]) + + def test_dit_layerwise_offload_from_kwargs(self): + with patch.object( + PipelineConfig, "from_kwargs", return_value=QwenImagePipelineConfig() + ): + args = ServerArgs.from_kwargs( + model_path="/data/my-model", + performance_mode="manual", + dit_layerwise_offload=True, + ) + + self.assertTrue(args.layerwise_offload_components) + self.assertEqual(args.layerwise_offload_components, ["default"]) + + def test_layerwise_offload_components_normalize_commas(self): + args = self._from_dict_without_model_resolution( + { + "model_path": "/data/my-model", + "performance_mode": "manual", + } + ) + args.layerwise_offload_components = ["text-encoder,transformer"] + args._adjust_layerwise_offload_components() + + self.assertEqual( + args.layerwise_offload_components, ["text_encoder", "transformer"] + ) + + def test_dit_layerwise_offload_cli_arg(self): + parser = FlexibleArgumentParser() + ServerArgs.add_cli_args(parser) + argv = [ + "--model-path", + "/fake", + "--performance-mode", + "manual", + "--dit-layerwise-offload", + "true", + ] + + with patch.object(sys, "argv", ["sglang"] + argv): + args, unknown_args = parser.parse_known_args(argv) + with patch.object( + PipelineConfig, "from_kwargs", return_value=QwenImagePipelineConfig() + ): + server_args = ServerArgs.from_cli_args(args, unknown_args) + + self.assertTrue(server_args.layerwise_offload_components) + self.assertEqual(server_args.layerwise_offload_components, ["default"]) + + def test_layerwise_offload_components_cli_args(self): + parser = FlexibleArgumentParser() + ServerArgs.add_cli_args(parser) + argv = [ + "--model-path", + "/fake", + "--performance-mode", + "manual", + "--layerwise-offload-components", + "transformer", + "text_encoder", + ] + + with patch.object(sys, "argv", ["sglang"] + argv): + args, unknown_args = parser.parse_known_args(argv) + with patch.object( + PipelineConfig, "from_kwargs", return_value=QwenImagePipelineConfig() + ): + server_args = ServerArgs.from_cli_args(args, unknown_args) + + self.assertEqual( + server_args.layerwise_offload_components, ["transformer", "text_encoder"] + ) + class TestOffloadDefaults(unittest.TestCase): def _from_dict_with_pipeline_config( @@ -208,6 +308,10 @@ def _from_dict_with_task_type( "sglang.multimodal_gen.runtime.server_args.current_platform.is_cpu", return_value=False, ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cuda", + return_value=True, + ), patch( "sglang.multimodal_gen.runtime.server_args.current_platform.get_device_total_memory", return_value=memory_gb * 1024**3, @@ -233,16 +337,48 @@ def test_vae_cpu_offload_defaults_false_on_low_memory_gpu(self): self.assertFalse(args.vae_cpu_offload) self.assertTrue(args.dit_cpu_offload) - self.assertTrue(args.text_encoder_cpu_offload) - self.assertTrue(args.image_encoder_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, ["text_encoder", "image_encoder"] + ) - def test_explicit_vae_cpu_offload_true_is_preserved(self): + def test_explicit_vae_cpu_offload_true_is_preserved_without_component_selection( + self, + ): args = self._from_dict_with_task_type( ModelTaskType.T2V, kwargs={"vae_cpu_offload": True}, ) self.assertTrue(args.vae_cpu_offload) + self.assertFalse(args.layerwise_offload_components) + + def test_layerwise_components_disable_matching_cpu_offloads(self): + args = self._from_dict_with_task_type( + ModelTaskType.T2V, + memory_gb=16, + kwargs={ + "performance_mode": "manual", + "dit_cpu_offload": True, + "text_encoder_cpu_offload": True, + "image_encoder_cpu_offload": True, + "vae_cpu_offload": True, + }, + ) + args.layerwise_offload_components = [ + "text_encoder", + "image_encoder", + "video_dit", + "vae", + ] + args._adjust_layerwise_offload_components() + + self.assertTrue(args.layerwise_offload_components) + self.assertFalse(args.dit_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertFalse(args.vae_cpu_offload) def test_pipeline_configs_declare_auto_tune_hints(self): qwen_deployment = QwenImagePipelineConfig().get_model_deployment_config() @@ -285,11 +421,12 @@ def test_manual_mode_preserves_unset_performance_args(self): self.assertIsNone(args.use_fsdp_inference) self.assertIsNone(args.dit_cpu_offload) self.assertIsNone(args.dit_layerwise_offload) + self.assertIsNone(args.layerwise_offload_components) self.assertIsNone(args.text_encoder_cpu_offload) self.assertIsNone(args.image_encoder_cpu_offload) self.assertFalse(args.enable_cfg_parallel) - def test_default_auto_keeps_legacy_single_gpu_offload_defaults(self): + def test_default_auto_replaces_text_encoder_cpu_offload_with_layerwise(self): args = self._from_dict_with_pipeline_config( QwenImagePipelineConfig(), kwargs={"model_path": "Qwen/Qwen-Image"}, @@ -298,11 +435,14 @@ def test_default_auto_keeps_legacy_single_gpu_offload_defaults(self): self.assertEqual(args.performance_mode, "auto") self.assertFalse(args.use_fsdp_inference) self.assertTrue(args.dit_cpu_offload) - self.assertFalse(args.dit_layerwise_offload) - self.assertTrue(args.text_encoder_cpu_offload) + self.assertTrue(args.layerwise_offload_components) + self.assertFalse(args.text_encoder_cpu_offload) self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual(args.layerwise_offload_components, ["text_encoder"]) - def test_auto_ltx_snapshot_keeps_dit_offload_with_headroom(self): + def test_auto_ltx_snapshot_keeps_dit_offload_and_replaces_encoder_cpu_offload( + self, + ): args = self._from_dict_with_pipeline_config( LTX2PipelineConfig(), available_memory_gb=76, @@ -316,8 +456,12 @@ def test_auto_ltx_snapshot_keeps_dit_offload_with_headroom(self): self.assertEqual(args.ltx2_two_stage_device_mode, "snapshot") self.assertTrue(args.dit_cpu_offload) - self.assertTrue(args.text_encoder_cpu_offload) - self.assertTrue(args.image_encoder_cpu_offload) + self.assertTrue(args.layerwise_offload_components) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, ["text_encoder", "image_encoder"] + ) def test_auto_wan_layerwise_offload_is_enabled_without_fsdp(self): args = self._from_dict_with_pipeline_config( @@ -325,8 +469,14 @@ def test_auto_wan_layerwise_offload_is_enabled_without_fsdp(self): kwargs={"performance_mode": "auto"}, ) - self.assertTrue(args.dit_layerwise_offload) + self.assertTrue(args.layerwise_offload_components) self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, + ["default", "text_encoder", "image_encoder"], + ) def test_memory_wan_layerwise_offload_is_enabled_without_fsdp(self): args = self._from_dict_with_pipeline_config( @@ -334,8 +484,14 @@ def test_memory_wan_layerwise_offload_is_enabled_without_fsdp(self): kwargs={"performance_mode": "memory"}, ) - self.assertTrue(args.dit_layerwise_offload) + self.assertTrue(args.layerwise_offload_components) self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, + ["default", "text_encoder", "image_encoder"], + ) def test_auto_wan_layerwise_offload_does_not_disable_explicit_fsdp(self): args = self._from_dict_with_pipeline_config( @@ -348,7 +504,7 @@ def test_auto_wan_layerwise_offload_does_not_disable_explicit_fsdp(self): }, ) - self.assertFalse(args.dit_layerwise_offload) + self.assertFalse(args.layerwise_offload_components) self.assertTrue(args.use_fsdp_inference) def test_auto_multi_gpu_wan_uses_layerwise_offload_without_cfg(self): @@ -365,9 +521,52 @@ def test_auto_multi_gpu_wan_uses_layerwise_offload_without_cfg(self): self.assertFalse(args.use_fsdp_inference) self.assertFalse(args.enable_cfg_parallel) self.assertFalse(args.dit_cpu_offload) - self.assertTrue(args.dit_layerwise_offload) + self.assertTrue(args.layerwise_offload_components) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, + ["default", "text_encoder", "image_encoder"], + ) + + def test_explicit_multi_gpu_dit_layerwise_only_selects_default_component(self): + args = self._from_dict_with_pipeline_config( + MOVAPipelineConfig(), + kwargs={ + "model_path": "OpenMOSS-Team/MOVA-360p", + "num_gpus": 2, + "dit_layerwise_offload": True, + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.dit_cpu_offload) + self.assertTrue(args.layerwise_offload_components) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertTrue(args.image_encoder_cpu_offload) + self.assertEqual(args.layerwise_offload_components, ["default"]) + + def test_auto_multi_gpu_ltx_replaces_component_cpu_offload_with_resident_dit(self): + args = self._from_dict_with_pipeline_config( + LTX2PipelineConfig(), + available_memory_gb=76, + kwargs={ + "model_path": "Lightricks/LTX-2", + "num_gpus": 2, + "pipeline_class_name": "LTX2TwoStagePipeline", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.dit_cpu_offload) + self.assertTrue(args.layerwise_offload_components) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, ["text_encoder", "image_encoder"] + ) - def test_auto_multi_gpu_qwen_keeps_legacy_offload_with_cfg(self): + def test_auto_multi_gpu_qwen_replaces_text_encoder_offload_with_cfg(self): args = self._from_dict_with_pipeline_config( QwenImagePipelineConfig(), kwargs={ @@ -380,9 +579,10 @@ def test_auto_multi_gpu_qwen_keeps_legacy_offload_with_cfg(self): self.assertFalse(args.use_fsdp_inference) self.assertTrue(args.enable_cfg_parallel) self.assertTrue(args.dit_cpu_offload) - self.assertFalse(args.dit_layerwise_offload) - self.assertTrue(args.text_encoder_cpu_offload) + self.assertTrue(args.layerwise_offload_components) + self.assertFalse(args.text_encoder_cpu_offload) self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual(args.layerwise_offload_components, ["text_encoder"]) def test_auto_multi_gpu_zimage_base_prefers_fsdp(self): args = self._from_dict_with_pipeline_config( @@ -424,8 +624,9 @@ def test_auto_multi_gpu_qwen_preserves_explicit_fsdp_false(self): self.assertFalse(args.use_fsdp_inference) self.assertTrue(args.enable_cfg_parallel) self.assertTrue(args.dit_cpu_offload) - self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual(args.layerwise_offload_components, ["text_encoder"]) def test_auto_multi_gpu_qwen_skips_fsdp_when_available_memory_is_low(self): args = self._from_dict_with_pipeline_config( @@ -441,8 +642,9 @@ def test_auto_multi_gpu_qwen_skips_fsdp_when_available_memory_is_low(self): self.assertFalse(args.use_fsdp_inference) self.assertTrue(args.enable_cfg_parallel) self.assertTrue(args.dit_cpu_offload) - self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual(args.layerwise_offload_components, ["text_encoder"]) def test_auto_multi_gpu_qwen_uses_selected_gpu_min_available_memory(self): args = self._from_dict_with_pipeline_config( @@ -459,7 +661,7 @@ def test_auto_multi_gpu_qwen_uses_selected_gpu_min_available_memory(self): self.assertFalse(args.use_fsdp_inference) self.assertTrue(args.enable_cfg_parallel) - def test_auto_multi_gpu_qwen_keeps_legacy_offload_with_headroom(self): + def test_auto_multi_gpu_qwen_replaces_text_encoder_offload_with_headroom(self): args = self._from_dict_with_pipeline_config( QwenImagePipelineConfig(), available_memory_gb={1: 72, 2: 80}, @@ -474,8 +676,9 @@ def test_auto_multi_gpu_qwen_keeps_legacy_offload_with_headroom(self): self.assertFalse(args.use_fsdp_inference) self.assertTrue(args.enable_cfg_parallel) self.assertTrue(args.dit_cpu_offload) - self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual(args.layerwise_offload_components, ["text_encoder"]) def test_speed_mode_single_gpu_disables_offload(self): args = self._from_dict_with_pipeline_config( @@ -489,7 +692,7 @@ def test_speed_mode_single_gpu_disables_offload(self): self.assertEqual(args.performance_mode, "speed") self.assertFalse(args.use_fsdp_inference) self.assertFalse(args.dit_cpu_offload) - self.assertFalse(args.dit_layerwise_offload) + self.assertFalse(args.layerwise_offload_components) self.assertFalse(args.text_encoder_cpu_offload) self.assertFalse(args.image_encoder_cpu_offload) @@ -518,10 +721,14 @@ def test_memory_mode_wan_uses_layerwise_offload(self): ) self.assertFalse(args.use_fsdp_inference) - self.assertTrue(args.dit_layerwise_offload) + self.assertTrue(args.layerwise_offload_components) self.assertFalse(args.dit_cpu_offload) - self.assertTrue(args.text_encoder_cpu_offload) - self.assertTrue(args.image_encoder_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + self.assertEqual( + args.layerwise_offload_components, + ["default", "text_encoder", "image_encoder"], + ) def test_memory_mode_preserves_explicit_fsdp(self): args = self._from_dict_with_pipeline_config( @@ -535,7 +742,7 @@ def test_memory_mode_preserves_explicit_fsdp(self): ) self.assertTrue(args.use_fsdp_inference) - self.assertFalse(args.dit_layerwise_offload) + self.assertFalse(args.layerwise_offload_components) self.assertFalse(args.dit_cpu_offload) def test_invalid_performance_mode_raises(self):