-
Notifications
You must be signed in to change notification settings - Fork 1k
[BugFix] Fix layerwise CPU offloading for LTX2 two-stages pipeline #2935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import torch.nn as nn | ||
|
|
||
| from vllm_omni.diffusion.offloader.module_collector import ModuleDiscovery | ||
|
|
||
|
|
||
| class DummyModule(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
|
|
||
| class DummyPipeline(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.transformer = nn.Module() | ||
| self.text_encoder = nn.Module() | ||
| self.vae = nn.Module() | ||
|
|
||
|
|
||
| # LTX2TwoStagesPipeline-like nested pipeline | ||
| class NestedDummyPipeline(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.pipe = DummyPipeline() | ||
| self.upsample_pipe = DummyPipeline() | ||
|
|
||
|
|
||
| class TestModuleDiscovery: | ||
| def test_discover_basic(self): | ||
| pipeline = DummyPipeline() | ||
| modules = ModuleDiscovery.discover(pipeline) | ||
| assert len(modules.dits) > 0 | ||
| assert len(modules.encoders) > 0 | ||
| assert modules.vae is not None | ||
|
|
||
| def test_discover_nested(self): | ||
| pipeline = NestedDummyPipeline() | ||
| modules = ModuleDiscovery.discover(pipeline) | ||
| assert len(modules.dits) > 0 | ||
| assert len(modules.encoders) > 0 | ||
| assert modules.vae is not None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,8 @@ | |
| from torch import nn | ||
| from vllm.logger import init_logger | ||
|
|
||
| from vllm_omni.diffusion.utils.tf_utils import find_module_with_attr | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -40,7 +42,13 @@ def discover(pipeline: nn.Module) -> PipelineModules: | |
| dit_names: list[str] = [] | ||
| for attr in ModuleDiscovery.DIT_ATTRS: | ||
| if not hasattr(pipeline, attr): | ||
| continue | ||
| # Some pipeline like LTX2TwoStagesPipeline have recursive | ||
| # modules that have the transformer | ||
| module = find_module_with_attr(pipeline, attr) | ||
| if module is None: | ||
| continue | ||
| pipeline = module | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reassignment to
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to have some tracking from the outmost wrapper pipeline to the transformer module which contains offloadable layers |
||
|
|
||
| module_obj = getattr(pipeline, attr) | ||
| if module_obj is None: | ||
| continue | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if a DiT on both pipe and upsample_pipe is found? The current resolution seems to fail on it