diff --git a/tests/diffusion/offloader/test_layerwise_backend.py b/tests/diffusion/offloader/test_layerwise_backend.py new file mode 100644 index 00000000000..7df3c1bb1a1 --- /dev/null +++ b/tests/diffusion/offloader/test_layerwise_backend.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for LayerwiseOffloadHook.""" + +import gc +import os +import socket +from contextlib import contextmanager + +import pytest +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate + +import vllm_omni.diffusion.offloader.layerwise_backend as layerwise_backend_module +from vllm_omni.diffusion.offloader.layerwise_backend import LayerwiseOffloadHook +from vllm_omni.platforms import current_omni_platform + +pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model] + + +class DummyStream: + def wait_stream(self, _stream) -> None: + return None + + def wait_event(self, _event) -> None: + return None + + +class DummyEvent: + def record(self, _stream) -> None: + return None + + +@contextmanager +def dummy_stream(_stream): + yield None + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return int(s.getsockname()[1]) + + +def _set_dist_env(*, rank: int, world_size: int, master_port: int) -> None: + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + + +def _cleanup_distributed() -> None: + if dist.is_initialized(): + dist.destroy_process_group() + + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]: + os.environ.pop(key, None) + + gc.collect() + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + current_omni_platform.synchronize() + + +@pytest.fixture(scope="module") +def dist_group(): + master_port = _find_free_port() + _set_dist_env(rank=0, world_size=1, master_port=master_port) + + dist.init_process_group("gloo", rank=0, world_size=1) + try: + yield + finally: + _cleanup_distributed() + + +@pytest.fixture +def patched_offload_runtime(mocker): + mocker.patch.object(layerwise_backend_module.current_omni_platform, "Stream", DummyStream) + mocker.patch.object(layerwise_backend_module.current_omni_platform, "Event", DummyEvent) + mocker.patch.object(layerwise_backend_module.current_omni_platform, "current_stream", lambda: DummyStream()) + mocker.patch.object(layerwise_backend_module.current_omni_platform, "stream", dummy_stream) + + +class TinyBlock(nn.Module): + def __init__(self, values: torch.Tensor): + super().__init__() + mesh = DeviceMesh("cpu", [0]) + dtensor = DTensor.from_local(values, mesh, [Replicate()]) + self.weight = nn.Parameter(dtensor) + + +def _make_values(start: float) -> torch.Tensor: + return torch.arange(start, start + 4, dtype=torch.float32) + + +class TestLayerwiseOffloadHook: + def test_dtensor_wrapper_is_preserved_across_prefetch_and_offload(self, dist_group, patched_offload_runtime): + current_block = TinyBlock(_make_values(1.0)) + next_block = TinyBlock(_make_values(10.0)) + + hook = LayerwiseOffloadHook( + next_block=next_block, + device=torch.device("cpu"), + stream=DummyStream(), + pin_memory=False, + ) + + hook.initialize_hook(current_block) + + assert isinstance(next_block.weight, DTensor) + assert next_block.weight.to_local().is_meta + assert next_block.weight.to_local().shape == torch.Size([4]) + assert hook.dtype_metadata[next_block.weight.dtype][0]["shape"] == torch.Size([4]) + + hook.prefetch_layer(non_blocking=False) + assert isinstance(next_block.weight, DTensor) + assert torch.equal(next_block.weight.to_local(), _make_values(10.0)) + assert next_block.weight.to_local().shape == torch.Size([4]) + + hook.offload_layer() + assert isinstance(current_block.weight, DTensor) + assert current_block.weight.to_local().is_meta + assert current_block.weight.to_local().shape == torch.Size([4]) + assert not hook.is_materialized diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 5b66ae5ee22..20af5b5d828 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -6,6 +6,7 @@ import torch from torch import nn +from torch.distributed.tensor import DTensor from vllm.logger import init_logger from vllm_omni.diffusion.hooks import HookRegistry, ModelHook @@ -58,6 +59,31 @@ def __init__( self.dtype_cpu_flattened_weights: dict[torch.dtype, torch.Tensor] = {} self.dtype_metadata: dict[torch.dtype, list[dict[str, Any]]] = {} + @staticmethod + def _is_dtensor(t: torch.Tensor) -> bool: + return isinstance(t, DTensor) + + @staticmethod + def _set_tensor_storage(target: torch.Tensor, value: torch.Tensor) -> None: + if LayerwiseOffloadHook._is_dtensor(target): + target._local_tensor = value + else: + target.data = value + + @staticmethod + def _make_offload_placeholder(tensor: torch.Tensor) -> torch.Tensor: + if LayerwiseOffloadHook._is_dtensor(tensor): + local_shape = tuple(tensor.to_local().shape) + return torch.empty(local_shape, device="meta", dtype=tensor.dtype) + return torch.empty((0,), device=tensor.device, dtype=tensor.dtype) + + @staticmethod + def _is_materialized_tensor(t: torch.Tensor) -> bool: + if LayerwiseOffloadHook._is_dtensor(t): + local_t = t.to_local() + return not local_t.is_meta + return not t.is_meta and t.data.numel() > 0 + def initialize_hook(self, module: nn.Module) -> nn.Module: # This all happen during the hook instance being registered to hook registry; # the input module is kept intact @@ -71,7 +97,10 @@ def initialize_hook(self, module: nn.Module) -> nn.Module: # Pre-allocate gpu tensors in a flattened way self.dtype_cpu_flattened_weights, self.dtype_metadata = LayerwiseOffloadHook._to_cpu( - self.next_block_parameters, self.next_block_buffers, self.device, self.pin_memory + self.next_block_parameters, + self.next_block_buffers, + self.device, + self.pin_memory, ) return module @@ -106,13 +135,17 @@ def _to_cpu( for dtype, name2weights in dtype_grouped_weights.items(): # total # of parameters + buffers - total_numel = sum(t.numel() for _, t in name2weights.items()) + weights_with_local = [] + for name, t in name2weights.items(): + local_t = t.to_local() if hasattr(t, "to_local") else t + weights_with_local.append((name, t, local_t)) + total_numel = sum(local.numel() for _, _, local in weights_with_local) cpu_tensor = torch.empty(total_numel, dtype=dtype, device="cpu", pin_memory=pin_memory) current_offset = 0 - for name, param_or_buf in name2weights.items(): - numel = param_or_buf.numel() - cpu_tensor[current_offset : current_offset + numel].copy_(param_or_buf.flatten()) + for name, original_tensor, local_tensor in weights_with_local: + numel = local_tensor.numel() + cpu_tensor[current_offset : current_offset + numel].copy_(local_tensor.flatten()) if dtype not in dtype_metadata: dtype_metadata[dtype] = [] dtype_metadata[dtype].append( @@ -120,11 +153,13 @@ def _to_cpu( "name": name, "offset": current_offset, "numel": numel, - "shape": param_or_buf.shape, + "shape": local_tensor.shape, } ) - param_or_buf.data = torch.empty((), device=device, dtype=dtype) + LayerwiseOffloadHook._set_tensor_storage( + original_tensor, LayerwiseOffloadHook._make_offload_placeholder(original_tensor) + ) current_offset += numel dtype_cpu_flattened_weights[dtype] = cpu_tensor @@ -135,7 +170,7 @@ def _to_cpu( def is_materialized(self) -> bool: """Check whether this block's parameters hold real data on device.""" for param in self.block_parameters.values(): - return param.data.dim() > 0 + return LayerwiseOffloadHook._is_materialized_tensor(param) return True @@ -172,8 +207,9 @@ def prefetch_layer(self, non_blocking: bool = True) -> None: layer_params[target_name] if target_name in layer_params else layer_bufs[target_name] ) - target_param_or_buf.data = gpu_weight[metadata["offset"] : metadata["offset"] + metadata["numel"]].view( - metadata["shape"] + LayerwiseOffloadHook._set_tensor_storage( + target_param_or_buf, + gpu_weight[metadata["offset"] : metadata["offset"] + metadata["numel"]].view(metadata["shape"]), ) self._prefetch_done = evt @@ -191,9 +227,9 @@ def offload_layer(self) -> None: # free GPU residency for _, param in self.block_parameters.items(): - param.data = torch.empty((), device=self.device, dtype=param.dtype) + LayerwiseOffloadHook._set_tensor_storage(param, LayerwiseOffloadHook._make_offload_placeholder(param)) for _, buf in self.block_buffers.items(): - buf.data = torch.empty((), device=self.device, dtype=buf.dtype) + LayerwiseOffloadHook._set_tensor_storage(buf, LayerwiseOffloadHook._make_offload_placeholder(buf)) def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tuple, dict]: # if the previous hook was skipped and the weights are not on device, @@ -311,7 +347,11 @@ def enable(self, pipeline: nn.Module) -> None: # during the last layer compute of the previous request. last_block, first_block = blocks[-1], blocks[0] last_hook = apply_block_hook( - last_block, first_block, self.device, self.copy_stream, self.config.pin_cpu_memory + last_block, + first_block, + self.device, + self.copy_stream, + self.config.pin_cpu_memory, ) last_hook.prefetch_layer(non_blocking=False) @@ -319,7 +359,13 @@ def enable(self, pipeline: nn.Module) -> None: # Register hook for each of blocks for i, block in enumerate(blocks[:-1]): next_block = blocks[(i + 1) % num_blocks] - hook = apply_block_hook(block, next_block, self.device, self.copy_stream, self.config.pin_cpu_memory) + hook = apply_block_hook( + block, + next_block, + self.device, + self.copy_stream, + self.config.pin_cpu_memory, + ) block_hooks.append(hook) # NOTE(yuanheng-zhao): We make each hook gets a backward reference to the hook