Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions tests/diffusion/offloader/test_layerwise_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for LayerwiseOffloadHook."""

import gc
import os
import socket
from contextlib import contextmanager

import pytest
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate

import vllm_omni.diffusion.offloader.layerwise_backend as layerwise_backend_module
from vllm_omni.diffusion.offloader.layerwise_backend import LayerwiseOffloadHook
from vllm_omni.platforms import current_omni_platform

pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model]


class DummyStream:
def wait_stream(self, _stream) -> None:
return None

def wait_event(self, _event) -> None:
return None


class DummyEvent:
def record(self, _stream) -> None:
return None


@contextmanager
def dummy_stream(_stream):
yield None


def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return int(s.getsockname()[1])


def _set_dist_env(*, rank: int, world_size: int, master_port: int) -> None:
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(master_port)


def _cleanup_distributed() -> None:
if dist.is_initialized():
dist.destroy_process_group()

for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]:
os.environ.pop(key, None)

gc.collect()
if current_omni_platform.is_available():
current_omni_platform.empty_cache()
current_omni_platform.synchronize()


@pytest.fixture(scope="module")
def dist_group():
master_port = _find_free_port()
_set_dist_env(rank=0, world_size=1, master_port=master_port)

dist.init_process_group("gloo", rank=0, world_size=1)
try:
yield
finally:
_cleanup_distributed()


@pytest.fixture
def patched_offload_runtime(mocker):
mocker.patch.object(layerwise_backend_module.current_omni_platform, "Stream", DummyStream)
mocker.patch.object(layerwise_backend_module.current_omni_platform, "Event", DummyEvent)
mocker.patch.object(layerwise_backend_module.current_omni_platform, "current_stream", lambda: DummyStream())
mocker.patch.object(layerwise_backend_module.current_omni_platform, "stream", dummy_stream)


class TinyBlock(nn.Module):
def __init__(self, values: torch.Tensor):
super().__init__()
mesh = DeviceMesh("cpu", [0])
dtensor = DTensor.from_local(values, mesh, [Replicate()])
self.weight = nn.Parameter(dtensor)


def _make_values(start: float) -> torch.Tensor:
return torch.arange(start, start + 4, dtype=torch.float32)


class TestLayerwiseOffloadHook:
def test_dtensor_wrapper_is_preserved_across_prefetch_and_offload(self, dist_group, patched_offload_runtime):
current_block = TinyBlock(_make_values(1.0))
next_block = TinyBlock(_make_values(10.0))

hook = LayerwiseOffloadHook(
next_block=next_block,
device=torch.device("cpu"),
stream=DummyStream(),
pin_memory=False,
)

hook.initialize_hook(current_block)

assert isinstance(next_block.weight, DTensor)
assert next_block.weight.to_local().is_meta
assert next_block.weight.to_local().shape == torch.Size([4])
assert hook.dtype_metadata[next_block.weight.dtype][0]["shape"] == torch.Size([4])

hook.prefetch_layer(non_blocking=False)
assert isinstance(next_block.weight, DTensor)
assert torch.equal(next_block.weight.to_local(), _make_values(10.0))
assert next_block.weight.to_local().shape == torch.Size([4])

hook.offload_layer()
assert isinstance(current_block.weight, DTensor)
assert current_block.weight.to_local().is_meta
assert current_block.weight.to_local().shape == torch.Size([4])
assert not hook.is_materialized
74 changes: 60 additions & 14 deletions vllm_omni/diffusion/offloader/layerwise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import nn
from torch.distributed.tensor import DTensor
from vllm.logger import init_logger

from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -58,6 +59,31 @@ def __init__(
self.dtype_cpu_flattened_weights: dict[torch.dtype, torch.Tensor] = {}
self.dtype_metadata: dict[torch.dtype, list[dict[str, Any]]] = {}

@staticmethod
def _is_dtensor(t: torch.Tensor) -> bool:
return isinstance(t, DTensor)

@staticmethod
def _set_tensor_storage(target: torch.Tensor, value: torch.Tensor) -> None:
if LayerwiseOffloadHook._is_dtensor(target):
target._local_tensor = value
else:
target.data = value

@staticmethod
def _make_offload_placeholder(tensor: torch.Tensor) -> torch.Tensor:
if LayerwiseOffloadHook._is_dtensor(tensor):
local_shape = tuple(tensor.to_local().shape)
return torch.empty(local_shape, device="meta", dtype=tensor.dtype)
return torch.empty((0,), device=tensor.device, dtype=tensor.dtype)

@staticmethod
def _is_materialized_tensor(t: torch.Tensor) -> bool:
if LayerwiseOffloadHook._is_dtensor(t):
local_t = t.to_local()
return not local_t.is_meta
return not t.is_meta and t.data.numel() > 0

def initialize_hook(self, module: nn.Module) -> nn.Module:
# This all happen during the hook instance being registered to hook registry;
# the input module is kept intact
Expand All @@ -71,7 +97,10 @@ def initialize_hook(self, module: nn.Module) -> nn.Module:

# Pre-allocate gpu tensors in a flattened way
self.dtype_cpu_flattened_weights, self.dtype_metadata = LayerwiseOffloadHook._to_cpu(
self.next_block_parameters, self.next_block_buffers, self.device, self.pin_memory
self.next_block_parameters,
self.next_block_buffers,
self.device,
self.pin_memory,
)

return module
Expand Down Expand Up @@ -106,25 +135,31 @@ def _to_cpu(

for dtype, name2weights in dtype_grouped_weights.items():
# total # of parameters + buffers
total_numel = sum(t.numel() for _, t in name2weights.items())
weights_with_local = []
for name, t in name2weights.items():
local_t = t.to_local() if hasattr(t, "to_local") else t
weights_with_local.append((name, t, local_t))
total_numel = sum(local.numel() for _, _, local in weights_with_local)
cpu_tensor = torch.empty(total_numel, dtype=dtype, device="cpu", pin_memory=pin_memory)

current_offset = 0
for name, param_or_buf in name2weights.items():
numel = param_or_buf.numel()
cpu_tensor[current_offset : current_offset + numel].copy_(param_or_buf.flatten())
for name, original_tensor, local_tensor in weights_with_local:
numel = local_tensor.numel()
cpu_tensor[current_offset : current_offset + numel].copy_(local_tensor.flatten())
if dtype not in dtype_metadata:
dtype_metadata[dtype] = []
dtype_metadata[dtype].append(
{
"name": name,
"offset": current_offset,
"numel": numel,
"shape": param_or_buf.shape,
"shape": local_tensor.shape,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefetch_layer uses metadata["shape"] to .view() the GPU slice and then assigns it to target_param_or_buf.data. With this PR, shape is now the local tensor shape, but target_param_or_buf is still a DTensor. That .data assignment replaces the DTensor internals with a plain tensor — doesn't this break FSDP/HSDP state tracking on the reload path? Same concern for offload_layer which does param.data = torch.empty(...).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied the same DTensor-safe storage update in prefetch_layer() as well, so both prefetch and offload follow the same handling

}
)

param_or_buf.data = torch.empty((), device=device, dtype=dtype)
LayerwiseOffloadHook._set_tensor_storage(
original_tensor, LayerwiseOffloadHook._make_offload_placeholder(original_tensor)
)
current_offset += numel

dtype_cpu_flattened_weights[dtype] = cpu_tensor
Expand All @@ -135,7 +170,7 @@ def _to_cpu(
def is_materialized(self) -> bool:
"""Check whether this block's parameters hold real data on device."""
for param in self.block_parameters.values():
return param.data.dim() > 0
return LayerwiseOffloadHook._is_materialized_tensor(param)

