Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1aba730
Generalize layerwise offload residency mixin
mickqian May 7, 2026
c7bb890
Apply residency lint formatting
mickqian May 7, 2026
955f019
Enable component-wide layerwise offload setup
mickqian May 7, 2026
f480946
Fix layerwise offload formatting
mickqian May 7, 2026
93a121e
Add layerwise offload module group selector
mickqian May 7, 2026
04a4b88
Fix layerwise module selector formatting
mickqian May 7, 2026
c17c9d8
Clarify layerwise offload module groups
mickqian May 7, 2026
d7b90be
Select layerwise offload by component name
mickqian May 7, 2026
ba9aeff
Handle DTensor weights in layerwise offload
mickqian May 7, 2026
36fe728
Skip module to() for DTensor layerwise offload
mickqian May 7, 2026
71a171e
Disable conflicting offload for layerwise components
mickqian May 7, 2026
e66ed84
Keep layer buffers resident under layerwise offload
mickqian May 7, 2026
fdf0227
Preserve layerwise offload alias CLI parsing
mickqian May 7, 2026
b64a766
upd
mickqian May 8, 2026
4b8050f
upd
mickqian May 8, 2026
3a826d3
Merge remote-tracking branch 'origin/main' into codex/component-resid…
mickqian May 14, 2026
c97a6fc
style: format flux2 modulation unpack
mickqian May 14, 2026
075ec0b
upd
mickqian May 14, 2026
b3d6eb5
Merge remote-tracking branch 'origin/main' into codex/component-resid…
mickqian May 14, 2026
988eb6b
lint
mickqian May 14, 2026
19f4125
upd
mickqian May 14, 2026
adf6a05
Fix diffusion consistency CLIP device in CI
mickqian May 14, 2026
e093024
lint
mickqian May 14, 2026
efd53db
Use tiled RealESRGAN under low GPU memory
mickqian May 14, 2026
15e522f
lint
mickqian May 14, 2026
936efe2
Constrain component layerwise replacement on multi-GPU
mickqian May 14, 2026
b65bb87
Preserve component CPU offload with multi-GPU layerwise DiT
mickqian May 15, 2026
12d6c26
Keep encoder CPU offload when multi-GPU DiT is resident
mickqian May 15, 2026
bd7ce96
Stage encoder loading for layerwise offload
mickqian May 15, 2026
653e441
Format image encoder startup staging
mickqian May 15, 2026
d1a2be0
upd
mickqian May 15, 2026
e463679
Clarify layerwise offload configuration flow
mickqian May 15, 2026
5b3913a
Name lazy component layerwise condition
mickqian May 15, 2026
0300715
Clarify layerwise offload server arg
mickqian May 15, 2026
dc3d505
Keep layerwise offload fields internal
mickqian May 15, 2026
2200582
Use component selection for layerwise offload
mickqian May 15, 2026
edcbc1e
lint
mickqian May 15, 2026
461bc08
Move DiT layerwise selection check to server args
mickqian May 15, 2026
f7698fe
upd
mickqian May 15, 2026
9cf1b88
upd
mickqian May 15, 2026
1d78f8a
Fix layerwise offload unit regressions
mickqian May 15, 2026
9476f04
upd
mickqian May 15, 2026
5afe306
Fix encoder layerwise load kwargs hook
mickqian May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/diffusion/api/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ HTTP server-only arguments are ignored by `sglang generate`.

For diffusers pipelines, Cache-DiT can be enabled with `SGLANG_CACHE_DIT_ENABLED=true` or `--cache-dit-config`. See [Cache-DiT](../performance/cache/cache_dit.md).

### Layerwise Offload

Use layerwise offload when a large component does not fit comfortably in GPU memory. By default, `--dit-layerwise-offload` only applies to legacy DiT components. Use `--layerwise-offload-components` to select pipeline component names explicitly (`--layerwise-offload-modules` is accepted as an alias):

```bash
sglang generate \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--dit-layerwise-offload \
--layerwise-offload-components transformer text_encoder \
--dit-offload-prefetch-size 0 \
--prompt "A quiet city street after rain"
```

The values must match keys in the selected pipeline's `pipeline.modules`, such as `transformer`, `text_encoder`, `image_encoder`, `vae`, `condition_image_encoder`, `spatial_upsampler`, or `vocoder`. Use `all` to select every layerwise-offloadable component. Prefer the smallest component set that solves the memory issue because layerwise offload can increase latency.

