Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/contributing/model/adding_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ omni = Omni(model="your-model", enable_layerwise_offload=True)

```python
class WanTransformer3DModel(nn.Module):
_layerwise_offload_blocks_attr = "blocks" # Attribute name containing transformer blocks
_layerwise_offload_blocks_attrs = ["blocks"] # Attribute name containing transformer blocks

def __init__(self):
self.blocks = nn.ModuleList([...]) # Transformer blocks
Expand Down
11 changes: 9 additions & 2 deletions docs/user_guide/diffusion/cpu_offload_diffusion.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,19 @@ Models must define the blocks attribute name for layerwise offloading:

```python
class WanTransformer3DModel(nn.Module):
_layerwise_offload_blocks_attr = "blocks" # Attribute name containing transformer blocks
_layerwise_offload_blocks_attrs = ["blocks"] # Attribute names containing transformer blocks

def __init__(self):
self.blocks = nn.ModuleList([...]) # Transformer blocks
```
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds multi-block layerwise offloading but provides no test coverage. Add tests to verify: (1) multi-block offloading works correctly with different block types, (2) memory usage is reduced as expected, (3) output quality is maintained, and (4) edge cases like empty or invalid block attributes are handled.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept these out of e2e, as they're pure logic tests for block parsing (single, multi, empty, invalid, etc.). Just put them in a new file instead. pls take a look.


For models with multiple block types:

```python
class Flux2Transformer2DModel(nn.Module):
_layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
```

### Limitations
- Cold start latency increases because of
1) components are loaded to CPU first at the very first during initialization,
Expand Down Expand Up @@ -140,4 +147,4 @@ Factory function `get_offload_backend()` selects the appropriate backend based o

**Notes:**
- Model-Level Offloading is expected to be supported by all common diffusion models (DiT and encoders) naturally
- Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attr` pointing to transformer blocks
- Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attrs` pointing to transformer blocks
117 changes: 115 additions & 2 deletions tests/diffusion/offloader/test_layerwise_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for LayerwiseOffloadHook."""
"""Unit tests for LayerwiseOffloadHook and LayerWiseOffloadBackend utilities."""

import gc
import os
Expand All @@ -15,7 +15,7 @@
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate

import vllm_omni.diffusion.offloader.layerwise_backend as layerwise_backend_module
from vllm_omni.diffusion.offloader.layerwise_backend import LayerwiseOffloadHook
from vllm_omni.diffusion.offloader.layerwise_backend import LayerWiseOffloadBackend, LayerwiseOffloadHook
from vllm_omni.platforms import current_omni_platform

pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model]
Expand Down Expand Up @@ -127,3 +127,116 @@ def test_dtensor_wrapper_is_preserved_across_prefetch_and_offload(self, dist_gro
assert current_block.weight.to_local().is_meta
assert current_block.weight.to_local().shape == torch.Size([4])
assert not hook.is_materialized


class _DummyBlock(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(10, 10))


class _SingleBlockModel(nn.Module):
_layerwise_offload_blocks_attrs = ["blocks"]

def __init__(self, num_blocks: int = 3):
super().__init__()
self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])


class _MultiBlockModel(nn.Module):
_layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]

def __init__(self, num_transformer: int = 2, num_single: int = 2):
super().__init__()
self.transformer_blocks = nn.ModuleList([_DummyBlock() for _ in range(num_transformer)])
self.single_transformer_blocks = nn.ModuleList([_DummyBlock() for _ in range(num_single)])


class _EmptyBlocksModel(nn.Module):
_layerwise_offload_blocks_attrs = ["blocks"]

def __init__(self):
super().__init__()
self.blocks = nn.ModuleList([])


class _InvalidAttrModel(nn.Module):
_layerwise_offload_blocks_attrs = ["nonexistent_blocks", "blocks"]

def __init__(self, num_blocks: int = 2):
super().__init__()
self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])


class _DeprecatedSingleAttrModel(nn.Module):
_layerwise_offload_blocks_attr = "blocks"

def __init__(self, num_blocks: int = 2):
super().__init__()
self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])


class _NoAttrsModel(nn.Module):
def __init__(self, num_blocks: int = 2):
super().__init__()
self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])


class TestGetBlocksFromDit:
def test_get_blocks_from_dit_single_block_attr(self):
model = _SingleBlockModel(num_blocks=3)
attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
assert attr_names == ["blocks"]
assert len(blocks) == 3
assert all(isinstance(b, _DummyBlock) for b in blocks)

def test_get_blocks_from_dit_multi_block_attrs(self):
model = _MultiBlockModel(num_transformer=2, num_single=3)
attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
assert set(attr_names) == {"transformer_blocks", "single_transformer_blocks"}
assert len(blocks) == 5
assert all(isinstance(b, _DummyBlock) for b in blocks)

def test_get_blocks_from_dit_empty_blocks(self):
model = _EmptyBlocksModel()
attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
assert attr_names == []
assert blocks == []

def test_get_blocks_from_dit_invalid_attr_name(self):
model = _InvalidAttrModel(num_blocks=2)
with pytest.raises(
AttributeError,
match="Attribute 'nonexistent_blocks' declared in _layerwise_offload_blocks_attrs does not exist",
):
LayerWiseOffloadBackend.get_blocks_from_dit(model)

def test_get_blocks_from_dit_no_attrs_defined(self):
model = _NoAttrsModel(num_blocks=3)
attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
assert attr_names == []
assert blocks == []

def test_get_blocks_from_dit_deprecated_single_attr(self):
model = _DeprecatedSingleAttrModel(num_blocks=2)
attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
assert attr_names == ["blocks"]
assert len(blocks) == 2


class TestGetBlocksAttrNames:
def test_get_blocks_attr_names_new_format(self):
model = _MultiBlockModel()
attrs = LayerWiseOffloadBackend.get_blocks_attr_names(model)
assert attrs == ["transformer_blocks", "single_transformer_blocks"]

def test_get_blocks_attr_names_no_attrs(self):
model = _NoAttrsModel()
attrs = LayerWiseOffloadBackend.get_blocks_attr_names(model)
assert attrs == []

def test_set_blocks_attr_names(self):
model = _NoAttrsModel()
LayerWiseOffloadBackend.set_blocks_attr_names(model, ["new_blocks"])
assert hasattr(model.__class__, "_layerwise_offload_blocks_attrs")
assert model.__class__._layerwise_offload_blocks_attrs == ["new_blocks"]
45 changes: 45 additions & 0 deletions tests/e2e/online_serving/test_flux2_expansion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""
Tests for Flux2 Klein; currently Dev is implemented separately,
but ideally these models will fold together in the future.

Coverage:
- FP8 + CacheDiT + Ulysses=2 + TP=2
- Layerwise CPU offload + Ulysses=2 + Ring=2
- Layerwise CPU offload + TP=2
- Layerwise CPU offload + HSDP
"""

import pytest
Expand Down Expand Up @@ -42,6 +48,45 @@ def _get_diffusion_feature_cases(model: str):
),
marks=FOUR_CARD_FEATURE_MARKS,
),
pytest.param(
OmniServerParams(
model=model,
server_args=[
"--enable-layerwise-offload",
"--ulysses-degree",
"2",
"--ring",
"2",
],
),
id="layerwise_ulysses2_ring2",
marks=FOUR_CARD_FEATURE_MARKS,
),
pytest.param(
OmniServerParams(
model=model,
server_args=[
"--enable-layerwise-offload",
"--tensor-parallel-size",
"2",
],
),
id="layerwise_tp2",
marks=FOUR_CARD_FEATURE_MARKS,
),
pytest.param(
OmniServerParams(
model=model,
server_args=[
"--enable-layerwise-offload",
"--use-hsdp",
"--hsdp-shard-size",
"2",
],
),
id="layerwise_hsdp",
marks=FOUR_CARD_FEATURE_MARKS,
),
]


Expand Down
34 changes: 32 additions & 2 deletions tests/e2e/online_serving/test_zimage_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
for Z-Image.

Coverage is intentionally limited to the minimal 4xL4 cases that
exercise Z-Image's supported parallel feature combinations:
exercise Z-Image's supported feature combinations:
- CacheDiT + FP8 + Ring=2 + TP=2
- TeaCache + FP8 + Ulysses=2 + Ring=2
- Layerwise CPU offload + Ulysses=2 + Ring=2
- Layerwise CPU offload + TP=2
- Layerwise CPU offload + HSDP
"""

import pytest
Expand Down Expand Up @@ -64,12 +67,39 @@ def _get_diffusion_feature_cases():
OmniServerParams(
model=MODEL,
server_args=[
"--enable-layerwise-offload",
"--ulysses-degree",
"2",
"--ring",
"2",
],
),
id="layerwise_ulysses2_ring2",
marks=FOUR_CARD_MARKS,
),
pytest.param(
OmniServerParams(
model=MODEL,
server_args=[
"--enable-layerwise-offload",
"--tensor-parallel-size",
"2",
],
),
id="layerwise_tp2",
marks=FOUR_CARD_MARKS,
),
pytest.param(
OmniServerParams(
model=MODEL,
server_args=[
"--enable-layerwise-offload",
"--use-hsdp",
"--hsdp-shard-size",
"2",
],
),
id="parallel_hsdp",
id="layerwise_hsdp",
marks=[*FOUR_CARD_MARKS, pytest.mark.skip(reason="issue #2435")],
),
]
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/models/flux/flux_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ class FluxTransformer2DModel(nn.Module):
# -- typically a transformer layer
# used for torch compile optimizations
_repeated_blocks = ["FluxTransformerBlock"]
_layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]

@staticmethod
def _is_transformer_block(name: str, module) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ class Flux2Transformer2DModel(nn.Module):
"""

_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
_layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]

@staticmethod
def _is_transformer_block(name: str, module) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/models/helios/helios_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class HeliosTransformer3DModel(nn.Module):
"""

_repeated_blocks = ["HeliosTransformerBlock"]
_layerwise_offload_blocks_attr = "blocks"
_layerwise_offload_blocks_attrs = ["blocks"]
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class HunyuanVideo15Transformer3DModel(nn.Module):
"""

_repeated_blocks = ["HunyuanVideo15TransformerBlock"]
_layerwise_offload_blocks_attr = "transformer_blocks"
_layerwise_offload_blocks_attrs = ["transformer_blocks"]
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ class QwenImageTransformer2DModel(CachedTransformer):
# -- typically a transformer layer
# used for torch compile optimizations
_repeated_blocks = ["QwenImageTransformerBlock"]
_layerwise_offload_blocks_attr = "transformer_blocks"
_layerwise_offload_blocks_attrs = ["transformer_blocks"]
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ class WanTransformer3DModel(nn.Module):
"""

_repeated_blocks = ["WanTransformerBlock"]
_layerwise_offload_blocks_attr = "blocks"
_layerwise_offload_blocks_attrs = ["blocks"]
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ class ZImageTransformer2DModel(CachedTransformer):
"""

_repeated_blocks = ["ZImageTransformerBlock"]
_layerwise_offload_blocks_attrs = ["layers"]

@staticmethod
def _is_transformer_block(name: str, module) -> bool:
Expand Down
Loading
Loading