return True

Expand Down Expand Up @@ -172,8 +207,9 @@ def prefetch_layer(self, non_blocking: bool = True) -> None:
layer_params[target_name] if target_name in layer_params else layer_bufs[target_name]
)

target_param_or_buf.data = gpu_weight[metadata["offset"] : metadata["offset"] + metadata["numel"]].view(
metadata["shape"]
LayerwiseOffloadHook._set_tensor_storage(
target_param_or_buf,
gpu_weight[metadata["offset"] : metadata["offset"] + metadata["numel"]].view(metadata["shape"]),
)

self._prefetch_done = evt
Expand All @@ -191,9 +227,9 @@ def offload_layer(self) -> None:

# free GPU residency
for _, param in self.block_parameters.items():
param.data = torch.empty((), device=self.device, dtype=param.dtype)
LayerwiseOffloadHook._set_tensor_storage(param, LayerwiseOffloadHook._make_offload_placeholder(param))
for _, buf in self.block_buffers.items():
buf.data = torch.empty((), device=self.device, dtype=buf.dtype)
LayerwiseOffloadHook._set_tensor_storage(buf, LayerwiseOffloadHook._make_offload_placeholder(buf))

def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tuple, dict]:
# if the previous hook was skipped and the weights are not on device,
Expand Down Expand Up @@ -311,15 +347,25 @@ def enable(self, pipeline: nn.Module) -> None:
# during the last layer compute of the previous request.
last_block, first_block = blocks[-1], blocks[0]
last_hook = apply_block_hook(
last_block, first_block, self.device, self.copy_stream, self.config.pin_cpu_memory
last_block,
first_block,
self.device,
self.copy_stream,
self.config.pin_cpu_memory,
)
last_hook.prefetch_layer(non_blocking=False)

block_hooks: list[LayerwiseOffloadHook] = [last_hook]
# Register hook for each of blocks
for i, block in enumerate(blocks[:-1]):
next_block = blocks[(i + 1) % num_blocks]
hook = apply_block_hook(block, next_block, self.device, self.copy_stream, self.config.pin_cpu_memory)
hook = apply_block_hook(
block,
next_block,
self.device,
self.copy_stream,
self.config.pin_cpu_memory,
)
block_hooks.append(hook)

# NOTE(yuanheng-zhao): We make each hook gets a backward reference to the hook
Expand Down
Loading