Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions .claude/skills/add-diffusion-model/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,11 @@ For Omni or custom models, create:

Required updates:
1. `docs/user_guide/diffusion/parallelism_acceleration.md` — parallelism support table
2. `docs/user_guide/diffusion/teacache.md` — if TeaCache supported
3. `docs/user_guide/diffusion/cache_dit_acceleration.md` — if Cache-DiT supported
4. `examples/offline_inference/xxx/README.md` — offline example docs
5. `examples/online_serve/xxx/README.md` — online serve docs
2. `docs/user_guide/diffusion/cpu_offload_diffusion.md` — if CPU offload supported (add to supported models table)
3. `docs/user_guide/diffusion/teacache.md` — if TeaCache supported
4. `docs/user_guide/diffusion/cache_dit_acceleration.md` — if Cache-DiT supported
5. `examples/offline_inference/xxx/README.md` — offline example docs
6. `examples/online_serve/xxx/README.md` — online serve docs

### Step 8: Add E2E Tests (Recommended)

Expand Down Expand Up @@ -512,6 +513,37 @@ After adding parallelism support, update:
1. `docs/user_guide/diffusion/parallelism_acceleration.md` — add your model to the support table
2. Record which parallelism methods are supported (USP, Ring, CFG, TP, HSDP, VAE-Patch)

### Step 11: Add CPU Offload Support

Implement `SupportsModuleOffload` on your pipeline class to enable
`--enable-cpu-offload` and `--enable-layerwise-offload`. The protocol
declares which submodules the offloader should manage:

```python
from typing import ClassVar
from vllm_omni.diffusion.models.interface import SupportsModuleOffload

class YourPipeline(nn.Module, SupportsModuleOffload):
_dit_modules: ClassVar[list[str]] = ["transformer"]
_encoder_modules: ClassVar[list[str]] = ["text_encoder"]
_vae_modules: ClassVar[list[str]] = ["vae"]
_resident_modules: ClassVar[list[str]] = [] # optional
```

- `_dit_modules`: denoising submodules (kept on GPU during diffusion loop)
- `_encoder_modules`: encoder/vision submodules (offloaded to CPU during diffusion loop)
- `_vae_modules`: VAE(s) (handled by both sequential and layerwise backends)
- `_resident_modules`: additional modules to pin on GPU during layerwise
offloading (e.g. embedders, connectors). Only used by the layerwise
backend. Optional — defaults to `[]`.

All attribute names support dotted paths for nested submodules
(e.g. `"pipe.transformer"`, `"bagel.time_embedder"`).

Pipelines without `SupportsModuleOffload` fall back to scanning
well-known attribute names (`transformer`, `text_encoder`, `vae`,
etc.), which fails for non-standard names.

---

