From 161d94870369c5089dbb52fb7ad50e85e3718875 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 26 Feb 2026 01:42:13 +0800 Subject: [PATCH 01/10] [Feat] support for multi-block layerwise offloading Signed-off-by: Lancer --- .../diffusion/models/flux/flux_transformer.py | 1 + .../flux2_klein/flux2_klein_transformer.py | 1 + .../diffusion/offloader/layerwise_backend.py | 31 +++++++++++++++---- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py index faf6d08d3a..c6820f6388 100644 --- a/vllm_omni/diffusion/models/flux/flux_transformer.py +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -432,6 +432,7 @@ class FluxTransformer2DModel(nn.Module): # -- typically a transformer layer # used for torch compile optimizations _repeated_blocks = ["FluxTransformerBlock"] + _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"] def __init__( self, diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ee10d2e0e4..e7f9455aab 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -559,6 +559,7 @@ class Flux2Transformer2DModel(nn.Module): """ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"] packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index a69b13e64d..d1b41b8397 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -261,9 +261,10 @@ def enable(self, pipeline: nn.Module) -> None: logger.info(f"Applying hooks on {dit_name} ({dit_module.__class__.__name__})") blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(dit_module) + blocks_attrs_names = getattr(dit_module.__class__, "_layerwise_offload_blocks_attrs", None) blocks = LayerWiseOffloadBackend.get_blocks_from_dit(dit_module) - if not blocks_attr_name or not blocks: + if not blocks: logger.warning( "Target layers (blocks) not found. Skipping offloading on %s (%s)", dit_name, @@ -284,11 +285,17 @@ def enable(self, pipeline: nn.Module) -> None: # Move non-block modules to GPU (they stay resident) for name, m in dit_module.named_children(): - if name == blocks_attr_name: - logger.debug(f"Skipped blocks module {name}") - continue - m.to(self.device) - logger.debug(f"Moved {name} to device {self.device}") + if name != blocks_attr_name and (not blocks_attrs_names or name not in blocks_attrs_names): + m.to(self.device) + + # Move top-level params/buffers to GPU (dit_module's own, not sub-modules) + for param in dit_module._parameters.values(): + if param is not None: + param.data = param.data.to(self.device, non_blocking=True) + + for buffer in dit_module._buffers.values(): + if buffer is not None: + buffer.data = buffer.data.to(self.device, non_blocking=True) # Pre-fetch the first layer by manually calling the hook function on the last layer; # For subsequent requests, the first layer/block will be pre-fetched @@ -344,13 +351,25 @@ class WanTransformer3DModel(nn.Module): ``` """ blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(model) + + # Handle multiple block types (_layerwise_offload_blocks_attrs) if blocks_attr_name is None: + blocks_attrs_names = getattr(model.__class__, "_layerwise_offload_blocks_attrs", None) + if blocks_attrs_names: + all_blocks = [block for name in blocks_attrs_names for block in getattr(model, name, [])] + if all_blocks: + logger.info(f"{len(all_blocks)} blocks from {blocks_attrs_names}") + return all_blocks + logger.warning(f"No blocks found in {blocks_attrs_names}") + return [] + logger.warning( f"No _layerwise_offload_blocks_attr defined for {model.__class__.__name__}, " "skipping layerwise offloading" ) return [] + # Standard single attribute handling _blocks = getattr(model, blocks_attr_name, None) if _blocks is None: logger.warning( From d6324f5dd599468422498b6e24e65a3c361503e8 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 26 Feb 2026 13:49:03 +0800 Subject: [PATCH 02/10] upd Signed-off-by: Lancer --- .../diffusion/cpu_offload_diffusion.md | 11 +++- .../qwen_image/qwen_image_transformer.py | 2 +- .../models/wan2_2/wan2_2_transformer.py | 2 +- .../diffusion/offloader/layerwise_backend.py | 60 ++++++++----------- 4 files changed, 35 insertions(+), 40 deletions(-) diff --git a/docs/user_guide/diffusion/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md index 8786ae9649..be72efffa5 100644 --- a/docs/user_guide/diffusion/cpu_offload_diffusion.md +++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md @@ -91,12 +91,19 @@ Models must define the blocks attribute name for layerwise offloading: ```python class WanTransformer3DModel(nn.Module): - _layerwise_offload_blocks_attr = "blocks" # Attribute name containing transformer blocks + _layerwise_offload_blocks_attrs = ["blocks"] # Attribute names containing transformer blocks def __init__(self): self.blocks = nn.ModuleList([...]) # Transformer blocks ``` +For models with multiple block types: + +```python +class Flux2Transformer2DModel(nn.Module): + _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"] +``` + ### Limitations - Cold start latency increases because of 1) components are loaded to CPU first at the very first during initialization, @@ -140,4 +147,4 @@ Factory function `get_offload_backend()` selects the appropriate backend based o **Notes:** - Model-Level Offloading is expected to be supported by all common diffusion models (DiT and encoders) naturally -- Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attr` pointing to transformer blocks +- Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attrs` pointing to transformer blocks diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 2d8d49eee9..2da51a902b 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -845,7 +845,7 @@ class QwenImageTransformer2DModel(CachedTransformer): # -- typically a transformer layer # used for torch compile optimizations _repeated_blocks = ["QwenImageTransformerBlock"] - _layerwise_offload_blocks_attr = "transformer_blocks" + _layerwise_offload_blocks_attrs = ["transformer_blocks"] packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index 4de44119f8..d5facd852c 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -711,7 +711,7 @@ class WanTransformer3DModel(nn.Module): """ _repeated_blocks = ["WanTransformerBlock"] - _layerwise_offload_blocks_attr = "blocks" + _layerwise_offload_blocks_attrs = ["blocks"] packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], } diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index d1b41b8397..88d025878d 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -260,9 +260,7 @@ def enable(self, pipeline: nn.Module) -> None: dit_name = modules.dit_names[i] logger.info(f"Applying hooks on {dit_name} ({dit_module.__class__.__name__})") - blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(dit_module) - blocks_attrs_names = getattr(dit_module.__class__, "_layerwise_offload_blocks_attrs", None) - blocks = LayerWiseOffloadBackend.get_blocks_from_dit(dit_module) + blocks_attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(dit_module) if not blocks: logger.warning( @@ -285,7 +283,7 @@ def enable(self, pipeline: nn.Module) -> None: # Move non-block modules to GPU (they stay resident) for name, m in dit_module.named_children(): - if name != blocks_attr_name and (not blocks_attrs_names or name not in blocks_attrs_names): + if blocks_attr_names and name not in blocks_attr_names: m.to(self.device) # Move top-level params/buffers to GPU (dit_module's own, not sub-modules) @@ -330,52 +328,42 @@ def disable(self) -> None: logger.info("Layer-wise offloading disabled") @staticmethod - def get_blocks_attr_name(model: nn.Module) -> str | None: - """Retrieve blocks attribute name from provided DiT model""" - return getattr(model.__class__, "_layerwise_offload_blocks_attr", None) + def get_blocks_attr_names(model: nn.Module) -> list[str]: + """Get block attribute names from model class.""" + return getattr(model.__class__, "_layerwise_offload_blocks_attrs", []) @staticmethod - def set_blocks_attr_name(model: nn.Module, name: str) -> None: - if not hasattr(model.__class__, "_layerwise_offload_blocks_attr"): - setattr(model.__class__, "_layerwise_offload_blocks_attr", name) + def set_blocks_attr_names(model: nn.Module, names: list[str]) -> None: + if not hasattr(model.__class__, "_layerwise_offload_blocks_attrs"): + setattr(model.__class__, "_layerwise_offload_blocks_attrs", names) @staticmethod - def get_blocks_from_dit(model: nn.Module) -> list[nn.Module]: + def get_blocks_from_dit(model: nn.Module) -> tuple[list[str], list[nn.Module]]: """ - Retrieve a list of blocks from provided DiT model. Blocks attribute name - are found by `_layerwise_offload_blocks_attr` set to DiT models. For example, + Retrieve blocks and attribute names from provided DiT model. Blocks attribute names + are found by `_layerwise_offload_blocks_attrs` set to DiT models. For example, ``` class WanTransformer3DModel(nn.Module): - _layerwise_offload_blocks_attr = "blocks" + _layerwise_offload_blocks_attrs = ["blocks"] ``` - """ - blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(model) - - # Handle multiple block types (_layerwise_offload_blocks_attrs) - if blocks_attr_name is None: - blocks_attrs_names = getattr(model.__class__, "_layerwise_offload_blocks_attrs", None) - if blocks_attrs_names: - all_blocks = [block for name in blocks_attrs_names for block in getattr(model, name, [])] - if all_blocks: - logger.info(f"{len(all_blocks)} blocks from {blocks_attrs_names}") - return all_blocks - logger.warning(f"No blocks found in {blocks_attrs_names}") - return [] + Returns: + Tuple of (blocks_attr_names, blocks) + """ + blocks_attr_names = LayerWiseOffloadBackend.get_blocks_attr_names(model) + if not blocks_attr_names: logger.warning( - f"No _layerwise_offload_blocks_attr defined for {model.__class__.__name__}, " + f"No _layerwise_offload_blocks_attrs defined for {model.__class__.__name__}, " "skipping layerwise offloading" ) - return [] + return [], [] - # Standard single attribute handling - _blocks = getattr(model, blocks_attr_name, None) - if _blocks is None: + blocks = [block for name in blocks_attr_names for block in getattr(model, name, [])] + if not blocks: logger.warning( - f"Blocks (layers) '{blocks_attr_name}' not found on {model.__class__.__name__}, " - "skipping layerwise offloading" + f"No blocks found in {blocks_attr_names} for {model.__class__.__name__}, skipping layerwise offloading" ) - return [] + return [], [] - return list(_blocks) + return blocks_attr_names, blocks From f1da0c7e31550c4f0459033582e58454070601a3 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sat, 28 Feb 2026 15:56:40 +0800 Subject: [PATCH 03/10] upd Signed-off-by: Lancer --- .../diffusion/offloader/layerwise_backend.py | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 88d025878d..07502afb07 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -283,7 +283,7 @@ def enable(self, pipeline: nn.Module) -> None: # Move non-block modules to GPU (they stay resident) for name, m in dit_module.named_children(): - if blocks_attr_names and name not in blocks_attr_names: + if name not in blocks_attr_names: m.to(self.device) # Move top-level params/buffers to GPU (dit_module's own, not sub-modules) @@ -330,7 +330,19 @@ def disable(self) -> None: @staticmethod def get_blocks_attr_names(model: nn.Module) -> list[str]: """Get block attribute names from model class.""" - return getattr(model.__class__, "_layerwise_offload_blocks_attrs", []) + attrs: list[str] = getattr(model.__class__, "_layerwise_offload_blocks_attrs", []) + + if not attrs: + old_attr = getattr(model.__class__, "_layerwise_offload_blocks_attr", None) + if old_attr is not None: + logger.warning( + "'_layerwise_offload_blocks_attr' is deprecated, " + "please use '_layerwise_offload_blocks_attrs' instead. " + "Example: _layerwise_offload_blocks_attrs = ['blocks']" + ) + attrs = [old_attr] if isinstance(old_attr, str) else list(old_attr) + + return attrs @staticmethod def set_blocks_attr_names(model: nn.Module, names: list[str]) -> None: @@ -359,10 +371,23 @@ class WanTransformer3DModel(nn.Module): ) return [], [] - blocks = [block for name in blocks_attr_names for block in getattr(model, name, [])] + blocks = [] + for name in blocks_attr_names: + attr = getattr(model, name, None) + if attr is None: + logger.error( + "Attribute '%s' in _layerwise_offload_blocks_attrs does not exist on model %s", + name, + model.__class__.__name__, + ) + continue + blocks.extend(attr) + if not blocks: logger.warning( - f"No blocks found in {blocks_attr_names} for {model.__class__.__name__}, skipping layerwise offloading" + "No blocks found in %s for %s, skipping layerwise offloading", + blocks_attr_names, + model.__class__.__name__, ) return [], [] From faa68392adb23664ca30dc5cba18a884256fab09 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sat, 28 Feb 2026 16:13:27 +0800 Subject: [PATCH 04/10] upd Signed-off-by: Lancer --- vllm_omni/diffusion/models/z_image/z_image_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index a4f073faa5..1203458231 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -566,6 +566,7 @@ class ZImageTransformer2DModel(CachedTransformer): """ _repeated_blocks = ["ZImageTransformerBlock"] + _layerwise_offload_blocks_attrs = ["layers"] packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], "w13": ["w1", "w3"], From 9429c92d8555329c8521824324adae11cd34937f Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 5 Mar 2026 19:54:23 +0800 Subject: [PATCH 05/10] upd Signed-off-by: Lancer --- vllm_omni/diffusion/offloader/layerwise_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 9795b857be..a3c62f51d8 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -284,6 +284,9 @@ def enable(self, pipeline: nn.Module) -> None: for name, m in dit_module.named_children(): if name not in blocks_attr_names: m.to(self.device) + logger.debug(f"Moved {name} to device {self.device}") + else: + logger.debug(f"Skipped blocks module {name}") # Move top-level params/buffers to GPU (dit_module's own, not sub-modules) for param in dit_module._parameters.values(): From cfacd00ab5e9d3b8575285c2cd1226a53b92ced2 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 15 Mar 2026 10:10:02 +0800 Subject: [PATCH 06/10] upd Signed-off-by: Lancer --- .../offloader/test_layerwise_backend.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/diffusion/offloader/test_layerwise_backend.py diff --git a/tests/diffusion/offloader/test_layerwise_backend.py b/tests/diffusion/offloader/test_layerwise_backend.py new file mode 100644 index 0000000000..60ab74f19c --- /dev/null +++ b/tests/diffusion/offloader/test_layerwise_backend.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for LayerWiseOffloadBackend block discovery utilities.""" + +import pytest +import torch +from torch import nn + +from vllm_omni.diffusion.offloader.layerwise_backend import LayerWiseOffloadBackend + +pytestmark = [pytest.mark.diffusion, pytest.mark.cpu] + + +class _DummyBlock(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(10, 10)) + + +class _SingleBlockModel(nn.Module): + _layerwise_offload_blocks_attrs = ["blocks"] + + def __init__(self, num_blocks: int = 3): + super().__init__() + self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)]) + + +class _MultiBlockModel(nn.Module): + _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"] + + def __init__(self, num_transformer: int = 2, num_single: int = 2): + super().__init__() + self.transformer_blocks = nn.ModuleList([_DummyBlock() for _ in range(num_transformer)]) + self.single_transformer_blocks = nn.ModuleList([_DummyBlock() for _ in range(num_single)]) + + +class _EmptyBlocksModel(nn.Module): + _layerwise_offload_blocks_attrs = ["blocks"] + + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([]) + + +class _InvalidAttrModel(nn.Module): + _layerwise_offload_blocks_attrs = ["nonexistent_blocks", "blocks"] + + def __init__(self, num_blocks: int = 2): + super().__init__() + self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)]) + + +class _DeprecatedSingleAttrModel(nn.Module): + _layerwise_offload_blocks_attr = "blocks" + + def __init__(self, num_blocks: int = 2): + super().__init__() + self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)]) + + +class _NoAttrsModel(nn.Module): + def __init__(self, num_blocks: int = 2): + super().__init__() + self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)]) + + +class TestGetBlocksFromDit: + def test_get_blocks_from_dit_single_block_attr(self): + model = _SingleBlockModel(num_blocks=3) + attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) + assert attr_names == ["blocks"] + assert len(blocks) == 3 + assert all(isinstance(b, _DummyBlock) for b in blocks) + + def test_get_blocks_from_dit_multi_block_attrs(self): + model = _MultiBlockModel(num_transformer=2, num_single=3) + attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) + assert set(attr_names) == {"transformer_blocks", "single_transformer_blocks"} + assert len(blocks) == 5 + assert all(isinstance(b, _DummyBlock) for b in blocks) + + def test_get_blocks_from_dit_empty_blocks(self): + model = _EmptyBlocksModel() + attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) + assert attr_names == [] + assert blocks == [] + + def test_get_blocks_from_dit_invalid_attr_name(self): + model = _InvalidAttrModel(num_blocks=2) + attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) + assert set(attr_names) == {"nonexistent_blocks", "blocks"} + assert len(blocks) == 2 + + def test_get_blocks_from_dit_no_attrs_defined(self): + model = _NoAttrsModel(num_blocks=3) + attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) + assert attr_names == [] + assert blocks == [] + + def test_get_blocks_from_dit_deprecated_single_attr(self): + model = _DeprecatedSingleAttrModel(num_blocks=2) + attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) + assert attr_names == ["blocks"] + assert len(blocks) == 2 + + +class TestGetBlocksAttrNames: + def test_get_blocks_attr_names_new_format(self): + model = _MultiBlockModel() + attrs = LayerWiseOffloadBackend.get_blocks_attr_names(model) + assert attrs == ["transformer_blocks", "single_transformer_blocks"] + + def test_get_blocks_attr_names_no_attrs(self): + model = _NoAttrsModel() + attrs = LayerWiseOffloadBackend.get_blocks_attr_names(model) + assert attrs == [] + + def test_set_blocks_attr_names(self): + model = _NoAttrsModel() + LayerWiseOffloadBackend.set_blocks_attr_names(model, ["new_blocks"]) + assert hasattr(model.__class__, "_layerwise_offload_blocks_attrs") + assert model.__class__._layerwise_offload_blocks_attrs == ["new_blocks"] From f8ad03287f467c28e047e0f80190f3e54cbf2a9b Mon Sep 17 00:00:00 2001 From: Lancer Date: Tue, 31 Mar 2026 12:02:29 +0800 Subject: [PATCH 07/10] upd Signed-off-by: Lancer --- .../model/adding_diffusion_model.md | 2 +- .../models/helios/helios_transformer.py | 2 +- .../hunyuan_video_15_transformer.py | 2 +- .../diffusion/offloader/layerwise_backend.py | 21 ++++++++++++++++++- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md index 366903433e..6c81e8b307 100644 --- a/docs/contributing/model/adding_diffusion_model.md +++ b/docs/contributing/model/adding_diffusion_model.md @@ -820,7 +820,7 @@ omni = Omni(model="your-model", enable_layerwise_offload=True) ```python class WanTransformer3DModel(nn.Module): - _layerwise_offload_blocks_attr = "blocks" # Attribute name containing transformer blocks + _layerwise_offload_blocks_attrs = ["blocks"] # Attribute name containing transformer blocks def __init__(self): self.blocks = nn.ModuleList([...]) # Transformer blocks diff --git a/vllm_omni/diffusion/models/helios/helios_transformer.py b/vllm_omni/diffusion/models/helios/helios_transformer.py index 17b9ec37e2..dd0fae8aa9 100644 --- a/vllm_omni/diffusion/models/helios/helios_transformer.py +++ b/vllm_omni/diffusion/models/helios/helios_transformer.py @@ -558,7 +558,7 @@ class HeliosTransformer3DModel(nn.Module): """ _repeated_blocks = ["HeliosTransformerBlock"] - _layerwise_offload_blocks_attr = "blocks" + _layerwise_offload_blocks_attrs = ["blocks"] packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], } diff --git a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py index 2f7318cefc..3884d7ab73 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py @@ -538,7 +538,7 @@ class HunyuanVideo15Transformer3DModel(nn.Module): """ _repeated_blocks = ["HunyuanVideo15TransformerBlock"] - _layerwise_offload_blocks_attr = "transformer_blocks" + _layerwise_offload_blocks_attrs = ["transformer_blocks"] packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 452b5f4ab8..11e9512fa3 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -410,7 +410,26 @@ class WanTransformer3DModel(nn.Module): model.__class__.__name__, ) continue - blocks.extend(attr) + try: + attr_iter = iter(attr) + except TypeError: + if isinstance(attr, nn.Module): + logger.warning( + "Attribute '%s' on %s is not iterable; treating it as one block.", + name, + model.__class__.__name__, + ) + blocks.append(attr) + continue + + logger.warning( + "Attribute '%s' on %s is not iterable (got %s); skipping it.", + name, + model.__class__.__name__, + type(attr).__name__, + ) + else: + blocks.extend(attr_iter) if not blocks: logger.warning( From 0fd342d23469171e31650f976866c2cea724e10b Mon Sep 17 00:00:00 2001 From: Lancer <402430575@qq.com> Date: Wed, 1 Apr 2026 19:28:13 +0800 Subject: [PATCH 08/10] Update tests/diffusion/offloader/test_layerwise_backend.py Co-authored-by: Didan Deng <33117903+wtomin@users.noreply.github.com> Signed-off-by: Lancer <402430575@qq.com> --- tests/diffusion/offloader/test_layerwise_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/diffusion/offloader/test_layerwise_backend.py b/tests/diffusion/offloader/test_layerwise_backend.py index 60ab74f19c..335c555997 100644 --- a/tests/diffusion/offloader/test_layerwise_backend.py +++ b/tests/diffusion/offloader/test_layerwise_backend.py @@ -8,7 +8,7 @@ from vllm_omni.diffusion.offloader.layerwise_backend import LayerWiseOffloadBackend -pytestmark = [pytest.mark.diffusion, pytest.mark.cpu] +pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model] class _DummyBlock(nn.Module): From caa05f968a28bc2ff541f9c4cf897ac15b918328 Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 1 Apr 2026 20:51:21 +0800 Subject: [PATCH 09/10] upd Signed-off-by: Lancer --- .../offloader/test_layerwise_backend.py | 8 +++-- .../online_serving/test_flux2_expansion.py | 31 +++++++++++++++++++ .../online_serving/test_zimage_expansion.py | 30 +++++++++++++++++- .../diffusion/offloader/layerwise_backend.py | 8 ++--- 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/tests/diffusion/offloader/test_layerwise_backend.py b/tests/diffusion/offloader/test_layerwise_backend.py index 335c555997..c13b594299 100644 --- a/tests/diffusion/offloader/test_layerwise_backend.py +++ b/tests/diffusion/offloader/test_layerwise_backend.py @@ -87,9 +87,11 @@ def test_get_blocks_from_dit_empty_blocks(self): def test_get_blocks_from_dit_invalid_attr_name(self): model = _InvalidAttrModel(num_blocks=2) - attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model) - assert set(attr_names) == {"nonexistent_blocks", "blocks"} - assert len(blocks) == 2 + with pytest.raises( + AttributeError, + match="Attribute 'nonexistent_blocks' declared in _layerwise_offload_blocks_attrs does not exist", + ): + LayerWiseOffloadBackend.get_blocks_from_dit(model) def test_get_blocks_from_dit_no_attrs_defined(self): model = _NoAttrsModel(num_blocks=3) diff --git a/tests/e2e/online_serving/test_flux2_expansion.py b/tests/e2e/online_serving/test_flux2_expansion.py index 0e9e8c89a6..8afa5f6c86 100644 --- a/tests/e2e/online_serving/test_flux2_expansion.py +++ b/tests/e2e/online_serving/test_flux2_expansion.py @@ -1,6 +1,11 @@ """ Tests for Flux2 Klein; currently Dev is implemented separately, but ideally these models will fold together in the future. + +Coverage: +- FP8 + CacheDiT + Ulysses=2 + TP=2 +- Layerwise CPU offload + Ulysses=2 + Ring=2 +- Layerwise CPU offload + TP=2 """ import pytest @@ -42,6 +47,32 @@ def _get_diffusion_feature_cases(model: str): ), marks=FOUR_CARD_FEATURE_MARKS, ), + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--enable-layerwise-offload", + "--ulysses-degree", + "2", + "--ring", + "2", + ], + ), + id="layerwise_ulysses2_ring2", + marks=FOUR_CARD_FEATURE_MARKS, + ), + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--enable-layerwise-offload", + "--tensor-parallel-size", + "2", + ], + ), + id="layerwise_tp2", + marks=FOUR_CARD_FEATURE_MARKS, + ), ] diff --git a/tests/e2e/online_serving/test_zimage_expansion.py b/tests/e2e/online_serving/test_zimage_expansion.py index dfca76ca25..15cce4f1f8 100644 --- a/tests/e2e/online_serving/test_zimage_expansion.py +++ b/tests/e2e/online_serving/test_zimage_expansion.py @@ -3,9 +3,11 @@ for Z-Image. Coverage is intentionally limited to the minimal 4xL4 cases that -exercise Z-Image's supported parallel feature combinations: +exercise Z-Image's supported feature combinations: - CacheDiT + FP8 + Ring=2 + TP=2 - TeaCache + FP8 + Ulysses=2 + Ring=2 +- Layerwise CPU offload + Ulysses=2 + Ring=2 +- Layerwise CPU offload + TP=2 """ import pytest @@ -60,6 +62,32 @@ def _get_diffusion_feature_cases(): id="parallel_teacache_fp8_ulysses2_ring2", marks=FOUR_CARD_MARKS, ), + pytest.param( + OmniServerParams( + model=MODEL, + server_args=[ + "--enable-layerwise-offload", + "--ulysses-degree", + "2", + "--ring", + "2", + ], + ), + id="layerwise_ulysses2_ring2", + marks=FOUR_CARD_MARKS, + ), + pytest.param( + OmniServerParams( + model=MODEL, + server_args=[ + "--enable-layerwise-offload", + "--tensor-parallel-size", + "2", + ], + ), + id="layerwise_tp2", + marks=FOUR_CARD_MARKS, + ), ] diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 11e9512fa3..6b1dbfc00c 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -404,12 +404,10 @@ class WanTransformer3DModel(nn.Module): for name in blocks_attr_names: attr = getattr(model, name, None) if attr is None: - logger.error( - "Attribute '%s' in _layerwise_offload_blocks_attrs does not exist on model %s", - name, - model.__class__.__name__, + raise AttributeError( + f"Attribute '{name}' declared in _layerwise_offload_blocks_attrs " + f"does not exist on model {model.__class__.__name__}" ) - continue try: attr_iter = iter(attr) except TypeError: From a24eaf7e53c23f9e44934c8c969024bda1886411 Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 1 Apr 2026 22:52:00 +0800 Subject: [PATCH 10/10] upd Signed-off-by: Lancer --- .../diffusion/offloader/test_layerwise_backend.py | 2 +- tests/e2e/online_serving/test_flux2_expansion.py | 14 ++++++++++++++ tests/e2e/online_serving/test_zimage_expansion.py | 5 +++-- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/diffusion/offloader/test_layerwise_backend.py b/tests/diffusion/offloader/test_layerwise_backend.py index ee3a0f20b1..5fd80e75c2 100644 --- a/tests/diffusion/offloader/test_layerwise_backend.py +++ b/tests/diffusion/offloader/test_layerwise_backend.py @@ -12,7 +12,7 @@ import torch import torch.distributed as dist from torch import nn -from torch.distributed.tensor import DTensor, DeviceMesh, Replicate +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate import vllm_omni.diffusion.offloader.layerwise_backend as layerwise_backend_module from vllm_omni.diffusion.offloader.layerwise_backend import LayerWiseOffloadBackend, LayerwiseOffloadHook diff --git a/tests/e2e/online_serving/test_flux2_expansion.py b/tests/e2e/online_serving/test_flux2_expansion.py index 8afa5f6c86..336bd83a1d 100644 --- a/tests/e2e/online_serving/test_flux2_expansion.py +++ b/tests/e2e/online_serving/test_flux2_expansion.py @@ -6,6 +6,7 @@ - FP8 + CacheDiT + Ulysses=2 + TP=2 - Layerwise CPU offload + Ulysses=2 + Ring=2 - Layerwise CPU offload + TP=2 +- Layerwise CPU offload + HSDP """ import pytest @@ -73,6 +74,19 @@ def _get_diffusion_feature_cases(model: str): id="layerwise_tp2", marks=FOUR_CARD_FEATURE_MARKS, ), + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--enable-layerwise-offload", + "--use-hsdp", + "--hsdp-shard-size", + "2", + ], + ), + id="layerwise_hsdp", + marks=FOUR_CARD_FEATURE_MARKS, + ), ] diff --git a/tests/e2e/online_serving/test_zimage_expansion.py b/tests/e2e/online_serving/test_zimage_expansion.py index f31393314b..e24e868a1a 100644 --- a/tests/e2e/online_serving/test_zimage_expansion.py +++ b/tests/e2e/online_serving/test_zimage_expansion.py @@ -8,7 +8,7 @@ - TeaCache + FP8 + Ulysses=2 + Ring=2 - Layerwise CPU offload + Ulysses=2 + Ring=2 - Layerwise CPU offload + TP=2 -- HSDP +- Layerwise CPU offload + HSDP """ import pytest @@ -93,12 +93,13 @@ def _get_diffusion_feature_cases(): OmniServerParams( model=MODEL, server_args=[ + "--enable-layerwise-offload", "--use-hsdp", "--hsdp-shard-size", "2", ], ), - id="parallel_hsdp", + id="layerwise_hsdp", marks=FOUR_CARD_MARKS, ), ]