## Serve

`sglang serve` starts the HTTP server and keeps the model loaded for repeated requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class PipelineConfig:
# Image encoder configuration
image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig)
image_encoder_precision: str = "fp32"
image_encoder_extra_args: dict = field(default_factory=lambda: {})

# Text encoder configuration
DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp32",)
Expand All @@ -240,9 +241,6 @@ class PipelineConfig:
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",))
text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}])

# image encoding
image_encoder_extra_args: dict = field(default_factory=lambda: {})

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig()

Expand All @@ -266,6 +264,10 @@ def postprocess_image(self, image):
# DMD parameters
dmd_denoising_steps: list[int] | None = field(default=None)

def get_model_deployment_config(self) -> ModelDeploymentConfig:
# return the model-specific config for optimal deployment setting
return ModelDeploymentConfig()

# Wan2.2 TI2V parameters
boundary_ratio: float | None = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ class WanT2V480PConfig(PipelineConfig):
vae_precision: str = "fp32"
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",))

# WanConfig-specific added parameters

def __post_init__(self):
self.vae_config.load_encoder = False
self.vae_config.load_decoder = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
component_name_to_loader_cls,
get_memory_usage_of_component,
)
from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import (
configure_layerwise_offload_modules,
is_layerwise_offloaded_module,
)
from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload_components import (
LAYERWISE_OFFLOAD_ALL_COMPONENTS,
LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS,
layerwise_component_matches_selection,
)
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
Expand Down Expand Up @@ -96,6 +105,66 @@ def target_device(self, should_offload):
else:
return get_local_torch_device()

def customized_load_kwargs_for_component(
self, _server_args: ServerArgs, _component_name: str
) -> dict[str, Any]:
return {}

@staticmethod
def _is_component_set_as_layerwise_load(
server_args: ServerArgs, component_name: str
) -> bool:
"""if a component should be loaded in a layerwise-fashion"""
selected_component_names = server_args.layerwise_offload_components
if selected_component_names is None:
return False
selected_component_names = set(selected_component_names)
if LAYERWISE_OFFLOAD_ALL_COMPONENTS in selected_component_names:
return True
explicit_component_names = selected_component_names - {
LAYERWISE_OFFLOAD_DEFAULT_COMPONENTS
}
return any(
layerwise_component_matches_selection(component_name, selected_component)
for selected_component in explicit_component_names
)

def _maybe_configure_layerwise_after_startup_cpu_staging(
self,
component: AutoModel,
server_args: ServerArgs,
component_name: str,
load_kwargs: dict[str, Any],
) -> AutoModel:
if not load_kwargs.get("cpu_offload_flag"):
return component
if not isinstance(component, nn.Module):
return component

# try to configure layerwise-offload with the component
configured_components = configure_layerwise_offload_modules(
{component_name: component},
server_args,
component_names=server_args.layerwise_offload_components,
warn_missing=False,
)
if is_layerwise_offloaded_module(component):
logger.info(
"Configured layerwise offload for %s immediately after startup CPU staging",
component_name,
)
return component

logger.warning(
"Layerwise startup CPU staging was requested for %s, but the loaded "
"module did not enable layerwise offload. Moving it to GPU.",
component_name,
)
# ensures the module is on GPU
if component_name in configured_components:
return component
return component.to(get_local_torch_device())

