From 35a98fb248dbffa10cf2f567edef23a34cc7c8aa Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sun, 18 Jan 2026 15:23:17 +0800 Subject: [PATCH 1/4] init and update test Signed-off-by: zjy0516 --- .../test_diffusion_cpu_offload.py | 44 +--- .../test_zimage_tensor_parallel.py | 22 +- tests/utils.py | 43 +++ .../qwen_image/qwen_image_transformer.py | 244 ++++++++++-------- 4 files changed, 206 insertions(+), 147 deletions(-) diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py index 0066d49b161..cefda891571 100644 --- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -1,11 +1,10 @@ import sys -import threading -import time from pathlib import Path import pytest import torch +from tests.utils import GPUMemoryMonitor from vllm_omni.utils.platform_utils import is_npu, is_rocm # ruff: noqa: E402 @@ -15,39 +14,6 @@ from vllm_omni import Omni - -class GPUMemoryMonitor: - """Poll global device memory usage via CUDA APIs.""" - - def __init__(self, device_index: int, interval: float = 0.05): - self.device_index = device_index - self.interval = interval - self.peak_used_mb = 0.0 - self._stop_event = threading.Event() - self._thread: threading.Thread | None = None - - def start(self) -> None: - def monitor_loop() -> None: - while not self._stop_event.is_set(): - try: - with torch.cuda.device(self.device_index): - free_bytes, total_bytes = torch.cuda.mem_get_info() - used_mb = (total_bytes - free_bytes) / (1024**2) - self.peak_used_mb = max(self.peak_used_mb, used_mb) - except Exception: - pass - time.sleep(self.interval) - - self._thread = threading.Thread(target=monitor_loop, daemon=True) - self._thread.start() - - def stop(self) -> None: - if self._thread is None: - return - self._stop_event.set() - self._thread.join(timeout=2.0) - - models = ["riverclouds/qwen_image_random"] @@ -73,13 +39,7 @@ def inference(offload: bool = True): generator=torch.Generator("cuda").manual_seed(42), ) - monitor.stop() - torch.cuda.synchronize(device_index) - fallback_alloc = torch.cuda.max_memory_allocated(device=device_index) / (1024**2) - fallback_reserved = torch.cuda.max_memory_reserved(device=device_index) / (1024**2) - peak_memory_mb = max(monitor.peak_used_mb, fallback_alloc, fallback_reserved) - - return peak_memory_mb + return monitor.peak_used_mb offload_peak_memory = inference(offload=True) no_offload_peak_memory = inference(offload=False) diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index d32bb2b8223..60686992278 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -17,6 +17,7 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +from tests.utils import GPUMemoryMonitor from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.outputs import OmniRequestOutput @@ -66,7 +67,12 @@ def _extract_single_image(outputs) -> Image.Image: def _run_zimage_generate( *, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int -) -> tuple[Image.Image, float]: +) -> tuple[Image.Image, float, float]: + torch.cuda.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + m = Omni( model=_get_zimage_model(), parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size), @@ -107,7 +113,10 @@ def _run_zimage_generate( pass median_time_s = float(np.median(per_request_times_s)) - return _extract_single_image([last_output]), median_time_s + + peak_memory_mb = monitor.peak_used_mb + + return _extract_single_image([last_output]), median_time_s, peak_memory_mb finally: m.close() cleanup_dist_env_and_memory() @@ -125,14 +134,14 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): num_inference_steps = 2 seed = 42 - tp1_img, tp1_time_s = _run_zimage_generate( + tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate( tp_size=1, height=height, width=width, num_inference_steps=num_inference_steps, seed=seed, ) - tp2_img, tp2_time_s = _run_zimage_generate( + tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate( tp_size=2, height=height, width=width, @@ -164,3 +173,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}") assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})" + + print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}") + assert tp2_peak_mem < tp1_peak_mem, ( + f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})" + ) diff --git a/tests/utils.py b/tests/utils.py index 2a2dca238a8..8e5593d6501 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,6 +7,7 @@ import subprocess import sys import tempfile +import threading import time from collections.abc import Callable from contextlib import ExitStack, contextmanager, suppress @@ -14,6 +15,7 @@ import cloudpickle import pytest +import torch from typing_extensions import ParamSpec from vllm.platforms import current_platform from vllm.utils.torch_utils import cuda_device_count_stateless @@ -474,3 +476,44 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: return func return wrapper + + +class GPUMemoryMonitor: + """Poll global device memory usage via CUDA APIs.""" + + def __init__(self, device_index: int, interval: float = 0.05): + self.device_index = device_index + self.interval = interval + self._peak_used_mb = 0.0 + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + def monitor_loop() -> None: + while not self._stop_event.is_set(): + try: + with torch.cuda.device(self.device_index): + free_bytes, total_bytes = torch.cuda.mem_get_info() + used_mb = (total_bytes - free_bytes) / (1024**2) + self._peak_used_mb = max(self._peak_used_mb, used_mb) + except Exception: + pass + time.sleep(self.interval) + + self._thread = threading.Thread(target=monitor_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + self._stop_event.set() + self._thread.join(timeout=2.0) + + @property + def peak_used_mb(self) -> float: + fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2) + fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2) + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) + + def __del__(self): + self.stop() diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 8ac5014ce89..cbf0b7e10ac 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from diffusers.models.attention import FeedForward +import torch.nn.functional as F # TODO replace this with vLLM implementation from diffusers.models.embeddings import TimestepEmbedding, Timesteps @@ -16,7 +16,11 @@ from diffusers.models.normalization import AdaLayerNormContinuous from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.backends.abstract import ( @@ -287,79 +291,133 @@ def _compute_video_freqs(self, frame, height, width, idx=0): return freqs.clone().contiguous() +class ColumnParallelApproxGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + ) -> None: + super().__init__() + + assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported." + + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + layers: list[nn.Module] = [ + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Identity(), # placeholder for weight loading + RowParallelLinear( + inner_dim, + dim_out, + input_is_parallel=True, + return_bias=False, + ), + ] + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + class QwenImageCrossAttention(nn.Module): def __init__( self, dim: int, # query_dim num_heads: int, head_dim: int, - window_size=(-1, -1), - added_kv_proj_dim: int = None, + added_kv_proj_dim: int, + window_size: tuple[int, int] = (-1, -1), out_bias: bool = True, - qk_norm=True, # rmsnorm - eps=1e-6, - pre_only=False, + qk_norm: bool = True, + eps: float = 1e-6, + pre_only: bool = False, context_pre_only: bool = False, - parallel_attention=False, - out_dim: int = None, + out_dim: int | None = None, ) -> None: - assert dim % num_heads == 0 super().__init__() + assert dim % num_heads == 0 + self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps - self.parallel_attention = parallel_attention - # layers - # self.to_q = ReplicatedLinear(dim, dim) - # self.to_k = ReplicatedLinear(dim, dim) - # self.to_v = ReplicatedLinear(dim, dim) self.to_qkv = QKVParallelLinear( hidden_size=dim, head_size=self.head_dim, total_num_heads=num_heads, - disable_tp=True, ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() - self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads - self.inner_kv_dim = self.inner_dim - if added_kv_proj_dim is not None: - assert context_pre_only is not None - # self.add_k_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True) - # self.add_v_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True) - # self.add_q_proj = ReplicatedLinear( - # added_kv_proj_dim, self.inner_dim, bias=True - # ) - self.add_kv_proj = QKVParallelLinear( - added_kv_proj_dim, - head_size=self.inner_kv_dim // self.num_heads, - total_num_heads=self.num_heads, - disable_tp=True, - ) - if context_pre_only is not None and not context_pre_only: - self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) - else: - self.to_add_out = None + self.inner_dim = out_dim if out_dim is not None else head_dim * self.total_num_heads - if not pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)) - else: - self.to_out = None + assert context_pre_only is not None + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=head_dim, + total_num_heads=num_heads, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + + assert not context_pre_only + self.to_add_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + assert not pre_only + self.to_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) self.norm_added_q = RMSNorm(head_dim, eps=eps) self.norm_added_k = RMSNorm(head_dim, eps=eps) self.attn = Attention( - num_heads=num_heads, + num_heads=self.query_num_heads, head_size=self.head_dim, softmax_scale=1.0 / (self.head_dim**0.5), causal=False, + num_kv_heads=self.kv_num_heads, ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -377,61 +435,55 @@ def forward( txt_freqs: torch.Tensor, hidden_states_mask: torch.Tensor | None = None, encoder_hidden_states_mask: torch.Tensor | None = None, - ): - # if mask is all true, set it to None + ) -> tuple[torch.Tensor, torch.Tensor]: if hidden_states_mask is not None and hidden_states_mask.all(): hidden_states_mask = None if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.all(): encoder_hidden_states_mask = None - seq_len_txt = encoder_hidden_states.shape[1] - # Compute QKV for image stream (sample projections) - qkv, _ = self.to_qkv(hidden_states) - img_query, img_key, img_value = qkv.chunk(3, dim=-1) + img_qkv, _ = self.to_qkv(hidden_states) + q_size = self.query_num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1) - # Compute QKV for text stream (context projections) - qkv, _ = self.add_kv_proj(encoder_hidden_states) - txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1) + txt_qkv, _ = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_query_num_heads * self.head_dim + add_kv_size = self.add_kv_num_heads * self.head_dim + txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1) - # Reshape for multi-head attention - img_query = img_query.unflatten(-1, (self.num_heads, -1)) - img_key = img_key.unflatten(-1, (self.num_heads, -1)) - img_value = img_value.unflatten(-1, (self.num_heads, -1)) + img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim)) + img_key = img_key.unflatten(-1, (self.kv_num_heads, self.head_dim)) + img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim)) - txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) - txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) - txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim)) + txt_key = txt_key.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) - # Apply QK normalization img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - # Apply RoPE img_cos = vid_freqs.real.to(img_query.dtype) img_sin = vid_freqs.imag.to(img_query.dtype) txt_cos = txt_freqs.real.to(txt_query.dtype) txt_sin = txt_freqs.imag.to(txt_query.dtype) + img_query = self.rope(img_query, img_cos, img_sin) img_key = self.rope(img_key, img_cos, img_sin) txt_query = self.rope(txt_query, txt_cos, txt_sin) txt_key = self.rope(txt_key, txt_cos, txt_sin) - # Concatenate for joint attention - # Order: [text, image] + seq_len_txt = encoder_hidden_states.shape[1] joint_query = torch.cat([txt_query, img_query], dim=1) joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention if ( self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1 and not get_forward_context().split_text_embed_in_sp ): - # if using sequence parallel, but not splitting text embed, - # we need to pass text embedding to attention layer as joint qkv attn_metadata = AttentionMetadata( joint_query=txt_query, joint_key=txt_key, @@ -443,22 +495,17 @@ def forward( if encoder_hidden_states_mask is not None: attn_metadata.joint_attn_mask = encoder_hidden_states_mask - joint_hidden_states = self.attn( - img_query, - img_key, - img_value, - attn_metadata, - ) + joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata) else: attn_metadata = None if hidden_states_mask is not None or encoder_hidden_states_mask is not None: - mask_list = [] + mask_list: list[torch.Tensor] = [] if encoder_hidden_states_mask is not None: mask_list.append(encoder_hidden_states_mask) else: mask_list.append( torch.ones( - [encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]], + encoder_hidden_states.shape[:2], dtype=torch.bool, device=encoder_hidden_states.device, ) @@ -468,34 +515,22 @@ def forward( else: mask_list.append( torch.ones( - [hidden_states.shape[0], hidden_states.shape[1]], + hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device, ) ) - joint_mask = ( - None if len(mask_list) == 0 else torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] - ) + joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] attn_metadata = AttentionMetadata(attn_mask=joint_mask) - joint_hidden_states = self.attn( - joint_query, - joint_key, - joint_value, - attn_metadata, - ) - joint_hidden_states = joint_hidden_states.flatten(2, 3) - joint_hidden_states = joint_hidden_states.to(joint_query.dtype) - # Split attention outputs back - txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] # Text part - img_attn_output = joint_hidden_states[:, seq_len_txt:, :] # Image part + joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata) - # Apply output projections - img_attn_output, _ = self.to_out[0](img_attn_output) - if len(self.to_out) > 1: - (img_attn_output,) = self.to_out[1](img_attn_output) # dropout + joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype) + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] - txt_attn_output, _ = self.to_add_out(txt_attn_output) + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) return img_attn_output, txt_attn_output @@ -530,7 +565,7 @@ def __init__( head_dim=attention_head_dim, ) self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.img_mlp = FeedForward(dim=dim, dim_out=dim) # Text processing modules self.txt_mod = nn.Sequential( @@ -540,7 +575,7 @@ def __init__( self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.txt_mlp = FeedForward(dim=dim, dim_out=dim) self.zero_cond_t = zero_cond_t @@ -892,6 +927,8 @@ def get_rotary_emb_chunk(freqs, padding=0): if original_seq_len is not None: output = output[:, :original_seq_len, :] + torch.cuda.empty_cache() + return Transformer2DModelOutput(sample=output) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -916,17 +953,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + original_name = name + lookup_name = name for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in original_name: continue - name = name.replace(weight_name, param_name) - param = params_dict[name] + lookup_name = original_name.replace(weight_name, param_name) + param = params_dict[lookup_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - param = params_dict[name] + if lookup_name not in params_dict and ".to_out.0." in lookup_name: + lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + param = params_dict[lookup_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + loaded_params.add(original_name) + loaded_params.add(lookup_name) return loaded_params From 588ba5e013b4a4aaf808a5fef98f13ee4c2220e2 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sun, 18 Jan 2026 17:38:20 +0800 Subject: [PATCH 2/4] update doc Signed-off-by: zjy0516 --- docs/user_guide/diffusion/parallelism_acceleration.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index dfacd2183ff..59741b37fef 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -23,13 +23,21 @@ The following table shows which models are currently supported by parallelism me | **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ❌ | | **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ❌ | | **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ❌ | | **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ❌ | | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ❌ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ | ❌ | ✅ (TP=2 only) | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | + +!!! note "TP Limitations for Diffusion Models" + We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the text_encoder component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. + - Good news: The text_encoder typically has minimal impact on overall inference performance. + - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste. + We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771). + + !!! note "Why Z-Image is TP=2 only" Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints. For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**. From f5b90f65171979094d7e64371a24c327dcea0ac9 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sun, 18 Jan 2026 17:47:44 +0800 Subject: [PATCH 3/4] update doc Signed-off-by: zjy0516 --- docs/user_guide/diffusion/parallelism_acceleration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 59741b37fef..fdc370a07e2 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -33,8 +33,10 @@ The following table shows which models are currently supported by parallelism me !!! note "TP Limitations for Diffusion Models" We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the text_encoder component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. + - Good news: The text_encoder typically has minimal impact on overall inference performance. - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste. + We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771). From 03f1f2f7df82fc0dbaa3663b1def6795eec36881 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sun, 18 Jan 2026 20:15:17 +0800 Subject: [PATCH 4/4] update Signed-off-by: zjy0516 --- .../diffusion/parallelism_acceleration.md | 2 +- .../offline_inference/image_to_image/image_edit.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index fdc370a07e2..324301158d8 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -32,7 +32,7 @@ The following table shows which models are currently supported by parallelism me !!! note "TP Limitations for Diffusion Models" - We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the text_encoder component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. + We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. - Good news: The text_encoder typically has minimal impact on overall inference performance. - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste. diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 5d2b1052bec..c31d098252b 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -181,6 +181,12 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", + ) parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.") parser.add_argument( "--resolution", @@ -301,7 +307,10 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() parallel_config = DiffusionParallelConfig( - ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size + ulysses_degree=args.ulysses_degree, + ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, ) # Configure cache based on backend type @@ -351,7 +360,7 @@ def main(): else: print(f" Input image size: {input_image.size}") print( - f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}" ) print(f"{'=' * 60}\n")