Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions tests/e2e/offline_inference/test_diffusion_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import sys
import threading
import time
from pathlib import Path

import pytest
import torch

from tests.utils import GPUMemoryMonitor
from vllm_omni.utils.platform_utils import is_npu, is_rocm

# ruff: noqa: E402
Expand All @@ -15,39 +14,6 @@

from vllm_omni import Omni


class GPUMemoryMonitor:
"""Poll global device memory usage via CUDA APIs."""

def __init__(self, device_index: int, interval: float = 0.05):
self.device_index = device_index
self.interval = interval
self.peak_used_mb = 0.0
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None

def start(self) -> None:
def monitor_loop() -> None:
while not self._stop_event.is_set():
try:
with torch.cuda.device(self.device_index):
free_bytes, total_bytes = torch.cuda.mem_get_info()
used_mb = (total_bytes - free_bytes) / (1024**2)
self.peak_used_mb = max(self.peak_used_mb, used_mb)
except Exception:
pass
time.sleep(self.interval)

self._thread = threading.Thread(target=monitor_loop, daemon=True)
self._thread.start()

def stop(self) -> None:
if self._thread is None:
return
self._stop_event.set()
self._thread.join(timeout=2.0)


models = ["riverclouds/qwen_image_random"]


Expand All @@ -73,13 +39,7 @@ def inference(offload: bool = True):
generator=torch.Generator("cuda").manual_seed(42),
)

monitor.stop()
torch.cuda.synchronize(device_index)
fallback_alloc = torch.cuda.max_memory_allocated(device=device_index) / (1024**2)
fallback_reserved = torch.cuda.max_memory_reserved(device=device_index) / (1024**2)
peak_memory_mb = max(monitor.peak_used_mb, fallback_alloc, fallback_reserved)

return peak_memory_mb
return monitor.peak_used_mb

offload_peak_memory = inference(offload=True)
no_offload_peak_memory = inference(offload=False)
Expand Down
22 changes: 18 additions & 4 deletions tests/e2e/offline_inference/test_zimage_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from tests.utils import GPUMemoryMonitor
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.outputs import OmniRequestOutput
Expand Down Expand Up @@ -66,7 +67,12 @@ def _extract_single_image(outputs) -> Image.Image:

def _run_zimage_generate(
*, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int
) -> tuple[Image.Image, float]:
) -> tuple[Image.Image, float, float]:
torch.cuda.empty_cache()
device_index = torch.cuda.current_device()
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
Comment on lines +71 to +74

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reset CUDA peak stats before collecting TP memory

GPUMemoryMonitor.peak_used_mb falls back to torch.cuda.max_memory_allocated/reserved, which are process‑wide peaks and are not reset by empty_cache(). Since _run_zimage_generate is invoked twice in the same process, the TP=2 run will inherit the TP=1 peak and can never be lower even if it actually uses less memory, making the new assertion flaky. Consider calling torch.cuda.reset_peak_memory_stats(device_index) before starting the monitor (or dropping the max_memory fallback) so each run measures its own peak.

Useful? React with 👍 / 👎.


m = Omni(
model=_get_zimage_model(),
parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size),
Expand Down Expand Up @@ -107,7 +113,10 @@ def _run_zimage_generate(
pass

median_time_s = float(np.median(per_request_times_s))
return _extract_single_image([last_output]), median_time_s

peak_memory_mb = monitor.peak_used_mb

return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
m.close()
cleanup_dist_env_and_memory()
Expand All @@ -125,14 +134,14 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
num_inference_steps = 2
seed = 42

tp1_img, tp1_time_s = _run_zimage_generate(
tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate(
tp_size=1,
height=height,
width=width,
num_inference_steps=num_inference_steps,
seed=seed,
)
tp2_img, tp2_time_s = _run_zimage_generate(
tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate(
tp_size=2,
height=height,
width=width,
Expand Down Expand Up @@ -164,3 +173,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):

print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}")
assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"

print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}")
assert tp2_peak_mem < tp1_peak_mem, (
f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})"
)
43 changes: 43 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import subprocess
import sys
import tempfile
import threading
import time
from collections.abc import Callable
from contextlib import ExitStack, contextmanager, suppress
from typing import Any, Literal

import cloudpickle
import pytest
import torch
from typing_extensions import ParamSpec
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
Expand Down Expand Up @@ -474,3 +476,44 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return func

return wrapper


class GPUMemoryMonitor:
"""Poll global device memory usage via CUDA APIs."""

def __init__(self, device_index: int, interval: float = 0.05):
self.device_index = device_index
self.interval = interval
self._peak_used_mb = 0.0
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None

def start(self) -> None:
def monitor_loop() -> None:
while not self._stop_event.is_set():
try:
with torch.cuda.device(self.device_index):
free_bytes, total_bytes = torch.cuda.mem_get_info()
used_mb = (total_bytes - free_bytes) / (1024**2)
self._peak_used_mb = max(self._peak_used_mb, used_mb)
except Exception:
pass
time.sleep(self.interval)

self._thread = threading.Thread(target=monitor_loop, daemon=True)
self._thread.start()

def stop(self) -> None:
if self._thread is None:
return
self._stop_event.set()
self._thread.join(timeout=2.0)

@property
def peak_used_mb(self) -> float:
fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2)
fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2)
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)

def __del__(self):
self.stop()
Loading