def load(
self,
component_model_path: str,
Expand Down Expand Up @@ -135,8 +204,15 @@ def load(
with component_attn_backend_context_manager(
attn_backend, component_name=component_attn_name
):
load_kwargs = self.customized_load_kwargs_for_component(
server_args, component_name
)
component = self.load_customized(
component_model_path, server_args, component_name
component_model_path, server_args, component_name, **load_kwargs
)
# configure layerwise to make enough VRAM headroom
component = self._maybe_configure_layerwise_after_startup_cpu_staging(
component, server_args, component_name, load_kwargs
)
source = "sgl-diffusion"
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def should_offload(self, server_args, model_config: ModelConfig | None = None):
return use_cpu_offload

def load_customized(
self, component_model_path: str, server_args: ServerArgs, *args
self,
component_model_path: str,
server_args: ServerArgs,
component_name: str = "image_encoder",
cpu_offload_flag: bool | None = None,
):
"""Load the text encoders based on the model path, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
Expand All @@ -53,5 +57,9 @@ def load_customized(
encoder_config,
server_args,
server_args.pipeline_config.image_encoder_precision,
cpu_offload_flag=server_args.image_encoder_cpu_offload,
cpu_offload_flag=(
cpu_offload_flag
if cpu_offload_flag is not None
else server_args.image_encoder_cpu_offload
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def should_offload(self, server_args, model_config: ModelConfig | None = None):
use_cpu_offload = should_offload and len(fsdp_shard_conditions) > 0
return use_cpu_offload

def customized_load_kwargs_for_component(
self, server_args: ServerArgs, component_name: str
) -> dict[str, bool]:
if ComponentLoader._is_component_set_as_layerwise_load(
server_args, component_name
):
logger.info(
"Loading %s on CPU first because it is selected for layerwise offload",
component_name,
)
return {"cpu_offload_flag": True}
return {}

def load_native(
self,
component_model_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
from sglang.multimodal_gen.runtime.loader.weight_utils import (
safetensors_weights_iterator,
)
from sglang.multimodal_gen.runtime.managers.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import (
is_layerwise_offloaded_module,
)
from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down Expand Up @@ -114,7 +116,7 @@ def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None:
update_cpu_weights(); non-offloaded parameters use in-place copy.
"""
offload_managers: list = []
if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:
if is_layerwise_offloaded_module(module):
offload_managers = [m for m in module.layerwise_offload_managers if m.enabled]

if offload_managers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, server_args: "ServerArgs", gpu_id: int):
self._mode = getattr(server_args, "batching_mode", "dynamic")
self._user_max_batch_size = max(1, int(server_args.batching_max_size))
self._model_path = server_args.model_path
self._offload = bool(server_args.dit_layerwise_offload)
self._offload = bool(server_args.layerwise_offload_components)
self._device_memory_gb = self._get_device_memory_gb(gpu_id)
self._rules = load_batching_config(server_args.batching_config)
self._pipeline_config = server_args.pipeline_config
Expand Down
31 changes: 10 additions & 21 deletions python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
WeightsUpdater,
get_updatable_modules,
)
from sglang.multimodal_gen.runtime.managers.layerwise_offload import (
OffloadableDiTMixin,
from sglang.multimodal_gen.runtime.managers.memory_managers.layerwise_offload import (
configure_layerwise_offload_modules,
iter_materialized_weights,
)
from sglang.multimodal_gen.runtime.pipelines_core import (
Expand Down Expand Up @@ -165,23 +165,12 @@ def init_device_and_model(self) -> None:

# apply layerwise offload after lora is applied while building LoRAPipeline
# otherwise empty offloaded weights could fail lora converting
if self.server_args.dit_layerwise_offload:
# enable layerwise offload if possible
for module_name in [
"transformer",
"transformer_2",
"video_dit",
"video_dit_2",
"audio_dit",
]:
dit = self.pipeline.get_module(module_name)
if dit:
if isinstance(dit, OffloadableDiTMixin):
dit.configure_layerwise_offload(self.server_args)
else:
logger.info(
f"Module {type(dit).__name__} does not support layerwise offload. Skipping."
)
if self.server_args.layerwise_offload_components:
configure_layerwise_offload_modules(
self.pipeline.modules,
self.server_args,
component_names=self.server_args.layerwise_offload_components,
)

logger.info(
f"Worker {self.rank}: Initialized device, model, and distributed environment."
Expand Down Expand Up @@ -234,7 +223,7 @@ def _format_offload_disable_suggestions(self, components: List[str]) -> str:
elif component in ("text_encoder", "text_encoder_2"):
arg = "--text-encoder-cpu-offload"
elif component == "transformer":
if self.server_args.dit_layerwise_offload:
if self.server_args.is_dit_layerwise_offload_selected:
arg = "--dit-layerwise-offload"
elif self.server_args.dit_cpu_offload:
arg = "--dit-cpu-offload"
Expand Down Expand Up @@ -740,7 +729,7 @@ def get_can_stay_resident_components(
# If the flag is True, it is currently offloaded, so it is a candidate to stay resident.
offload_flags = {
"transformer": self.server_args.dit_cpu_offload
or self.server_args.dit_layerwise_offload,
or self.server_args.is_dit_layerwise_offload_selected,
"vae": self.server_args.vae_cpu_offload,
"text_encoder": self.server_args.text_encoder_cpu_offload,
"text_encoder_2": self.server_args.text_encoder_cpu_offload,
Expand Down
Loading
Loading