[diffusion] Generalize layerwise offload residency mixin to all components#24593
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the layerwise offload mechanism by renaming OffloadableDiTMixin to LayerwiseOffloadableModuleMixin and centralizing residency strategy logic in component_manager.py. It introduces helper functions like is_layerwise_offloaded_module and should_cpu_offload_component to simplify offload decisions across various model components. Feedback includes suggestions to remove a redundant bool() call, simplify multi-line tuple unpacking for better readability, and remove an unnecessary trailing comma in a tuple assignment.
| def is_layerwise_offloaded_module(module: torch.nn.Module) -> bool: | ||
| return ( | ||
| isinstance(module, LayerwiseOffloadableModuleMixin) | ||
| and bool(module.layerwise_offload_managers) |
There was a problem hiding this comment.
The bool() call here is redundant. In Python, an empty list is evaluated as False in a boolean context, so you can check for non-emptiness directly. The pythonic way is to use the list itself in the condition.
| and bool(module.layerwise_offload_managers) | |
| and module.layerwise_offload_managers |
| (shift_msa, scale_msa, gate_msa), ( | ||
| shift_mlp, | ||
| scale_mlp, | ||
| gate_mlp, | ||
| ) = temb_mod_params_img |
There was a problem hiding this comment.
| x_valid_lens, | ||
| cap_valid_lens, | ||
| ) = self.patchify_and_embed( | ||
| (x, cap_feats, x_size, x_valid_lens, cap_valid_lens,) = self.patchify_and_embed( |
There was a problem hiding this comment.
The trailing comma in this tuple unpacking is unnecessary. While valid syntax, it's typically used to define a single-element tuple. For multi-element tuples, it's unconventional and can be removed for clarity.
| (x, cap_feats, x_size, x_valid_lens, cap_valid_lens,) = self.patchify_and_embed( | |
| (x, cap_feats, x_size, x_valid_lens, cap_valid_lens) = self.patchify_and_embed( |
…ency-strategy-compat # Conflicts: # python/sglang/multimodal_gen/configs/pipeline_configs/base.py # python/sglang/multimodal_gen/configs/pipeline_configs/model_deployment_config.py # python/sglang/multimodal_gen/configs/pipeline_configs/mova.py # python/sglang/multimodal_gen/configs/pipeline_configs/wan.py # python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py # python/sglang/multimodal_gen/runtime/server_args.py # python/sglang/multimodal_gen/test/unit/test_server_args.py
…ents Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Summary
LayerwiseOffloadableModuleMixin.--layerwise-offload-componentsto select layerwise offload by pipeline component name, with--layerwise-offload-modulesaccepted as an alias.--dit-layerwise-offloadlegacy behavior: when no component is named, only default DiT components are configured; encoder / VAE / bridge / upsampler / vocoder must be selected explicitly.Validation
Remote H200 container
/sgl-workspace/sglangatfdf022713606dc2e6262975145d94e7f7d504a0d:PYTHONPATH=/sgl-workspace/sglang/python python -m pytest python/sglang/multimodal_gen/test/unit/test_layerwise_offload.py python/sglang/multimodal_gen/test/unit/test_server_args.py->41 passed, 2 warningstorch_sdpabackend:--dit-layerwise-offload true --dit-offload-prefetch-size 0-> PASS, enabled['transformer']--layerwise-offload-components transformer --dit-offload-prefetch-size 0-> PASS, enabled['transformer']--layerwise-offload-components transformer text_encoder --dit-offload-prefetch-size 0-> PASS, enabled['text_encoder', 'transformer']--layerwise-offload-components all --dit-offload-prefetch-size 0-> PASS, enabled['text_encoder', 'vae', 'transformer']--layerwise-offload-modules transformer --dit-offload-prefetch-size 0-> PASS, enabled['transformer']--layerwise-offload-components missing_component --dit-offload-prefetch-size 0-> PASS with warning and no layerwise component0f0dba3e7d97aa3be19ef7d6d1cd3ea0e727c322153a8a4f9904089b8e9ee4c1CI States
Latest PR Test: Run #25928755979⚠️ Not enabled — add
Latest PR Test (Extra):
run-ci-extralabel to opt in.