diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6d731f127cc..49d1b2ece64 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -77,6 +77,7 @@ steps: depends_on: image-build commands: - pytest -s -v tests/e2e/offline_inference/test_diffusion_cpu_offload.py + - pytest -s -v tests/e2e/offline_inference/test_diffusion_layerwise_offload.py agents: queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU plugins: diff --git a/docs/user_guide/diffusion/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md index 533b6b3b964..1f82fd6089a 100644 --- a/docs/user_guide/diffusion/cpu_offload_diffusion.md +++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md @@ -1,19 +1,32 @@ # CPU Offloading for Diffusion Model ## Overview + +vLLM-Omni provides two offloading strategies to reduce GPU memory usage for diffusion models, allowing you to run larger models on GPUs with limited VRAM: + +1. **Model-level (Component) Offloading**: Swaps entire model components (DiT transformer, VAE, encoders) between GPU and CPU. +2. **Layerwise (Blockwise) Offloading**: Keeps only a single or a few transformer blocks on GPU at a time, with compute - memory copy overlap. + +Both approaches use pinned memory for faster CPU-GPU transfers. For now, the two offloading strategies could not be used at the same time. + + +## Model-level CPU Offloading + +### Implementation + CPU offload lets the diffusion worker move large model components between GPU and CPU memory on demand. It keeps the DiT transformer resident on GPU only while it is actively running, and swaps it out when encoders modules need the device. This reduces peak VRAM usage so bigger checkpoints run on smaller GPUs, or multiple requests can share the same GPU. -## Execution Model +**Execution Flow**: 1. Text encoders run on GPU while the DiT transformer is offloaded to CPU. 2. Before denoising, weights are prefetched back to GPU, honoring pinned-memory copies for speed. 3. After the diffusion step, the transformer returns to CPU and the process repeats as needed. Transfers use pinned host buffers, and the worker coordinates swaps via mutex-style hooks so components never compete for memory. -## Configuration +### Configuration You can enable CPU offload in two ways: -- **Python API**: set `enable_cpu_offload=True`. +1. **Python API**: set `enable_cpu_offload=True`. ```python from vllm_omni import Omni @@ -23,7 +36,66 @@ if __name__ == "__main__": m = Omni(model="Qwen/Qwen-Image",enable_cpu_offload=True) ``` -- **CLI**: pass `--enable-cpu-offload` to the diffusion service entrypoint. +2. **CLI**: pass `--enable-cpu-offload` to the diffusion service entrypoint. -## Known Limitations +### Limitations - Cold start latency increases for over one minute for some models(e.g., Qwen-Image) + + +## Layerwise (Blockwise) Offloading + +### Implementation +Layerwise offload operates at transformer block granularity, keeping a single transformer block, or a specified number of blocks, on GPU while others stay in CPU memory. + +Unlike full model-wise CPU offload which swaps entire components like DiT and encoders, layerwise offloading applies a sliding window way of loading and offloading weights between gpu and cpu: while block `i` computes, block `i+1` get prefetched asynchronously via pinned memory. In this way, only partial blocks(s) reside on GPU at any moment during inference, so that greatly decrease the memory occupancy. + +**Execution Flow**: + +1. During model initialization, all components are loaded to CPU first. Then components other than DiT model(s) in the pipeline, such as VAE and encoders, are moved to GPU. The weights of target transformer blocks are collected as contiguous tensors per layer on CPU with pinned memory; and non-block modules (embeddings, norms, etc) in the DiT model are moved to and stay on GPU. +2. The first block(s) are transferred to GPU during initialization of `LayerwiseOffloader`, before the first denoising step of the very first request. +3. As each block executes, the next block prefetches on a separate CUDA stream for compute - memory copy overlap. After execution, the current block is immediately freed from GPU memory. +4. When the last block completes, the first block prefetches for the next denoising step. + + +Example of hook executions of a DiT model with n layers, by default keep a single layer on GPU: +| Layer (block) idx | forward pre-hook | forward | forward post-hook | +|-------------------|--------------------------------|------------------|---------------------------| +| layer-0 | prefetch layer 1 (copy stream) | compute layer 0 | free layer-0 gpu weights | +| layer-1 | prefetch layer 2 (copy stream) | compute layer 1 | free layer-1 gpu weights | +| layer-2 | prefetch layer 3 (copy stream) | compute layer 2 | free layer-2 gpu weights | +| ... | ... | ... | ... | +| layer-(n-1) | **prefetch layer 0 (copy stream)** | compute layer (n-1) | free layer (n-1) gpu weights | + + +### Configuration + +1. **Python API**: set `enable_layerwise_offload=True` and optionally `layerwise_num_gpu_layers`. + +```python +from vllm_omni import Omni + +if __name__ == "__main__": + m = Omni( + model="Wan-AI/Wan2.2-T2V-A14B-Diffusers", + enable_layerwise_offload=True, + ... + ) +``` + +2. **CLI**: pass `--enable-layerwise-offload` and `--layerwise-num-gpu-layers` to the diffusion service entrypoint. + +### Supported Models + +| Architecture | Models | Example HF Models | DiT Model Cls | Blocks Attr Name | +|--------------|--------|-------------------|----------|----------| +| `QwenImagePipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image` | `QwenImageTransformer2DModel` | "transformer_blocks" | +| `Wan22Pipeline` | Wan2.2 | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | `WanTransformer3DModel` | "blocks" | + +NOTE: Models must define `_layerwise_offload_blocks_attr` class attribute so that the layerwise offloader finds the target transformer blocks. + +### Limitations +- Cold start latency increases because of + 1) components are loaded to CPU first at the very first during initialization, + 2) weight consolidation and pinning +- Performance depends on CPU <-> GPU interconnection (e.g., PCIe bandwidth). +- Support single GPU only for now diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 5ffab87f0cf..8f330e09d20 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -295,6 +295,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) return parser.parse_args() @@ -350,6 +361,8 @@ def main(): # Initialize Omni with appropriate pipeline omni = Omni( model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, cache_backend=args.cache_backend, diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 4fca7825c74..1785287849d 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -74,6 +74,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) return parser.parse_args() @@ -112,6 +123,8 @@ def main(): omni = Omni( model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, boundary_ratio=args.boundary_ratio, diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 3e0628a15a3..a79e5d640d2 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -107,6 +107,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) parser.add_argument( "--tensor_parallel_size", type=int, @@ -172,6 +183,8 @@ def main(): omni = Omni( model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, cache_backend=args.cache_backend, diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index b5c003dfb22..63474987faa 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -79,6 +79,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) parser.add_argument( "--ulysses_degree", type=int, @@ -128,6 +139,8 @@ def main(): omni = Omni( model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, boundary_ratio=args.boundary_ratio, diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py new file mode 100644 index 00000000000..87a9e0a9e5f --- /dev/null +++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py @@ -0,0 +1,110 @@ +import sys +from pathlib import Path + +import pytest +import torch +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + +from tests.utils import GPUMemoryMonitor +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +# Models to test and expected saved memory in MB, correspondingly +MODELS_SAVED_MEMORY_MB = {"riverclouds/qwen_image_random": 4500} + + +def run_inference( + model_name: str, + layerwise_offload: bool = False, + num_gpu_layers: int = 1, + num_inference_steps: int = 3, +) -> float: + # For now, only support on GPU, so apply torch.cuda operations here + # NPU / ROCm platforms are expected to be detected and skipped this test function + torch.cuda.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + + m = Omni( + model=model_name, + enable_layerwise_offload=layerwise_offload, + layerwise_num_gpu_layers=num_gpu_layers, + boundary_ratio=0.875, + flow_shift=5.0, + ) + + torch.cuda.reset_peak_memory_stats(device=device_index) + + # Refer to tests/e2e/offline_inference/test_t2v_model.py + # Use minimal settings for testing + height = 480 + width = 640 + num_frames = 5 + + m.generate( + "A cat sitting on a table", + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=torch.Generator("cuda").manual_seed(42), + guidance_scale=1.0, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + ), + ) + + peak = monitor.peak_used_mb + monitor.stop() + + return peak + + +@pytest.mark.skipif(current_omni_platform.is_npu() or current_omni_platform.is_rocm(), reason="Hardware not supported") +@pytest.mark.parametrize("model_name", MODELS_SAVED_MEMORY_MB.keys()) +def test_layerwise_offload_diffusion_model(model_name: str): + """Test that layerwise offloading reduces GPU memory usage. + + This test verifies that layerwise offloading significantly reduces peak + GPU memory usage compared to loading the entire model on GPU. The layerwise + offloader keeps only a single transformer block on GPU at a time, with + prefetching for compute-memory overlap. + """ + try: + # Run without layerwise offloading (baseline) + no_offload_peak_memory = run_inference(model_name, layerwise_offload=False) + cleanup_dist_env_and_memory() + + # Run with layerwise offloading (1 layer on device) + layerwise_offload_peak_memory = run_inference(model_name, layerwise_offload=True, num_gpu_layers=1) + cleanup_dist_env_and_memory() + + # Run with 2 layers on device + layerwise_offload_two_layers_peak = run_inference(model_name, layerwise_offload=True, num_gpu_layers=2) + except Exception: + pytest.fail("Inference failed") + + print(f"Layerwise offload peak memory (1 GPU layer): {layerwise_offload_peak_memory} MB") + print(f"Layerwise offload peak memory (2 GPU layers): {layerwise_offload_two_layers_peak} MB") + print(f"No offload peak memory: {no_offload_peak_memory} MB") + + # Verify that layerwise offloading significantly reduces memory usage + # Passes only if the actual savings exceeds the expected savings + assert layerwise_offload_peak_memory + MODELS_SAVED_MEMORY_MB[model_name] < no_offload_peak_memory, ( + f"Layerwise offload peak memory {layerwise_offload_peak_memory} MB " + f"should be significantly less than no offload peak memory {no_offload_peak_memory} MB" + ) + + # Verify that 2 GPU layers uses more memory than 1 GPU layer + # But not excessively more (should be a reasonable increase) + assert layerwise_offload_peak_memory < layerwise_offload_two_layers_peak, ( + f"1 GPU layer peak {layerwise_offload_peak_memory} MB should be < " + f"2 GPU layers peak {layerwise_offload_two_layers_peak} MB" + ) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 286ab9153dc..4248071f010 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -288,6 +288,12 @@ class OmniDiffusionConfig: # - Text encoders run on GPU while DiT is on CPU # - DiT runs on GPU while encoders are on CPU enable_cpu_offload: bool = False + + # Layer-wise offloading (block-level offloading) parameters + enable_layerwise_offload: bool = False + # Number of transformer blocks ready for computation to keep on GPU + layerwise_num_gpu_layers: int = 1 + use_fsdp_inference: bool = False pin_cpu_memory: bool = True # Use pinned memory for faster transfers when offloading diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 054a5cb36d6..47098ff8a1b 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -758,6 +758,7 @@ class QwenImageTransformer2DModel(CachedTransformer): # -- typically a transformer layer # used for torch compile optimizations _repeated_blocks = ["QwenImageTransformerBlock"] + _layerwise_offload_blocks_attr = "transformer_blocks" packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index e4ba6318bca..ab92ad0c882 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -531,6 +531,7 @@ class WanTransformer3DModel(nn.Module): """ _repeated_blocks = ["WanTransformerBlock"] + _layerwise_offload_blocks_attr = "blocks" packed_modules_mapping = { "to_qkv": ["to_q", "to_k", "to_v"], } diff --git a/vllm_omni/diffusion/offload.py b/vllm_omni/diffusion/offload.py index 174c6434276..6f1d8a0db06 100644 --- a/vllm_omni/diffusion/offload.py +++ b/vllm_omni/diffusion/offload.py @@ -12,7 +12,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from functools import partial +from itertools import chain +from typing import TYPE_CHECKING, Any import torch from torch import nn @@ -66,7 +68,7 @@ def _to_cpu(self, module: nn.Module) -> None: # Release allocator blocks when tensors leave the GPU. if previous_device.type != "cpu": - current_omni_platform.empty_cache() + torch.cuda.empty_cache() if self.pin_memory: for p in module.parameters(): @@ -125,6 +127,267 @@ def remove(self) -> None: self._handles = [] +class LayerwiseOffloader: + """Layer-wise CPU offloading for transformer blocks. + + Keeps only a sliding window of layers (blocks), by default a single layer, on GPU, + prefetching the next block while the current block computes to approach compute - memcpy overlap. + Unused blocks are freed on GPU. + + Based on implementations from: + https://github.com/sgl-project/sglang/blob/v0.5.8/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py + """ + + def __init__( + self, + blocks: list[nn.Module], + device: torch.device, + pin_memory: bool = True, + num_gpu_layers: int = 1, + ): + assert all(isinstance(m, nn.Module) for m in blocks), "All transformer blocks must be torch.nn.Module" + assert current_omni_platform.is_cuda(), "Layerwise offloading is only supported on cuda devices for now" + + self.blocks = blocks + self.device = device + self.pin_memory = pin_memory + self.num_gpu_layers = num_gpu_layers + self.num_blocks = len(self.blocks) + if self.num_blocks == 0: + raise ValueError("LayerwiseOffloader requires at least one block, but found 0.") + if not (1 <= self.num_gpu_layers <= self.num_blocks): + raise ValueError(f"Invalid num_gpu_layers {self.num_gpu_layers} with {self.num_blocks} blocks") + + self._pre_hook_handles: list = [] + self._post_hook_handles: list = [] + + self._copy_stream = torch.cuda.Stream() + + # Per-layer synchronization primitive: set after H2D copy completes. + self._prefetch_done: list[torch.cuda.Event | None] = [None] * self.num_blocks + + # Simple state to avoid redundant work. + self._resident: list[bool] = [False] * self.num_blocks + + # Pre-allocate gpu tensors + # layer-id -> {dtype -> flattened aggregated cpu tensor} + self.layer_cpu_weights: list[dict[torch.dtype, torch.Tensor]] = [] + self.layer_metadata: list[dict[torch.dtype, list[dict[str, Any]]]] = [] + + self.block_parameters: dict[int, dict[str, nn.Parameter]] = {} + self.block_buffers: dict[int, dict[str, torch.Tensor]] = {} + for layer_idx, block in enumerate(self.blocks): + self.block_parameters[layer_idx] = dict(block.named_parameters()) + self.block_buffers[layer_idx] = dict(block.named_buffers()) + + dtype_cpu_flattened_weights, dtype_metadata = self._to_cpu( + self.block_parameters[layer_idx], self.block_buffers[layer_idx] + ) + self.layer_cpu_weights.append(dtype_cpu_flattened_weights) + self.layer_metadata.append(dtype_metadata) + + if self.num_blocks != len(self.layer_cpu_weights): + logger.error( + f"Inconsistent block layers happened: # of blocks: {self.num_blocks}; " + f"# of layer cpu weights: {len(self.layer_cpu_weights)}" + ) + + # Register pre and post forward hooks on each of the blocks + self.register_block_hooks() + + # Pre-fetch the first layer + # For subsequent requests, the first layer/block will be pre-fetched + # during the last layer compute of the previous request. + self.prefetch_layer(0, non_blocking=False) + + def _to_cpu( + self, params: dict[str, nn.Parameter], bufs: dict[str, torch.Tensor] + ) -> tuple[dict[torch.dtype, torch.Tensor], dict[torch.dtype, list[dict[str, Any]]]]: + """Move block parameters and buffers to CPU, flattening by dtype. + + Consolidates parameters and buffers into contiguous CPU tensors grouped by dtype + for GPU transfers. Replaces original tensors with empty placeholders. + + Returns: + Tuple of + flattened CPU tensors by dtype, + metadata for reconstruction by dtype + """ + dtype_grouped_weights: dict[torch.dtype, dict[str, torch.Tensor]] = {} + dtype_cpu_flattened_weights: dict[torch.dtype, torch.Tensor] = {} + # order does matter + dtype_metadata: dict[torch.dtype, list[dict[str, Any]]] = {} + + for name, param_or_buf in chain(params.items(), bufs.items()): + dtype = param_or_buf.dtype + if dtype not in dtype_grouped_weights: + dtype_grouped_weights[dtype] = {} + dtype_grouped_weights[dtype][name] = param_or_buf + + for dtype, name2weights in dtype_grouped_weights.items(): + # total # of parameters + buffers + total_numel = sum(t.numel() for _, t in name2weights.items()) + cpu_tensor = torch.empty(total_numel, dtype=dtype, device="cpu", pin_memory=self.pin_memory) + + current_offset = 0 + for name, param_or_buf in name2weights.items(): + numel = param_or_buf.numel() + cpu_tensor[current_offset : current_offset + numel].copy_(param_or_buf.flatten()) + if dtype not in dtype_metadata: + dtype_metadata[dtype] = [] + dtype_metadata[dtype].append( + { + "name": name, + "offset": current_offset, + "numel": numel, + "shape": param_or_buf.shape, + } + ) + + param_or_buf.data = torch.empty((), device=self.device, dtype=dtype) + current_offset += numel + + dtype_cpu_flattened_weights[dtype] = cpu_tensor + + return dtype_cpu_flattened_weights, dtype_metadata + + def register_block_hooks(self) -> None: + """Register forward hooks on blocks for prefetching and offloading.""" + + def _pre_hook(module: nn.Module, args: tuple, *, layer_idx: int) -> None: + # For the last block / layer, prefetch layer 0 (the first layer) + next_id = (layer_idx + 1) % self.num_blocks + self.prefetch_layer(next_id, non_blocking=True) + + def _post_hook(module: nn.Module, args: tuple, output: tuple, *, layer_idx: int) -> None: + self.offload_layer(layer_idx) + self._resident[layer_idx] = False + self._prefetch_done[layer_idx] = None + + for i, layer in enumerate(self.blocks): + pre_hook_fn = partial(_pre_hook, layer_idx=i) + handle = layer.register_forward_pre_hook(pre_hook_fn) + self._pre_hook_handles.append(handle) + + post_hook_fn = partial(_post_hook, layer_idx=i) + handle = layer.register_forward_hook(post_hook_fn) + self._post_hook_handles.append(handle) + + @torch.compiler.disable + def prefetch_layer(self, layer_idx: int, non_blocking: bool = True) -> None: + """Copy layer weights from CPU -> GPU. + + Pre-fetch target layer in an asynchronous way with compute - memory copy overlap, + with non_blocking set to True. + """ + if layer_idx >= self.num_blocks or layer_idx < 0: + logger.warning(f"Invalid layer id specified: {layer_idx}") + return + + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + layers_to_fetch = [(layer_idx + i) % self.num_blocks for i in range(self.num_gpu_layers)] + + for idx in layers_to_fetch: + if self._resident[idx]: + continue + + layer_params = self.block_parameters[idx] + layer_bufs = self.block_buffers[idx] + + evt = torch.cuda.Event() + gpu_weights: dict[torch.dtype, torch.Tensor] = {} + + with torch.cuda.stream(self._copy_stream): + for dtype, cpu_weight in self.layer_cpu_weights[idx].items(): + gpu_weight = torch.empty(cpu_weight.shape, dtype=dtype, device=self.device) + gpu_weight.copy_(cpu_weight, non_blocking=non_blocking) + gpu_weights[dtype] = gpu_weight + + evt.record(self._copy_stream) + + for dtype in self.layer_metadata[idx]: + ordered_metadata: list[dict[str, Any]] = self.layer_metadata[idx][dtype] + + gpu_weight = gpu_weights[dtype] + + for metadata in ordered_metadata: + target_name = metadata["name"] + target_param_or_buf = ( + layer_params[target_name] if target_name in layer_params else layer_bufs[target_name] + ) + + target_param_or_buf.data = gpu_weight[ + metadata["offset"] : metadata["offset"] + metadata["numel"] + ].view(metadata["shape"]) + + self._prefetch_done[idx] = evt + self._resident[idx] = True + + @torch.compiler.disable + def offload_layer(self, layer_idx: int) -> None: + """Free GPU memory for layer by replacing tensors with empty placeholders.""" + if layer_idx >= self.num_blocks or layer_idx < 0: + logger.warning(f"Invalid layer id specified: {layer_idx}") + return + if not self._resident[layer_idx]: + logger.warning(f"{layer_idx} is not residing on GPU") + return + + evt = self._prefetch_done[layer_idx] + if evt is not None: + torch.cuda.current_stream().wait_event(evt) + + # free GPU residency + for _, param in self.block_parameters[layer_idx].items(): + param.data = torch.empty((), device=self.device, dtype=param.dtype) + for _, buf in self.block_buffers[layer_idx].items(): + buf.data = torch.empty((), device=self.device, dtype=buf.dtype) + + def remove_all_hooks(self) -> None: + """Remove all hooks.""" + for h in self._pre_hook_handles: + h.remove() + for h in self._post_hook_handles: + h.remove() + self._pre_hook_handles.clear() + self._post_hook_handles.clear() + + @staticmethod + def get_blocks_attr_name(model: nn.Module) -> str | None: + """Retrieve blocks attribute name from provided DiT model""" + return getattr(model.__class__, "_layerwise_offload_blocks_attr", None) + + @staticmethod + def get_blocks_from_dit(model: nn.Module) -> list[nn.Module]: + """ + Retrieve a list of blocks from provided DiT model. Blocks attribute name + are found by `_layerwise_offload_blocks_attr` set to DiT models. For example, + + ``` + class WanTransformer3DModel(nn.Module): + _layerwise_offload_blocks_attr = "blocks" + ``` + """ + blocks_attr_name = LayerwiseOffloader.get_blocks_attr_name(model) + if blocks_attr_name is None: + logger.warning( + f"No _layerwise_offload_blocks_attr defined for {model.__class__.__name__}, " + "skipping layerwise offloading" + ) + return [] + + _blocks = getattr(model, blocks_attr_name, None) + if _blocks is None: + logger.warning( + f"Blocks (layers) '{blocks_attr_name}' not found on {model.__class__.__name__}, " + "skipping layerwise offloading" + ) + return [] + + return list(_blocks) + + def apply_offload_hooks( model: nn.Module, od_config: OmniDiffusionConfig, @@ -142,7 +405,24 @@ def apply_offload_hooks( model: Diffusion pipeline model od_config: OmniDiffusionConfig with offload settings """ - if not getattr(od_config, "enable_cpu_offload", False): + enable_cpu_offload = getattr(od_config, "enable_cpu_offload", False) + enable_layerwise_offload = getattr(od_config, "enable_layerwise_offload", False) + pin_cpu_memory = getattr(od_config, "pin_cpu_memory", True) + + if not enable_cpu_offload and not enable_layerwise_offload: + return + if enable_cpu_offload and enable_layerwise_offload: + # NOTE: Model-wise and layerwise cpu offloading are not supported together at this moment, + # consider layerwise offloading has higher priority than model-wise offloading + enable_cpu_offload = False + logger.info( + "Model-wise and layer-wise CPU offloading are not supported together at this moment. " + "Automatically disabled model-wise offloading." + ) + # For now, model-wise and layer-wise (block-wise) offloading + # are functioning as expected when cuda device is available + if not current_omni_platform.is_cuda() or current_omni_platform.get_device_count() < 1: + logger.info("CPU Offloading requires cuda devices available. Skipping for now...") return # Find DiT/transformer modules @@ -167,7 +447,6 @@ def apply_offload_hooks( if not dit_modules: logger.warning("enable_cpu_offload enabled but no transformer/dit/unet found") return - if device is None: try: device = next(dit_modules[0].parameters()).device @@ -175,7 +454,8 @@ def apply_offload_hooks( try: device = current_omni_platform.get_torch_device() except (NotImplementedError, AttributeError): - device = torch.device("cpu") + logger.error("Fail to get device of pipeline. Skipping applying offloading hooks") + return # Collect all encoders encoders: list[nn.Module] = [] @@ -184,29 +464,72 @@ def apply_offload_hooks( if hasattr(model, attr) and getattr(model, attr) is not None: encoders.append(getattr(model, attr)) encoder_names.append(attr) - - if not encoders: + if not encoders and enable_cpu_offload: logger.warning("enable_cpu_offload enabled but no encoders found") return + for enc in encoders: + enc.to(device) - # Initial state: keep DiT modules on CPU (encoders typically run first) - pin = getattr(od_config, "pin_cpu_memory", True) - for dit_mod in dit_modules: - dit_mod.to("cpu") - - current_omni_platform.empty_cache() + # Collect VAE + for name in ["vae"]: + module = getattr(model, name, None) + if module is None: + continue + try: + module.to(device, non_blocking=True) + except Exception as exc: + logger.debug("Failed to move %s to GPU: %s", name, exc) - if pin: + if enable_cpu_offload: + # Initial state: keep DiT modules on CPU (encoders typically run first) for dit_mod in dit_modules: - for p in dit_mod.parameters(): - if p.data.device.type == "cpu" and not p.data.is_pinned(): - p.data = p.data.pin_memory() - - # Register sequential offload hooks - SequentialOffloader(dit_modules, encoders, device, pin).register() - - logger.info( - "CPU offload enabled: %s <-> %s (mutual exclusion)", - ", ".join(dit_names), - ", ".join(encoder_names), - ) + dit_mod.to("cpu") + + torch.cuda.empty_cache() + + if pin_cpu_memory: + for dit_mod in dit_modules: + for p in dit_mod.parameters(): + if p.data.device.type == "cpu" and not p.data.is_pinned(): + p.data = p.data.pin_memory() + + # Register sequential offload hooks + SequentialOffloader(dit_modules, encoders, device, pin_cpu_memory).register() + logger.info( + "CPU offload enabled: %s <-> %s (mutual exclusion)", + ", ".join(dit_names), + ", ".join(encoder_names), + ) + elif enable_layerwise_offload: + logger.info(f"Applying offloading hooks on {dit_names}") + + for i, dit_module in enumerate(dit_modules): + logger.info(f"Applying hook on {dit_names[i]} ({dit_module.__class__.__name__})") + blocks_attr_name = LayerwiseOffloader.get_blocks_attr_name(dit_module) + blocks = LayerwiseOffloader.get_blocks_from_dit(dit_module) + + if not blocks_attr_name or not blocks: + logger.warning( + "Target layers (blocks) are not found. " + f"Skipping offloading on {dit_names[i]} ({dit_module.__class__.__name__})" + ) + continue + + # move modules other than blocks to gpu and keep them on gpu + for name, m in dit_module.named_children(): + # Skip the blocks module (layers to be offloaded) + if name == blocks_attr_name: + logger.debug(f"Skipped module {name}") + continue + + m.to(device) + logger.debug(f"Moved {name} to device {device}") + + # set to the module (transformer) + offloader = LayerwiseOffloader(blocks, device, pin_cpu_memory, od_config.layerwise_num_gpu_layers) + setattr(dit_module, "_layerwise_offloader", offloader) + + logger.info( + f"Layerwise offloading enabled on {len(blocks)} layers (blocks), " + f"with {od_config.layerwise_num_gpu_layers} kept on device" + ) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index b60d2c2f5b9..b0bec085216 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -76,7 +76,9 @@ def load_model( memory_pool_context_fn: Optional function that returns a context manager for memory pool allocation (used for sleep mode). """ - load_device = "cpu" if self.od_config.enable_cpu_offload else str(self.device) + load_device = ( + "cpu" if self.od_config.enable_cpu_offload or self.od_config.enable_layerwise_offload else str(self.device) + ) def get_memory_context(): if memory_pool_context_fn is not None: @@ -104,17 +106,8 @@ def get_memory_context(): ) logger.info("Model runner: Model loaded successfully.") - # Apply CPU offloading (DiT <-> encoders mutual exclusion) - if self.od_config.enable_cpu_offload: - for name in ["vae"]: - module = getattr(self.pipeline, name, None) - if module is None: - continue - try: - module.to(self.device, non_blocking=True) - except Exception as exc: - logger.debug("Failed to move %s to GPU: %s", name, exc) - + # Apply CPU offloading + if self.od_config.enable_cpu_offload or self.od_config.enable_layerwise_offload: apply_offload_hooks(self.pipeline, self.od_config, device=self.device) # Apply torch.compile if not in eager mode diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index ec2a18ac4d8..ab829beaae4 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -165,6 +165,8 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "cache_config": cache_config, "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), + "layerwise_num_gpu_layers": kwargs.get("layerwise_num_gpu_layers", False), "enforce_eager": kwargs.get("enforce_eager", False), }, "final_output": True, diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index bbd16711194..5613bdaeb53 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -199,6 +199,17 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu action="store_true", help="Enable CPU offloading for diffusion models.", ) + serve_parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + serve_parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of layers (blocks) to keep on GPU during generation.", + ) # Video model parameters (e.g., Wan2.2) - engine-level omni_config_group.add_argument(