## Iterative Development Tips
Expand Down
55 changes: 51 additions & 4 deletions docs/user_guide/diffusion/cpu_offload_diffusion.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,45 @@ m = Omni(model="Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_cpu_offload=True)
vllm-omni serve diffusion Wan-AI/Wan2.2-T2V-A14B-Diffusers --enable-cpu-offload
```

### To Support a Model

Implement the `SupportsModuleOffload` protocol to declare which
submodules participate in offloading:

```python
from typing import ClassVar
from vllm_omni.diffusion.models.interface import SupportsModuleOffload

class MyPipeline(nn.Module, SupportsModuleOffload):
_dit_modules: ClassVar[list[str]] = ["transformer"]
_encoder_modules: ClassVar[list[str]] = ["text_encoder", "vision_model"]
_vae_modules: ClassVar[list[str]] = ["vae"]
_resident_modules: ClassVar[list[str]] = [] # optional

def __init__(self):
super().__init__()
self.transformer = ... # DiT — stays on GPU during denoising
self.text_encoder = ... # Encoder — offloaded to CPU during denoising
self.vision_model = ... # Encoder — offloaded to CPU during denoising
self.vae = ... # VAE — always on GPU
```

- `_dit_modules`: attribute names of denoising submodules (kept on GPU
during the diffusion loop).
- `_encoder_modules`: attribute names of encoder/vision submodules
(offloaded to CPU during the diffusion loop).
- `_vae_modules`: attribute names of VAE(s) (always kept on GPU, not
part of the mutual exclusion hooks).
- `_resident_modules`: attribute names of small submodules that must
stay on GPU during layerwise offloading (e.g. embedders, connectors).
Optional — defaults to `[]`.

All attribute names support dotted paths for nested submodules
(e.g. `"pipe.transformer"`, `"bagel.time_embedder"`).

Both DiT and encoder lists are needed because the offload hooks use
mutual exclusion: when one group runs, the other moves to CPU.

### Limitations
- Cold start latency increases
- Adds overhead from CPU-GPU transfers between encoder and denoising phases
Expand Down Expand Up @@ -116,11 +155,19 @@ class Flux2Transformer2DModel(nn.Module):

**Module Discovery**

The offloader automatically discovers pipeline components:
The offloader discovers pipeline components in two ways:

1. **Protocol-based** (preferred): If the pipeline implements
`SupportsModuleOffload`, its `_dit_modules`, `_encoder_modules`,
`_vae_modules`, and `_resident_modules` class variables are used
directly. All attribute names support dotted paths (e.g.
`"pipe.transformer"`, `"bagel.time_embedder"`) for nested submodules.

- **DiT modules**: `transformer`, `transformer_2`, `dit`
- **Encoders**: `text_encoder`, `text_encoder_2`, `text_encoder_3`, `image_encoder`
- **VAE**: `vae`
2. **Fallback attribute scan**: Otherwise, the offloader scans for
well-known attribute names:
- **DiT modules**: `transformer`, `transformer_2`, `dit`, `sr_dit`, `language_model`, `transformer_blocks`, `model`
- **Encoders**: `text_encoder`, `text_encoder_2`, `text_encoder_3`, `image_encoder`
- **VAE**: `vae`, `audio_vae`

**Hook System**

Expand Down
240 changes: 240 additions & 0 deletions tests/diffusion/offloader/test_module_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for ModuleDiscovery and SupportsModuleOffload."""

from typing import ClassVar

import pytest
from torch import nn

from vllm_omni.diffusion.models.interface import SupportsModuleOffload
from vllm_omni.diffusion.offloader.module_collector import ModuleDiscovery

pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model]

# NOTE: tests for skipped/warned attributes verify the *behavioral*
# outcome (attribute excluded from results) but do not assert on log
# output. vllm's logger sets propagate=False, preventing caplog from
# capturing records. See https://github.com/pytest-dev/pytest/issues/3697


# ---------------------------------------------------------------------------
# Test pipelines
# ---------------------------------------------------------------------------


class FallbackPipeline(nn.Module):
"""Pipeline with standard attribute names (no protocol)."""

def __init__(self):
super().__init__()
self.transformer = nn.Linear(10, 10)
self.text_encoder = nn.Linear(10, 10)
self.text_encoder_2 = nn.Linear(10, 10)
self.vae = nn.Linear(10, 10)


class NonModuleAttrPipeline(nn.Module):
"""Pipeline where an attribute is not an nn.Module (fallback path)."""

def __init__(self):
super().__init__()
self.transformer = nn.Linear(10, 10)
self.text_encoder = "not_a_module"
self.vae = nn.Linear(10, 10)


class DuplicateAttrPipeline(nn.Module):
"""Pipeline where two encoder attrs point to the same module."""

def __init__(self):
super().__init__()
self.transformer = nn.Linear(10, 10)
encoder = nn.Linear(10, 10)
self.text_encoder = encoder
self.text_encoder_2 = encoder
self.vae = nn.Linear(10, 10)


class ProtocolPipeline(nn.Module, SupportsModuleOffload):
"""Pipeline with non-standard names, using the protocol."""

_dit_modules: ClassVar[list[str]] = ["gen_transformer"]
_encoder_modules: ClassVar[list[str]] = ["mllm", "vision_model"]
_vae_modules: ClassVar[list[str]] = ["gen_vae"]

def __init__(self):
super().__init__()
self.gen_transformer = nn.Linear(10, 10)
self.mllm = nn.Linear(10, 10)
self.vision_model = nn.Linear(10, 10)
self.gen_vae = nn.Linear(10, 10)
# Standard name present but NOT declared — should be ignored
self.transformer = nn.Linear(10, 10)


class MissingAttrPipeline(nn.Module, SupportsModuleOffload):
"""Pipeline that declares a non-existent attribute."""

_dit_modules: ClassVar[list[str]] = ["transformer"]
_encoder_modules: ClassVar[list[str]] = ["nonexistent_encoder"]
_vae_modules: ClassVar[list[str]] = ["vae"]

def __init__(self):
super().__init__()
self.transformer = nn.Linear(10, 10)
self.vae = nn.Linear(10, 10)


class MissingIntermediatePipeline(nn.Module, SupportsModuleOffload):
"""Pipeline with dotted path referencing non-existent intermediate."""

_dit_modules: ClassVar[list[str]] = ["nonexistent.transformer"]
_encoder_modules: ClassVar[list[str]] = []
_vae_modules: ClassVar[list[str]] = []

def __init__(self):
super().__init__()


class NestedPipeline(nn.Module, SupportsModuleOffload):
"""Pipeline with nested modules accessed via dotted paths."""

_dit_modules: ClassVar[list[str]] = ["pipe.transformer"]
_encoder_modules: ClassVar[list[str]] = ["pipe.text_encoder"]
_vae_modules: ClassVar[list[str]] = ["vae"]

def __init__(self):
super().__init__()
self.pipe = nn.Module()
self.pipe.transformer = nn.Linear(10, 10)
self.pipe.text_encoder = nn.Linear(10, 10)
self.vae = nn.Linear(10, 10)


class ResidentPipeline(nn.Module, SupportsModuleOffload):
"""Pipeline with resident modules that must stay on GPU."""

_dit_modules: ClassVar[list[str]] = ["language_model.model"]
_encoder_modules: ClassVar[list[str]] = []
_vae_modules: ClassVar[list[str]] = ["vae"]
_resident_modules: ClassVar[list[str]] = [
"bagel.time_embedder",
"bagel.vae2llm",
]

def __init__(self):
super().__init__()
self.language_model = nn.Module()
self.language_model.model = nn.Linear(10, 10)
self.bagel = nn.Module()
self.bagel.time_embedder = nn.Linear(10, 10)
self.bagel.vae2llm = nn.Linear(10, 10)
self.vae = nn.Linear(10, 10)


class MultiVaePipeline(nn.Module, SupportsModuleOffload):
"""Pipeline with multiple VAEs."""

_dit_modules: ClassVar[list[str]] = ["transformer"]
_encoder_modules: ClassVar[list[str]] = ["text_encoder"]
_vae_modules: ClassVar[list[str]] = ["vae", "audio_vae"]

def __init__(self):
super().__init__()
self.transformer = nn.Linear(10, 10)
self.text_encoder = nn.Linear(10, 10)
self.vae = nn.Linear(10, 10)
self.audio_vae = nn.Linear(10, 10)


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


class TestFallbackDiscovery:
"""Test the fallback attribute scan (no SupportsModuleOffload)."""

def test_discovers_standard_attrs(self):
pipeline = FallbackPipeline()
result = ModuleDiscovery.discover(pipeline)

assert not isinstance(pipeline, SupportsModuleOffload)
assert result.dit_names == ["transformer"]
assert result.dits[0] is pipeline.transformer
assert result.encoder_names == ["text_encoder", "text_encoder_2"]
assert result.vaes[0] is pipeline.vae
assert result.resident_modules == []

def test_deduplicates_encoders(self):
pipeline = DuplicateAttrPipeline()
result = ModuleDiscovery.discover(pipeline)

assert len(result.encoders) == 1
assert result.encoder_names == ["text_encoder"]

def test_skips_non_module_attr(self):
pipeline = NonModuleAttrPipeline()
result = ModuleDiscovery.discover(pipeline)

assert len(result.encoders) == 0


class TestProtocolDiscovery:
"""Test discovery via SupportsModuleOffload protocol."""

def test_discovers_declared_attrs_and_ignores_undeclared(self):
pipeline = ProtocolPipeline()
result = ModuleDiscovery.discover(pipeline)

assert isinstance(pipeline, SupportsModuleOffload)
assert result.dit_names == ["gen_transformer"]
assert result.encoder_names == ["mllm", "vision_model"]
assert len(result.vaes) == 1
# self.transformer exists but is NOT in _dit_modules
assert "transformer" not in result.dit_names
# No _resident_modules declared — defaults to empty
assert result.resident_modules == []

def test_skips_missing_attr(self):
pipeline = MissingAttrPipeline()
result = ModuleDiscovery.discover(pipeline)

assert len(result.encoders) == 0

def test_skips_missing_intermediate(self):
result = ModuleDiscovery.discover(MissingIntermediatePipeline())

assert len(result.dits) == 0

def test_dotted_path_resolves_nested_modules(self):
pipeline = NestedPipeline()
result = ModuleDiscovery.discover(pipeline)

assert result.dit_names == ["pipe.transformer"]
assert result.dits[0] is pipeline.pipe.transformer
assert result.encoder_names == ["pipe.text_encoder"]
assert result.encoders[0] is pipeline.pipe.text_encoder
assert result.vaes[0] is pipeline.vae

def test_resident_modules(self):
pipeline = ResidentPipeline()
result = ModuleDiscovery.discover(pipeline)

assert result.resident_names == [
"bagel.time_embedder",
"bagel.vae2llm",
]
assert result.resident_modules[0] is pipeline.bagel.time_embedder
assert result.resident_modules[1] is pipeline.bagel.vae2llm
assert result.dits[0] is pipeline.language_model.model

def test_multiple_vaes(self):
pipeline = MultiVaePipeline()
result = ModuleDiscovery.discover(pipeline)

assert len(result.vaes) == 2
assert result.vaes[0] is pipeline.vae
assert result.vaes[1] is pipeline.audio_vae
Loading
Loading