Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/user_guide/diffusion/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@ The following table shows which models are currently supported by parallelism me
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ❌ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ❌ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ❌ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ | ❌ | ✅ (TP=2 only) |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ |


!!! note "TP Limitations for Diffusion Models"
We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.

- Good news: The text_encoder typically has minimal impact on overall inference performance.
- Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste.

We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771).


!!! note "Why Z-Image is TP=2 only"
Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints.
Comment thread
hsliuustc0106 marked this conversation as resolved.
For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**.
Expand Down
13 changes: 11 additions & 2 deletions examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for ring sequence parallelism.",
)
parser.add_argument(
"--tensor_parallel_size",
type=int,
default=1,
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.")
parser.add_argument(
"--resolution",
Expand Down Expand Up @@ -301,7 +307,10 @@ def main():
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()
parallel_config = DiffusionParallelConfig(
ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size
ulysses_degree=args.ulysses_degree,
ring_degree=args.ring_degree,
cfg_parallel_size=args.cfg_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
)

# Configure cache based on backend type
Expand Down Expand Up @@ -351,7 +360,7 @@ def main():
else:
print(f" Input image size: {input_image.size}")
print(
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}"
)
print(f"{'=' * 60}\n")

Expand Down
44 changes: 2 additions & 42 deletions tests/e2e/offline_inference/test_diffusion_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import sys
import threading
import time
from pathlib import Path

import pytest
import torch

from tests.utils import GPUMemoryMonitor
from vllm_omni.utils.platform_utils import is_npu, is_rocm

# ruff: noqa: E402
Expand All @@ -15,39 +14,6 @@

from vllm_omni import Omni


class GPUMemoryMonitor:
"""Poll global device memory usage via CUDA APIs."""

def __init__(self, device_index: int, interval: float = 0.05):
self.device_index = device_index
self.interval = interval
self.peak_used_mb = 0.0
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None

def start(self) -> None:
def monitor_loop() -> None:
while not self._stop_event.is_set():
try:
with torch.cuda.device(self.device_index):
free_bytes, total_bytes = torch.cuda.mem_get_info()
used_mb = (total_bytes - free_bytes) / (1024**2)
self.peak_used_mb = max(self.peak_used_mb, used_mb)
except Exception:
pass
time.sleep(self.interval)

self._thread = threading.Thread(target=monitor_loop, daemon=True)
self._thread.start()

def stop(self) -> None:
if self._thread is None:
return
self._stop_event.set()
self._thread.join(timeout=2.0)


models = ["riverclouds/qwen_image_random"]


Expand All @@ -73,13 +39,7 @@ def inference(offload: bool = True):
generator=torch.Generator("cuda").manual_seed(42),
)

monitor.stop()
torch.cuda.synchronize(device_index)
fallback_alloc = torch.cuda.max_memory_allocated(device=device_index) / (1024**2)
fallback_reserved = torch.cuda.max_memory_reserved(device=device_index) / (1024**2)
peak_memory_mb = max(monitor.peak_used_mb, fallback_alloc, fallback_reserved)

return peak_memory_mb
return monitor.peak_used_mb

offload_peak_memory = inference(offload=True)
no_offload_peak_memory = inference(offload=False)
Expand Down
22 changes: 18 additions & 4 deletions tests/e2e/offline_inference/test_zimage_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from tests.utils import GPUMemoryMonitor
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.outputs import OmniRequestOutput
Expand Down Expand Up @@ -66,7 +67,12 @@ def _extract_single_image(outputs) -> Image.Image:

def _run_zimage_generate(
*, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int
) -> tuple[Image.Image, float]:
) -> tuple[Image.Image, float, float]:
torch.cuda.empty_cache()
device_index = torch.cuda.current_device()
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
Comment on lines +71 to +74
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reset CUDA peak stats before collecting TP memory

GPUMemoryMonitor.peak_used_mb falls back to torch.cuda.max_memory_allocated/reserved, which are process‑wide peaks and are not reset by empty_cache(). Since _run_zimage_generate is invoked twice in the same process, the TP=2 run will inherit the TP=1 peak and can never be lower even if it actually uses less memory, making the new assertion flaky. Consider calling torch.cuda.reset_peak_memory_stats(device_index) before starting the monitor (or dropping the max_memory fallback) so each run measures its own peak.

Useful? React with 👍 / 👎.


m = Omni(
model=_get_zimage_model(),
parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size),
Expand Down Expand Up @@ -107,7 +113,10 @@ def _run_zimage_generate(
pass

median_time_s = float(np.median(per_request_times_s))
return _extract_single_image([last_output]), median_time_s

peak_memory_mb = monitor.peak_used_mb

return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
m.close()
cleanup_dist_env_and_memory()
Expand All @@ -125,14 +134,14 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
num_inference_steps = 2
seed = 42

tp1_img, tp1_time_s = _run_zimage_generate(
tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate(
tp_size=1,
height=height,
width=width,
num_inference_steps=num_inference_steps,
seed=seed,
)
tp2_img, tp2_time_s = _run_zimage_generate(
tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate(
tp_size=2,
height=height,
width=width,
Expand Down Expand Up @@ -164,3 +173,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):

print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}")
assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"

print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}")
assert tp2_peak_mem < tp1_peak_mem, (
f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})"
)
43 changes: 43 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import subprocess
import sys
import tempfile
import threading
import time
from collections.abc import Callable
from contextlib import ExitStack, contextmanager, suppress
from typing import Any, Literal

import cloudpickle
import pytest
import torch
from typing_extensions import ParamSpec
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
Expand Down Expand Up @@ -474,3 +476,44 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return func

return wrapper


class GPUMemoryMonitor:
"""Poll global device memory usage via CUDA APIs."""

def __init__(self, device_index: int, interval: float = 0.05):
self.device_index = device_index
self.interval = interval
self._peak_used_mb = 0.0
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None

def start(self) -> None:
def monitor_loop() -> None:
while not self._stop_event.is_set():
try:
with torch.cuda.device(self.device_index):
free_bytes, total_bytes = torch.cuda.mem_get_info()
used_mb = (total_bytes - free_bytes) / (1024**2)
self._peak_used_mb = max(self._peak_used_mb, used_mb)
except Exception:
pass
time.sleep(self.interval)

self._thread = threading.Thread(target=monitor_loop, daemon=True)
self._thread.start()

def stop(self) -> None:
if self._thread is None:
return
self._stop_event.set()
self._thread.join(timeout=2.0)

@property
def peak_used_mb(self) -> float:
fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2)
fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2)
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)

def __del__(self):
self.stop()
Loading