diff --git a/tests/diffusion/offloader/test_module_collector.py b/tests/diffusion/offloader/test_module_collector.py new file mode 100644 index 00000000000..0090772fad7 --- /dev/null +++ b/tests/diffusion/offloader/test_module_collector.py @@ -0,0 +1,40 @@ +import torch.nn as nn + +from vllm_omni.diffusion.offloader.module_collector import ModuleDiscovery + + +class DummyModule(nn.Module): + def __init__(self): + super().__init__() + + +class DummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.transformer = nn.Module() + self.text_encoder = nn.Module() + self.vae = nn.Module() + + +# LTX2TwoStagesPipeline-like nested pipeline +class NestedDummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.pipe = DummyPipeline() + self.upsample_pipe = DummyPipeline() + + +class TestModuleDiscovery: + def test_discover_basic(self): + pipeline = DummyPipeline() + modules = ModuleDiscovery.discover(pipeline) + assert len(modules.dits) > 0 + assert len(modules.encoders) > 0 + assert modules.vae is not None + + def test_discover_nested(self): + pipeline = NestedDummyPipeline() + modules = ModuleDiscovery.discover(pipeline) + assert len(modules.dits) > 0 + assert len(modules.encoders) > 0 + assert modules.vae is not None diff --git a/vllm_omni/diffusion/offloader/module_collector.py b/vllm_omni/diffusion/offloader/module_collector.py index a09a337001e..7f80fe7d43e 100644 --- a/vllm_omni/diffusion/offloader/module_collector.py +++ b/vllm_omni/diffusion/offloader/module_collector.py @@ -6,6 +6,8 @@ from torch import nn from vllm.logger import init_logger +from vllm_omni.diffusion.utils.tf_utils import find_module_with_attr + logger = init_logger(__name__) @@ -40,7 +42,13 @@ def discover(pipeline: nn.Module) -> PipelineModules: dit_names: list[str] = [] for attr in ModuleDiscovery.DIT_ATTRS: if not hasattr(pipeline, attr): - continue + # Some pipeline like LTX2TwoStagesPipeline have recursive + # modules that have the transformer + module = find_module_with_attr(pipeline, attr) + if module is None: + continue + pipeline = module + module_obj = getattr(pipeline, attr) if module_obj is None: continue