Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/diffusion/offloader/test_module_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch.nn as nn

from vllm_omni.diffusion.offloader.module_collector import ModuleDiscovery


class DummyModule(nn.Module):
def __init__(self):
super().__init__()


class DummyPipeline(nn.Module):
def __init__(self):
super().__init__()
self.transformer = nn.Module()
self.text_encoder = nn.Module()
self.vae = nn.Module()


# LTX2TwoStagesPipeline-like nested pipeline
class NestedDummyPipeline(nn.Module):
def __init__(self):
super().__init__()
self.pipe = DummyPipeline()
self.upsample_pipe = DummyPipeline()


class TestModuleDiscovery:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a DiT on both pipe and upsample_pipe is found? The current resolution seems to fail on it

def test_discover_basic(self):
pipeline = DummyPipeline()
modules = ModuleDiscovery.discover(pipeline)
assert len(modules.dits) > 0
assert len(modules.encoders) > 0
assert modules.vae is not None

def test_discover_nested(self):
pipeline = NestedDummyPipeline()
modules = ModuleDiscovery.discover(pipeline)
assert len(modules.dits) > 0
assert len(modules.encoders) > 0
assert modules.vae is not None
10 changes: 9 additions & 1 deletion vllm_omni/diffusion/offloader/module_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from torch import nn
from vllm.logger import init_logger

from vllm_omni.diffusion.utils.tf_utils import find_module_with_attr

logger = init_logger(__name__)


Expand Down Expand Up @@ -40,7 +42,13 @@ def discover(pipeline: nn.Module) -> PipelineModules:
dit_names: list[str] = []
for attr in ModuleDiscovery.DIT_ATTRS:
if not hasattr(pipeline, attr):
continue
# Some pipeline like LTX2TwoStagesPipeline have recursive
# modules that have the transformer
module = find_module_with_attr(pipeline, attr)
if module is None:
continue
pipeline = module
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reassignment to pipeline here descend other dit module which it's going to look for under the current one - which I think might be not that stable as it discards the root

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have some tracking from the outmost wrapper pipeline to the transformer module which contains offloadable layers


module_obj = getattr(pipeline, attr)
if module_obj is None:
continue
Expand Down
Loading