Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/offline_inference/test_diffusion_cpu_offload.py
- pytest -s -v tests/e2e/offline_inference/test_diffusion_layerwise_offload.py

- label: "Diffusion Cache Backend Test"
timeout_in_minutes: 15
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from tests.utils import GPUMemoryMonitor
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
Expand Down Expand Up @@ -68,7 +67,6 @@ def run_inference(
return peak


@pytest.mark.skipif(current_omni_platform.is_npu() or current_omni_platform.is_rocm(), reason="Hardware not supported")
@pytest.mark.parametrize("model_name", MODELS_SAVED_MEMORY_MB.keys())
def test_layerwise_offload_diffusion_model(model_name: str):
"""Test that layerwise offloading reduces GPU memory usage.
Expand Down
19 changes: 9 additions & 10 deletions vllm_omni/diffusion/offloader/layerwise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,18 @@ def __init__(
self,
next_block: nn.Module,
device: torch.device,
stream: torch.cuda.Stream | None = None,
stream: current_omni_platform.Stream | None = None,
pin_memory: bool = True,
):
assert isinstance(next_block, nn.Module), "transformer block must be type `torch.nn.Module`"
assert current_omni_platform.is_cuda(), "Layerwise offloading is only supported on cuda devices for now"

self.next_block = next_block
self.device = device
self.copy_stream = stream or torch.cuda.current_stream()
self.copy_stream = stream or current_omni_platform.current_stream()
self.pin_memory = pin_memory

# Per-block synchronization primitive: set after H2D copy completes.
self._prefetch_done: torch.cuda.Event | None = None
self._prefetch_done: current_omni_platform.Event | None = None

self.next_block_parameters: dict[str, nn.Parameter] = {}
self.next_block_buffers: dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -136,15 +135,15 @@ def prefetch_layer(self, non_blocking: bool = True) -> None:
Pre-fetch target block in an asynchronous way with compute - memory copy overlap,
with non_blocking set to True.
"""
self.copy_stream.wait_stream(torch.cuda.current_stream())
self.copy_stream.wait_stream(current_omni_platform.current_stream())

layer_params = self.next_block_parameters
layer_bufs = self.next_block_buffers

evt = torch.cuda.Event()
evt = current_omni_platform.Event()
gpu_weights: dict[torch.dtype, torch.Tensor] = {}

with torch.cuda.stream(self.copy_stream):
with current_omni_platform.stream(self.copy_stream):
for dtype, cpu_weight in self.dtype_cpu_flattened_weights.items():
gpu_weight = torch.empty(cpu_weight.shape, dtype=dtype, device=self.device)
gpu_weight.copy_(cpu_weight, non_blocking=non_blocking)
Expand Down Expand Up @@ -175,7 +174,7 @@ def offload_layer(self) -> None:
"""
evt = self._prefetch_done
if evt is not None:
torch.cuda.current_stream().wait_event(evt)
current_omni_platform.current_stream().wait_event(evt)

self._prefetch_done = None

Expand All @@ -200,7 +199,7 @@ def apply_block_hook(
module: nn.Module,
next_block: nn.Module,
device: torch.device,
stream: torch.cuda.Stream | None = None,
stream: current_omni_platform.Stream | None = None,
pin_memory: bool = True,
) -> LayerwiseOffloadHook:
registry = HookRegistry.get_or_create(module)
Expand Down Expand Up @@ -228,7 +227,7 @@ class LayerWiseOffloadBackend(OffloadBackend):
def __init__(self, config: OffloadConfig, device: torch.device):
super().__init__(config, device)

self.copy_stream = torch.cuda.Stream()
self.copy_stream = current_omni_platform.Stream()
self._blocks: list[list[nn.Module]] = []

def enable(self, pipeline: nn.Module) -> None:
Expand Down