From ef7c35a7677ae205581c88baf1bc1bf1e938f27f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:20:30 +0800 Subject: [PATCH 01/85] udpate diffusion config Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 70 ++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 8210bedab51..36bf24b60a8 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -6,15 +6,67 @@ import random from collections.abc import Callable from dataclasses import dataclass, field, fields -from typing import Any +from pydantic import Field, model_validator +from typing import Any +from typing_extensions import Self import torch from vllm.logger import init_logger +from vllm.config.utils import config from vllm_omni.diffusion.utils.network_utils import is_port_available logger = init_logger(__name__) +@config +@dataclass +class DiffusionParallelConfig: + """Configuration for diffusion model distributed execution.""" + + pipeline_parallel_size: int = 1 + """Number of pipeline parallel stages.""" + + data_parallel_size: int = 1 + """Number of data parallel groups.""" + + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + + sequence_parallel_size: int = 1 + """Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree""" + + ulysses_degree: int = 1 + """Number of GPUs used for ulysses sequence parallelism.""" + + ring_degree: int = 1 + """Number of GPUs used for ring sequence parallelism.""" + + cfg_parallel_size: int = 1 + """Number of Classifier Free Guidance (CFG) parallel groups.""" + + @model_validator(mode="after") + def _validate_parallel_config(self) -> Self: + """Validates the config relationships among the parallel strategies.""" + assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0" + assert self.data_parallel_size > 0, "Data parallel size must be > 0" + assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0" + assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0" + assert self.ulysses_degree > 0, "Ulysses degree must be > 0" + assert self.ring_degree > 0, "Ring degree must be > 0" + assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" + assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, "Sequence parallel size must be equal to the product of ulysses degree and ring degree, but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" + return self + + def __post_init__(self) -> None: + + self.world_size = ( + self.pipeline_parallel_size + * self.data_parallel_size + * self.tensor_parallel_size + * self.ulysses_degree + * self.ring_degree + * self.cfg_parallel_size + ) @dataclass class TransformerConfig: @@ -180,6 +232,7 @@ class OmniDiffusionConfig: # Cache strategy (legacy) cache_strategy: str = "none" + parallel_config: DiffusionParallelConfig = Field(default_factory=DiffusionParallelConfig) # Cache backend configuration (NEW) cache_backend: str = "none" # "tea_cache", "deep_cache", etc. @@ -193,21 +246,6 @@ class OmniDiffusionConfig: trust_remote_code: bool = False revision: str | None = None - # Parallelism - num_gpus: int = 1 - tp_size: int = -1 - sp_degree: int = -1 - # sequence parallelism - ulysses_degree: int | None = None - ring_degree: int | None = None - # data parallelism - # number of data parallelism groups - dp_size: int = 1 - # number of gpu in a dp group - dp_degree: int = 1 - # cfg parallel - enable_cfg_parallel: bool = False - hsdp_replicate_dim: int = 1 hsdp_shard_dim: int = -1 dist_timeout: int | None = None # timeout for torch.distributed From 641b62649d66e3ff4056696837f9fc939d52555a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:52:25 +0800 Subject: [PATCH 02/85] update usp Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 127 ++ vllm_omni/diffusion/distributed/comm.py | 257 ++++ .../distributed/group_coordinator.py | 1082 +++++++++++++++++ .../diffusion/distributed/parallel_state.py | 737 +++++++++++ vllm_omni/diffusion/envs.py | 269 ++++ vllm_omni/diffusion/worker/gpu_worker.py | 17 +- 6 files changed, 2484 insertions(+), 5 deletions(-) create mode 100644 vllm_omni/diffusion/distributed/comm.py create mode 100644 vllm_omni/diffusion/distributed/group_coordinator.py create mode 100644 vllm_omni/diffusion/distributed/parallel_state.py create mode 100644 vllm_omni/diffusion/envs.py diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 85d3da80265..c141a8b7df9 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team & Jiarui Fang +# Adapted from +# https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py import torch import torch.nn as nn @@ -9,6 +14,11 @@ ) from vllm_omni.diffusion.attention.selector import get_attn_backend +from typing import Any +from torch import Tensor +from yunchang.kernels import AttnType, select_flash_attn_impl # FIXME: replace it with vllm-omni attention +import torch.distributed as dist +from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D class Attention(nn.Module): def __init__( @@ -41,3 +51,120 @@ def forward( # shape: (batch_size, seq_len, num_heads, head_size) attn_output = self.attention.forward(query, key, value, attn_metadata) return attn_output + + +class UlyssesAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + use_sync (bool): whether to synchronize after all-to-all. This flag can save cuda memory but will slow down the speed. + attn_type (AttnType): attention type enum + """ + + def __init__( + self, + sequence_process_group: dist.ProcessGroup = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + attn_type : AttnType = AttnType.FA, + ) -> None: + + super(UlyssesAttention, self).__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.use_sync = use_sync + self.attn_type = attn_type + + try: + import torch_npu + device = torch.device("npu") + except: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + gpu_name = torch.cuda.get_device_name(device) + if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: + self.attn_type = AttnType.TORCH + self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd") + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *args: Any + ) -> Tensor: + """forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ + # TODO Merge three alltoall calls into one + # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! + # in shape : e.g., [s/p:h:] + # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) + + # scatter 2, gather 1 + q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync) + k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync) + v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + + if self.attn_type is AttnType.NPU: + context_layer = self.attn_fn( + q, + k, + v, + num_heads = q.shape[-2], + input_layout = "BSND", + scale = softmax_scale, + softmax_lse_flag = True, + pre_tokens=65535, + next_tokens=65535, + ) + else: + context_layer = self.attn_fn( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale = softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + + if isinstance(context_layer, tuple): + context_layer = context_layer[0] + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + output = SeqAllToAll4D.apply( + self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync + ) + + # out e.g., [s/p::h] + return output diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py new file mode 100644 index 00000000000..8d031463768 --- /dev/null +++ b/vllm_omni/diffusion/distributed/comm.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team & Jiarui Fang +# from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/all_to_all.py +import torch + +from typing import Any, Tuple +from torch import Tensor + +import torch.distributed as dist + + +def all_to_all_4D( + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = ( + input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) + .transpose(0, 2) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + use_sync: bool = False, + ) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply( + ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync + ), + None, + None, + None, + ) + + +def all_to_all_5D( + input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs) + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync: whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) + """ + assert ( + input.dim() == 5 + ), f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 3 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs) + bs, shard_seqlen, t_cnt, hc, hs = input.shape + + assert t_cnt == 3 + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs) + input_t = ( + input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs) + .transpose(0, 3) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, 3, bs, shard_hc, hs) + + # (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs) + output = output.transpose(0, 2).transpose(1, 2).contiguous() + + return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous() + elif scatter_idx == 1 and gather_idx == 3: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, _, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs) + .transpose(0, 4) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, 3, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 3).contiguous() + + return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous() + else: + raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3") + + +class SeqAllToAll5D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int = 3, + gather_idx: int = 1, + use_sync: bool = False, + ) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + + return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll5D.apply( + ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync + ), + None, + None, + None, + ) \ No newline at end of file diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py new file mode 100644 index 00000000000..00a7b8625e5 --- /dev/null +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -0,0 +1,1082 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch.distributed +from torch.cuda import synchronize +from torch.distributed import Backend, ProcessGroup + +try: + import torch_musa + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + +from vllm_omni.diffusion.envs import envs +if envs._is_npu(): + print("torch.npu synchronize") + from torch.npu import synchronize + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + +env_info = envs.PACKAGES_CHECKER.get_packages_info() + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = envs.get_device(local_rank) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, op=op, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class PipelineGroupCoordinator(GroupCoordinator): + """ + available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + difference between `local_rank` and `rank_in_group`: + if we have a group of size 4 across two nodes: + Process | Node | Rank | Local Rank | Rank in Group + 0 | 0 | 0 | 0 | 0 + 1 | 0 | 1 | 1 | 1 + 2 | 1 | 2 | 0 | 2 + 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.cpu_groups = [] + self.device_groups = [] + if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + # when pipeline parallelism is 2, we need to create two groups to avoid + # communication stall. + # *_group_0_1 represents the group for communication from device 0 to + # device 1. + # *_group_1_0 represents the group for communication from device 1 to + # device 0. + elif len(group_ranks[0]) == 2: + for ranks in group_ranks: + device_group_0_1 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + device_group_1_0 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") + cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_groups = [device_group_0_1, device_group_1_0] + self.cpu_groups = [cpu_group_0_1, cpu_group_1_0] + self.device_group = device_group_0_1 + self.cpu_group = cpu_group_0_1 + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = envs.get_device(local_rank) + + self.recv_buffer_set: bool = False + self.recv_tasks_queue: List[Tuple[str, int]] = [] + self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.dtype: Optional[torch.dtype] = None + self.num_pipefusion_patches: Optional[int] = None + + self.recv_shape: Dict[str, Dict[int, torch.Size]] = {} + self.send_shape: Dict[str, Dict[int, torch.Size]] = {} + self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {} + + self.skip_tensor_recv_buffer_set: bool = False + self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = [] + self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: Optional[ + Union[List[torch.Tensor], torch.Tensor] + ] = None + self.skip_device_group = None + for ranks in group_ranks: + skip_device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + if self.rank in ranks: + self.skip_device_group = skip_device_group + assert self.skip_device_group is not None + + def reset_buffer(self): + self.recv_tasks_queue = [] + self.receiving_tasks = [] + self.recv_shape = {} + self.send_shape = {} + self.recv_buffer = {} + + self.recv_skip_tasks_queue = [] + self.receiving_skip_tasks = [] + self.skip_tensor_recv_buffer = {} + + def set_config(self, dtype: torch.dtype): + self.dtype = dtype + + def set_recv_buffer( + self, + num_pipefusion_patches: int, + patches_shape_list: List[List[int]], + feature_map_shape: List[int], + dtype: torch.dtype, + ): + assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" + assert ( + isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1 + ), "num_pipefusion_patches must be greater than or equal to 1" + self.dtype = dtype + self.num_pipefusion_patches = num_pipefusion_patches + self.recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) + for shape in patches_shape_list + ] + self.recv_buffer.append( + torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) + ) + self.recv_buffer_set = True + + def set_extra_tensors_recv_buffer( + self, + name: str, + shape: List[int], + num_buffers: int = 1, + dtype: torch.dtype = torch.float16, + ): + self.extra_tensors_recv_buffer[name] = [ + torch.zeros(*shape, dtype=dtype, device=self.device) + for _ in range(num_buffers) + ] + + def _check_shape_and_buffer( + self, + tensor_send_to_next=None, + recv_prev=False, + name: Optional[str] = None, + segment_idx: int = 0, + ): + send_flag = False + name = name or "latent" + if tensor_send_to_next is not None: + shape_list = self.send_shape.get(name, None) + if shape_list is None: + self.send_shape[name] = {segment_idx: tensor_send_to_next.shape} + send_flag = True + elif shape_list.get(segment_idx, None) is None: + self.send_shape[name][segment_idx] = tensor_send_to_next.shape + send_flag = True + + recv_flag = False + if recv_prev: + shape_list = self.recv_shape.get(name, None) + if shape_list is None: + recv_flag = True + elif shape_list.get(segment_idx, None) is None: + recv_flag = True + + recv_prev_shape = self._communicate_shapes( + tensor_send_to_next=tensor_send_to_next if send_flag else None, + recv_prev=recv_flag, + ) + + if recv_flag: + if self.recv_shape.get(name, None) is None: + self.recv_shape[name] = {segment_idx: recv_prev_shape} + else: + self.recv_shape[name][segment_idx] = recv_prev_shape + + if self.recv_buffer.get(name, None) is None: + self.recv_buffer[name] = { + segment_idx: torch.zeros( + recv_prev_shape, device=self.device, dtype=self.dtype + ) + } + else: + if self.recv_buffer[name].get(segment_idx, None) is not None: + logger.warning( + f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating..." + ) + self.recv_buffer[name][segment_idx] = torch.zeros( + recv_prev_shape, device=self.device, dtype=self.dtype + ) + + def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + + Args: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + """ + + ops = [] + if recv_prev: + recv_prev_dim_tensor = torch.empty( + (1), device=self.device, dtype=torch.int64 + ) + recv_prev_dim_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_dim_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_dim_op) + + if tensor_send_to_next is not None: + send_next_dim_tensor = torch.tensor( + tensor_send_to_next.dim(), device=self.device, dtype=torch.int64 + ) + send_next_dim_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_dim_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_dim_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + synchronize() + + ops = [] + recv_prev_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty( + torch.Size(recv_prev_dim_tensor), device=self.device, dtype=torch.int64 + ) + recv_prev_shape_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_shape_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_shape_op) + + if tensor_send_to_next is not None: + send_next_shape_tensor = torch.tensor( + tensor_send_to_next.size(), device=self.device, dtype=torch.int64 + ) + send_next_shape_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_shape_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_shape_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor + return torch.Size(recv_prev_shape) + + def pipeline_send( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer( + tensor_send_to_next=tensor, name=name, segment_idx=segment_idx + ) + self._pipeline_isend(tensor).wait() + + def pipeline_isend( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer( + tensor_send_to_next=tensor, name=name, segment_idx=segment_idx + ) + self._pipeline_isend(tensor) + + def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + name = name or "latent" + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self._pipeline_irecv(self.recv_buffer[name][idx]).wait() + return self.recv_buffer[name][idx] + + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): + name = name or "latent" + self.recv_tasks_queue.append((name, idx)) + + def recv_next(self): + if len(self.recv_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_tasks_queue) > 0: + name, idx = self.recv_tasks_queue.pop(0) + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self.receiving_tasks.append( + (self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx) + ) + + def get_pipeline_recv_data( + self, idx: int = -1, name: str = "latent" + ) -> torch.Tensor: + assert ( + len(self.receiving_tasks) > 0 + ), "No tasks to receive, call add_pipeline_recv_task first" + receiving_task = self.receiving_tasks.pop(0) + receiving_task[0].wait() + assert ( + receiving_task[1] == name and receiving_task[2] == idx + ), "Received tensor does not match the requested" + return self.recv_buffer[name][idx] + + def _pipeline_irecv(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, + src=self.prev_rank, + group=( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def _pipeline_isend(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, + dst=self.next_rank, + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def set_skip_tensor_recv_buffer( + self, + patches_shape_list: List[List[int]], + feature_map_shape: List[int], + ): + self.skip_tensor_recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) + for shape in patches_shape_list + ] + self.skip_tensor_recv_buffer.append( + torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) + ) + self.skip_tensor_recv_buffer_set = True + + def pipeline_send_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor).wait() + + def pipeline_isend_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor) + + def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor: + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait() + return self.skip_tensor_recv_buffer[idx] + + def add_pipeline_recv_skip_task(self, idx: int = -1): + self.recv_skip_tasks_queue.append(idx) + + def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor: + assert ( + len(self.receiving_skip_tasks) > 0 + ), "No tasks to receive, call add_pipeline_recv_skip_task first" + receiving_skip_task = self.receiving_skip_tasks.pop(0) + receiving_skip_task[0].wait() + assert ( + receiving_skip_task[2] == idx + ), "Received tensor does not match the requested" + return self.skip_tensor_recv_buffer[idx] + + def recv_skip_next(self): + if len(self.recv_skip_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_skip_tasks_queue) > 0: + task = self.recv_skip_tasks_queue.pop(0) + idx = task + self.receiving_skip_tasks.append( + ( + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]), + None, + idx, + ) + ) + + def _pipeline_irecv_skip(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, src=self.skip_rank, group=self.skip_device_group + ) + + def _pipeline_isend_skip(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, dst=self.skip_rank, group=self.skip_device_group + ) + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + + ulysses_group = kwargs.get("ulysses_group", None) + ring_group = kwargs.get("ring_group", None) + if ulysses_group is None: + raise RuntimeError( + f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + ) + if ring_group is None: + raise RuntimeError( + f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + ) + self.ulysses_group = ulysses_group + self.ring_group = ring_group + + self.ulysses_world_size = torch.distributed.get_world_size( + self.ulysses_group + ) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py new file mode 100644 index 00000000000..1057bdaf436 --- /dev/null +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -0,0 +1,737 @@ + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/xdit-project/xDiT/blob/main/xfuser/core/distributed/utils.py +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""vLLM-Omni distributed state. + +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model parallelism, + you can skip the model parallel initialization and destruction steps. +""" +from typing import Any, List, Optional + +import torch +import torch.distributed + +from vllm.logger import init_logger +from vllm_omni.diffusion.envs import envs + +from .group_coordinator import ( + GroupCoordinator, + PipelineGroupCoordinator, + SequenceParallelGroupCoordinator, +) + +try: + import torch_musa + from torch_musa.core.device import set_device, device_count +except ModuleNotFoundError: + pass + +try: + from torch.npu import set_device, device_count +except ModuleNotFoundError: + pass + + + +env_info = envs.PACKAGES_CHECKER.get_packages_info() + +HAS_FLASH_ATTN = env_info["has_flash_attn"] + +logger = init_logger(__name__) + + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_PP: Optional[PipelineGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None +_DP: Optional[GroupCoordinator] = None +_DIT: Optional[GroupCoordinator] = None +_VAE: Optional[GroupCoordinator] = None + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) + tp_rank \in [0, tp_size) + dp_rank \in [0, dp_size) + pp_rank \in [0, pp_size) + + If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. + For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the + dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) + The tp_rank and pp_rank will be combined to form the `dp_group_index`. + dp_group_index = tp_rank + pp_rank * tp_size (2) + + So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in + range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the + equation (1). + + This function solve this math problem. + + For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], + and the mask = [False, True, False]. Then, + dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 + dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 + ... + dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 + + dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] + dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] + ... + dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + pp: int, + cfg: int, + dp: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.pp = pp + self.cfg = cfg + self.dp = dp + self.rank_offset = rank_offset + self.world_size = tp * sp * pp * cfg * dp + + self.name_to_size = { + "tp": self.tp, + "sp": self.sp, + "pp": self.pp, + "cfg": self.cfg, + "dp": self.dp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + + +def get_ulysses_parallel_world_size(): + return get_sp_group().ulysses_world_size + + +def get_ulysses_parallel_rank(): + return get_sp_group().ulysses_rank + + +def get_ring_parallel_world_size(): + return get_sp_group().ring_world_size + + +def get_ring_parallel_rank(): + return get_sp_group().ring_rank + + +# PP +def get_pp_group() -> PipelineGroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +def get_pipeline_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_first_stage(): + """Return True if in the first pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == 0 + + +def is_pipeline_last_stage(): + """Return True if in the last pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert ( + _CFG is not None + ), "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + +# DP +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "pipeline model parallel group is not initialized" + return _DP + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def is_dp_last_group(): + """Return True if in the last data parallel group, False otherwise.""" + return ( + get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1) + and get_classifier_free_guidance_rank() + == (get_classifier_free_guidance_world_size() - 1) + and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + ) + + +def get_dit_world_size(): + """Return world size for the DiT model (excluding VAE).""" + return ( + get_data_parallel_world_size() + * get_classifier_free_guidance_world_size() + * get_sequence_parallel_world_size() + * get_pipeline_parallel_world_size() + * get_tensor_model_parallel_world_size() + ) + + +# Add VAE getter functions +def get_vae_parallel_group() -> GroupCoordinator: + assert _VAE is not None, "VAE parallel group is not initialized" + return _VAE + + +def get_vae_parallel_world_size(): + """Return world size for the VAE parallel group.""" + return get_vae_parallel_group().world_size + + +def get_vae_parallel_rank(): + """Return my rank for the VAE parallel group.""" + return get_vae_parallel_group().rank_in_group + + +# * SET + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: Optional[str] = None, +): + if backend is None: + backend = envs.get_torch_distributed_backend() + logger.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + set_device(torch.distributed.get_rank() % device_count()) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == torch.distributed.get_world_size() + ), "world group already initialized with a different world size" + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _DP is not None + and _CFG is not None + and _SP is not None + and _PP is not None + and _TP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "data", + "pipeline", + "tensor", + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "pipeline": + return PipelineGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + elif parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_dit_group( + dit_parallel_size: int, + backend: str, +): + global _DIT + _DIT = torch.distributed.new_group( + ranks=list(range(dit_parallel_size)), backend=backend + ) + + +def get_dit_group(): + assert _DIT is not None, "DIT group is not initialized" + return _DIT + + +def init_vae_group( + dit_parallel_size: int, + vae_parallel_size: int, + backend: str, +): + # Initialize VAE group first + global _VAE + assert _VAE is None, "VAE parallel group is already initialized" + vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) + _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) + + +def initialize_model_parallel( + data_parallel_degree: int = 1, + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: Optional[int] = None, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + pipeline_parallel_degree: int = 1, + vae_parallel_size: int = 0, + backend: Optional[str] = None, +) -> None: + if backend is None: + backend = envs.get_torch_distributed_backend() + """ + Initialize model parallel groups. + + Arguments: + data_parallel_degree: number of data parallelism groups. + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. sequence_parallel_degree = ulysses_degree * ring_degree + ulysses_degree: number of GPUs used for ulysses sequence parallelism. + ring_degree: number of GPUs used for ring sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + pipeline_parallel_degree: number of GPUs used for pipeline parallelism. + backend: distributed backend of pytorch collective comm. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize + splited batch caused by CFG, and 2 GPUs to parallelize sequence. + + dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16. + + The present function will create 8 data-parallel groups, + 8 CFG group, 8 pipeline-parallel group, and + 8 sequence-parallel groups: + 8 data-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], + [g4, g12], [g5, g13], [g6, g14], [g7, g15] + 8 CFG-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7], + [g8, g12], [g9, g13], [g10, g14], [g11, g15] + 8 sequence-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], + [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 pipeline-parallel groups: + [g0, g2], [g4, g6], [g8, g10], [g12, g14], + [g1, g3], [g5, g7], [g9, g11], [g13, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if sequence_parallel_degree is None: + sequence_parallel_degree = ring_degree * ulysses_degree + logger.info( + f"sequence_parallel_degree is not provided, using ring_degree * ulysses_degree = {sequence_parallel_degree}" + ) + + if sequence_parallel_degree != ring_degree * ulysses_degree: + raise ValueError( + f"sequence_parallel_degree is not equal to ring_degree * ulysses_degree, {sequence_parallel_degree} != {ring_degree} * {ulysses_degree}" + ) + + # FIXME: Since the async p2p communication operation of NPU is not same as cuda in torch, + # the pipefusion is not ready for npu yet + if envs._is_npu(): + assert pipeline_parallel_degree == 1, "Current pipefusion is not ready for NPU" + + dit_parallel_size = ( + data_parallel_degree + * classifier_free_guidance_degree + * sequence_parallel_degree + * pipeline_parallel_degree + * tensor_parallel_degree + ) + + if world_size < dit_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than " + f"tensor_parallel_degree ({tensor_parallel_degree}) x " + f"pipeline_parallel_degree ({pipeline_parallel_degree}) x" + f"sequence_parallel_degree ({sequence_parallel_degree}) x" + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x" + f"data_parallel_degree ({data_parallel_degree})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + pipeline_parallel_degree, + classifier_free_guidance_degree, + data_parallel_degree, + "tp-sp-pp-cfg-dp", + ) + global _DP + assert _DP is None, "data parallel group is already initialized" + _DP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("dp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="data", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + _PP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("pp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="pipeline", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + if vae_parallel_size > 0: + init_vae_group(dit_parallel_size, vae_parallel_size, backend) + init_dit_group(dit_parallel_size, backend) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _DP + if _DP: + _DP.destroy() + _DP = None + + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + global _VAE + if _VAE: + _VAE.destroy() + _VAE = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py new file mode 100644 index 00000000000..9e23d2fec51 --- /dev/null +++ b/vllm_omni/diffusion/envs.py @@ -0,0 +1,269 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/xdit-project/xDiT/blob/main/xfuser/envs.py +import os +import torch +import diffusers +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from packaging import version + +try: + import torch_musa +except ModuleNotFoundError: + pass + +from xfuser.logger import init_logger + +logger = init_logger(__name__) + +if TYPE_CHECKING: + MASTER_ADDR: str = "" + MASTER_PORT: Optional[int] = None + CUDA_HOME: Optional[str] = None + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + CUDA_VERSION: version.Version + TORCH_VERSION: version.Version + + +environment_variables: Dict[str, Callable[[], Any]] = { + # ================== Runtime Env Vars ================== + # used in distributed environment to determine the master address + "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""), + # used in distributed environment to manually set the communication port + "MASTER_PORT": lambda: ( + int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None + ), + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), +} + + +def _is_hip(): + has_rocm = torch.version.hip is not None + return has_rocm + + +def _is_cuda(): + has_cuda = torch.version.cuda is not None + return has_cuda + + +def _is_musa(): + try: + if hasattr(torch, "musa") and torch.musa.is_available(): + return True + except ModuleNotFoundError: + return False + + +def _is_mps(): + return torch.backends.mps.is_available() + + +def _is_npu(): + try: + if hasattr(torch, "npu") and torch.npu.is_available(): + return True + except ModuleNotFoundError: + return False + + +def get_device(local_rank: int) -> torch.device: + if _is_cuda() or _is_hip(): + return torch.device("cuda", local_rank) + elif _is_musa(): + return torch.device("musa", local_rank) + elif _is_mps(): + return torch.device("mps") + elif _is_npu(): + return torch.device("npu", local_rank) + else: + return torch.device("cpu") + + +def get_device_name() -> str: + if _is_cuda() or _is_hip(): + return "cuda" + elif _is_musa(): + return "musa" + elif _is_mps(): + return "mps" + elif _is_npu(): + return "npu" + else: + return "cpu" + + +def get_device_version(): + if _is_hip(): + hip_version = torch.version.hip + hip_version = hip_version.split("-")[0] + return hip_version + elif _is_cuda(): + return torch.version.cuda + elif _is_musa(): + return torch.version.musa + elif _is_mps(): + return None + elif _is_npu(): + return None + else: + raise NotImplementedError( + "No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available" + ) + + +def get_torch_distributed_backend() -> str: + if _is_cuda() or _is_hip(): + return "nccl" + elif _is_musa(): + return "mccl" + elif _is_mps(): + return "gloo" + elif _is_npu(): + return "hccl" + else: + raise NotImplementedError( + "No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available" + ) + + +variables: Dict[str, Callable[[], Any]] = { + # ================== Other Vars ================== + # used in version checking + "CUDA_VERSION": lambda: version.parse(get_device_version() or "0.0"), + "TORCH_VERSION": lambda: version.parse( + version.parse(torch.__version__).base_version + ), +} + + +def _setup_musa(environment_variables, variables): + musa = getattr(torch, "musa", None) + if musa is None: + return + try: + if musa.is_available(): + environment_variables["MUSA_HOME"] = lambda: os.environ.get( + "MUSA_HOME", None + ) + environment_variables["MUSA_VISIBLE_DEVICES"] = lambda: os.environ.get( + "MUSA_VISIBLE_DEVICES", None + ) + musa_ver = getattr(getattr(torch, "version", None), "musa", None) + if musa_ver: + variables["MUSA_VERSION"] = lambda: version.parse(musa_ver) + except Exception: + pass + + +try: + _setup_musa(environment_variables, variables) +except (AttributeError, ModuleNotFoundError): + pass + + +class PackagesEnvChecker: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(PackagesEnvChecker, cls).__new__(cls) + cls._instance.initialize() + return cls._instance + + def initialize(self): + packages_info = {} + packages_info["has_aiter"] = self.check_aiter() + packages_info["has_flash_attn"] = self.check_flash_attn(packages_info) + packages_info["diffusers_version"] = self.check_diffusers_version() + self.packages_info = packages_info + + def check_aiter(self): + """ + Checks whether ROCm AITER library is installed + """ + try: + import aiter + logger.info("Using AITER as the attention library") + return True + except: + if _is_hip(): + logger.warning( + f'Using AMD GPUs, but library "aiter" is not installed, ' + 'defaulting to other attention mechanisms' + ) + return False + + + def check_flash_attn(self, packages_info): + if not torch.cuda.is_available(): + return False + + # Check if torch_npu is available + if _is_npu(): + logger.info("falsh_attn is not ready on torch_npu for now") + return False + + if _is_musa(): + logger.info( + "Flash Attention library is not supported on MUSA for the moment." + ) + return False + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + gpu_name = torch.cuda.get_device_name(device) + if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: + return False + else: + from flash_attn import flash_attn_func + from flash_attn import __version__ + + if __version__ < "2.6.0": + raise ImportError(f"install flash_attn >= 2.6.0") + return True + except ImportError: + if not packages_info.get("has_aiter", False): + logger.warning( + f'Flash Attention library "flash_attn" not found, ' + f"using pytorch attention implementation" + ) + return False + + + + def check_diffusers_version(self): + if version.parse( + version.parse(diffusers.__version__).base_version + ) < version.parse("0.30.0"): + raise RuntimeError( + f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported," + f"please upgrade to version > 0.30.0" + ) + return version.parse(version.parse(diffusers.__version__).base_version) + + def get_packages_info(self): + return self.packages_info + + +PACKAGES_CHECKER = PackagesEnvChecker() + + +def __getattr__(name): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + if name in variables: + return variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 86861184149..e20c4f92e92 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -8,10 +8,8 @@ import zmq from vllm.config import LoadConfig, VllmConfig, set_current_vllm_config from vllm.distributed.device_communicators.shm_broadcast import MessageQueue -from vllm.distributed.parallel_state import ( - init_distributed_environment, - initialize_model_parallel, -) +from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel + from vllm.logger import init_logger from vllm.utils import DeviceMemoryProfiler, GiB_bytes @@ -65,8 +63,17 @@ def init_device_and_model(self) -> None: set_current_vllm_config(vllm_config) init_distributed_environment(world_size=world_size, rank=rank) - initialize_model_parallel(tensor_model_parallel_size=world_size) logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_degree=parallel_config.data_parallel_size, + classifier_free_guidance_degree=parallel_config.cfg_parallel_size, + sequence_parallel_degree=parallel_config.sequence_parallel_size, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tensor_parallel_size, + pipeline_parallel_degree=parallel_config.pipeline_parallel_size, + ) load_config = LoadConfig() model_loader = DiffusersPipelineLoader(load_config) From 15de6d01d4141b63106b9da003906aa17386a925 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:35:06 +0800 Subject: [PATCH 03/85] update usp Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 39 +++++--------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index c141a8b7df9..c4b28dfa873 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -13,10 +13,10 @@ AttentionMetadata, ) from vllm_omni.diffusion.attention.selector import get_attn_backend - +from vllm_omni.utils.platform_utils import is_npu from typing import Any from torch import Tensor -from yunchang.kernels import AttnType, select_flash_attn_impl # FIXME: replace it with vllm-omni attention + import torch.distributed as dist from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D @@ -71,7 +71,6 @@ def __init__( scatter_idx: int = 2, gather_idx: int = 1, use_sync: bool = False, - attn_type : AttnType = AttnType.FA, ) -> None: super(UlyssesAttention, self).__init__() @@ -79,31 +78,14 @@ def __init__( self.scatter_idx = scatter_idx self.gather_idx = gather_idx self.use_sync = use_sync - self.attn_type = attn_type - - try: - import torch_npu - device = torch.device("npu") - except: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - gpu_name = torch.cuda.get_device_name(device) - if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: - self.attn_type = AttnType.TORCH - self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd") def forward( self, + attn: Attention, query: Tensor, key: Tensor, value: Tensor, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, + attn_metadata: AttentionMetadata = None, *args: Any ) -> Tensor: """forward @@ -130,8 +112,8 @@ def forward( if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 - if self.attn_type is AttnType.NPU: - context_layer = self.attn_fn( + if is_npu(): + context_layer = attn( q, k, v, @@ -147,14 +129,7 @@ def forward( q, k, v, - dropout_p=dropout_p, - softmax_scale = softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, + attn_metadata=attn_metadata, ) if isinstance(context_layer, tuple): From 5e4f8654afc224230202af7936160b72b6d9aafb Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 8 Dec 2025 18:37:32 +0800 Subject: [PATCH 04/85] set omni diffusion config Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 1 + vllm_omni/diffusion/data.py | 46 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index c4b28dfa873..6049b2dcd0a 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -14,6 +14,7 @@ ) from vllm_omni.diffusion.attention.selector import get_attn_backend from vllm_omni.utils.platform_utils import is_npu +from vllm_omni.diffusion.data import get_current_omni_diffusion_config from typing import Any from torch import Tensor diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 36bf24b60a8..309c5d41882 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -9,6 +9,8 @@ from pydantic import Field, model_validator from typing import Any +from contextlib import contextmanager +from functools import lru_cache from typing_extensions import Self import torch from vllm.logger import init_logger @@ -384,6 +386,50 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": kwargs["cache_backend"] = cache_backend.lower() if cache_backend else "none" return cls(**kwargs) +_current_omni_diffusion_config: OmniDiffusionConfig | None = None +_current_prefix: str | None = None +@contextmanager +def set_current_vllm_config( + omni_diffusion_config: OmniDiffusionConfig, check_compile=False, prefix: str | None = None +): + """ + Temporarily set the current vLLM config. + Used during model initialization. + We save the current vLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM config to determine how to dispatch. + """ + global _current_omni_diffusion_config, _current_prefix + old_omni_diffusion_config = _current_omni_diffusion_config + old_prefix = _current_prefix + from vllm.compilation.counter import compilation_counter + + num_models_seen = compilation_counter.num_models_seen + try: + _current_omni_diffusion_config = omni_diffusion_config + _current_prefix = prefix + yield + except Exception: + raise + else: + if check_compile: + raise RuntimeError("Compilation is not yet supported for OmniDiffusion") + finally: + _current_omni_diffusion_config = old_omni_diffusion_config + _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_omni_diffusion_config()""" + return get_current_omni_diffusion_config().compilation_config + +def get_current_omni_diffusion_config() -> OmniDiffusionConfig: + if _current_omni_diffusion_config is None: + logger.warning("Current OmniDiffusionConfig is not set.") + return OmniDiffusionConfig() + return _current_omni_diffusion_config @dataclass class DiffusionOutput: From c104832afbfb29063a3980d08f6011969574811c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:39:50 +0800 Subject: [PATCH 05/85] ulysses Attention Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 119 ++++++++++++------------- 1 file changed, 56 insertions(+), 63 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 6049b2dcd0a..6cc159ab342 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -15,11 +15,12 @@ from vllm_omni.diffusion.attention.selector import get_attn_backend from vllm_omni.utils.platform_utils import is_npu from vllm_omni.diffusion.data import get_current_omni_diffusion_config -from typing import Any +from typing import Optional from torch import Tensor import torch.distributed as dist from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D +from vllm_omni.diffusion.distributed.parallel_state import get_sp_group, get_sequence_parallel_world_size class Attention(nn.Module): def __init__( @@ -30,6 +31,10 @@ def __init__( softmax_scale: float, num_kv_heads: int | None = None, prefix: str = "", + # ulysses attention + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, ): super().__init__() self.attn_backend = get_attn_backend(-1) @@ -42,6 +47,26 @@ def __init__( num_kv_heads=num_kv_heads, ) + self.softmax_scale = softmax_scale + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.use_sync = use_sync + self.sequence_process_group: Optional[dist.ProcessGroup] = None + config = get_current_omni_diffusion_config() + + try: + if config.parallel_config.ulysses_degree > 1: + self.use_ulysses = True + # Get sequence parallel process group + try: + sp_group = get_sp_group() + self.sequence_process_group = sp_group.device_group + assert get_sequence_parallel_world_size() > 1, "Sequence parallel world size must be > 1" + except (AssertionError, RuntimeError): + self.use_ulysses = False + except Exception: + self.use_ulysses = False + def forward( self, query: torch.Tensor, @@ -49,84 +74,51 @@ def forward( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - # shape: (batch_size, seq_len, num_heads, head_size) - attn_output = self.attention.forward(query, key, value, attn_metadata) - return attn_output - - -class UlyssesAttention(torch.nn.Module): - """Initialization. - - Arguments: - local_attention (Module): local attention with q,k,v - sequence_process_group (ProcessGroup): sequence parallel process group - scatter_idx (int): scatter_idx for all2all comm - gather_idx (int): gather_idx for all2all comm - use_sync (bool): whether to synchronize after all-to-all. This flag can save cuda memory but will slow down the speed. - attn_type (AttnType): attention type enum - """ - - def __init__( - self, - sequence_process_group: dist.ProcessGroup = None, - scatter_idx: int = 2, - gather_idx: int = 1, - use_sync: bool = False, - ) -> None: - - super(UlyssesAttention, self).__init__() - self.spg = sequence_process_group - self.scatter_idx = scatter_idx - self.gather_idx = gather_idx - self.use_sync = use_sync - - def forward( + if self.use_ulysses: + return self._forward_ulysses(query, key, value, attn_metadata) + else: + # shape: (batch_size, seq_len, num_heads, head_size) + attn_output = self.attention.forward(query, key, value, attn_metadata) + return attn_output + + def _forward_ulysses( self, - attn: Attention, query: Tensor, key: Tensor, value: Tensor, attn_metadata: AttentionMetadata = None, - *args: Any ) -> Tensor: - """forward - - Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ - # TODO Merge three alltoall calls into one - # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! - # in shape : e.g., [s/p:h:] - # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) - + """Ulysses attention forward pass with sequence parallelism.""" # scatter 2, gather 1 - q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync) - k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync) - v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync) + # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) + q = SeqAllToAll4D.apply( + self.sequence_process_group, query, self.scatter_idx, self.gather_idx, self.use_sync + ) + k = SeqAllToAll4D.apply( + self.sequence_process_group, key, self.scatter_idx, self.gather_idx, self.use_sync + ) + v = SeqAllToAll4D.apply( + self.sequence_process_group, value, self.scatter_idx, self.gather_idx, self.use_sync + ) + softmax_scale = self.softmax_scale if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 if is_npu(): - context_layer = attn( + context_layer = self.attention( q, k, v, - num_heads = q.shape[-2], - input_layout = "BSND", - scale = softmax_scale, - softmax_lse_flag = True, - pre_tokens=65535, + num_heads=q.shape[-2], + input_layout="BSND", + scale=softmax_scale, + softmax_lse_flag=True, + pre_tokens=65535, next_tokens=65535, ) else: - context_layer = self.attn_fn( + context_layer = self.attention.forward( q, k, v, @@ -139,8 +131,9 @@ def forward( # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) # scatter 1, gather 2 output = SeqAllToAll4D.apply( - self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync + self.sequence_process_group, context_layer, self.gather_idx, self.scatter_idx, self.use_sync ) - # out e.g., [s/p::h] return output + + From 2657aad4c292df4e883259688d721573d69b1aef Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:55:16 +0800 Subject: [PATCH 06/85] impr Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 31 +- vllm_omni/diffusion/data.py | 26 +- vllm_omni/diffusion/distributed/comm.py | 41 +-- .../distributed/group_coordinator.py | 290 +++++------------- .../diffusion/distributed/parallel_state.py | 66 ++-- vllm_omni/diffusion/envs.py | 56 +--- vllm_omni/diffusion/worker/gpu_worker.py | 3 +- 7 files changed, 161 insertions(+), 352 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 6cc159ab342..b9bb0882c5c 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -6,21 +6,22 @@ # Adapted from # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py +from typing import Optional + import torch +import torch.distributed as dist import torch.nn as nn +from torch import Tensor from vllm_omni.diffusion.attention.backends.abstract import ( AttentionMetadata, ) from vllm_omni.diffusion.attention.selector import get_attn_backend -from vllm_omni.utils.platform_utils import is_npu from vllm_omni.diffusion.data import get_current_omni_diffusion_config -from typing import Optional -from torch import Tensor - -import torch.distributed as dist from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D -from vllm_omni.diffusion.distributed.parallel_state import get_sp_group, get_sequence_parallel_world_size +from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size, get_sp_group +from vllm_omni.utils.platform_utils import is_npu + class Attention(nn.Module): def __init__( @@ -53,7 +54,7 @@ def __init__( self.use_sync = use_sync self.sequence_process_group: Optional[dist.ProcessGroup] = None config = get_current_omni_diffusion_config() - + try: if config.parallel_config.ulysses_degree > 1: self.use_ulysses = True @@ -80,7 +81,7 @@ def forward( # shape: (batch_size, seq_len, num_heads, head_size) attn_output = self.attention.forward(query, key, value, attn_metadata) return attn_output - + def _forward_ulysses( self, query: Tensor, @@ -91,15 +92,9 @@ def _forward_ulysses( """Ulysses attention forward pass with sequence parallelism.""" # scatter 2, gather 1 # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) - q = SeqAllToAll4D.apply( - self.sequence_process_group, query, self.scatter_idx, self.gather_idx, self.use_sync - ) - k = SeqAllToAll4D.apply( - self.sequence_process_group, key, self.scatter_idx, self.gather_idx, self.use_sync - ) - v = SeqAllToAll4D.apply( - self.sequence_process_group, value, self.scatter_idx, self.gather_idx, self.use_sync - ) + q = SeqAllToAll4D.apply(self.sequence_process_group, query, self.scatter_idx, self.gather_idx, self.use_sync) + k = SeqAllToAll4D.apply(self.sequence_process_group, key, self.scatter_idx, self.gather_idx, self.use_sync) + v = SeqAllToAll4D.apply(self.sequence_process_group, value, self.scatter_idx, self.gather_idx, self.use_sync) softmax_scale = self.softmax_scale if softmax_scale is None: @@ -135,5 +130,3 @@ def _forward_ulysses( ) return output - - diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 309c5d41882..4bcaf249393 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -11,15 +11,18 @@ from typing import Any from contextlib import contextmanager from functools import lru_cache -from typing_extensions import Self + import torch -from vllm.logger import init_logger +from pydantic import Field, model_validator +from typing_extensions import Self from vllm.config.utils import config +from vllm.logger import init_logger from vllm_omni.diffusion.utils.network_utils import is_port_available logger = init_logger(__name__) + @config @dataclass class DiffusionParallelConfig: @@ -56,11 +59,13 @@ def _validate_parallel_config(self) -> Self: assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" - assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, "Sequence parallel size must be equal to the product of ulysses degree and ring degree, but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" + assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( + "Sequence parallel size must be equal to the product of ulysses degree and ring degree," + f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" + ) return self def __post_init__(self) -> None: - self.world_size = ( self.pipeline_parallel_size * self.data_parallel_size @@ -70,6 +75,7 @@ def __post_init__(self) -> None: * self.cfg_parallel_size ) + @dataclass class TransformerConfig: """Container for raw transformer configuration dictionaries.""" @@ -386,12 +392,13 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": kwargs["cache_backend"] = cache_backend.lower() if cache_backend else "none" return cls(**kwargs) + _current_omni_diffusion_config: OmniDiffusionConfig | None = None _current_prefix: str | None = None + + @contextmanager -def set_current_vllm_config( - omni_diffusion_config: OmniDiffusionConfig, check_compile=False, prefix: str | None = None -): +def set_current_vllm_config(omni_diffusion_config: OmniDiffusionConfig, check_compile=False, prefix: str | None = None): """ Temporarily set the current vLLM config. Used during model initialization. @@ -404,7 +411,7 @@ def set_current_vllm_config( old_prefix = _current_prefix from vllm.compilation.counter import compilation_counter - num_models_seen = compilation_counter.num_models_seen + # num_models_seen = compilation_counter.num_models_seen try: _current_omni_diffusion_config = omni_diffusion_config _current_prefix = prefix @@ -420,17 +427,20 @@ def set_current_vllm_config( # Clear the compilation config cache when context changes get_cached_compilation_config.cache_clear() + @lru_cache(maxsize=1) def get_cached_compilation_config(): """Cache config to avoid repeated calls to get_current_omni_diffusion_config()""" return get_current_omni_diffusion_config().compilation_config + def get_current_omni_diffusion_config() -> OmniDiffusionConfig: if _current_omni_diffusion_config is None: logger.warning("Current OmniDiffusionConfig is not set.") return OmniDiffusionConfig() return _current_omni_diffusion_config + @dataclass class DiffusionOutput: """ diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py index 8d031463768..271e22ad24e 100644 --- a/vllm_omni/diffusion/distributed/comm.py +++ b/vllm_omni/diffusion/distributed/comm.py @@ -2,12 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team & Jiarui Fang # from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/all_to_all.py -import torch - from typing import Any, Tuple -from torch import Tensor +import torch import torch.distributed as dist +from torch import Tensor def all_to_all_4D( @@ -26,9 +25,7 @@ def all_to_all_4D( Returns: torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) """ - assert ( - input.dim() == 4 - ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + assert input.dim() == 4, f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" seq_world_size = dist.get_world_size(group) @@ -40,11 +37,7 @@ def all_to_all_4D( # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) - input_t = ( - input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) - .transpose(0, 2) - .contiguous() - ) + input_t = input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous() output = torch.empty_like(input_t) # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single @@ -94,7 +87,7 @@ def all_to_all_4D( # if scattering the seq-dim, transpose the heads back to the original dimension output = output.reshape(hc, shard_seqlen, bs, hs) - # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + # (hc, seqlen/N, bs, hs) -transpose(0,2)-> (bs, seqlen/N, hc, hs) output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) return output @@ -112,7 +105,6 @@ def forward( gather_idx: int, use_sync: bool = False, ) -> Tensor: - ctx.group = group ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx @@ -123,9 +115,7 @@ def forward( def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return ( None, - SeqAllToAll4D.apply( - ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync - ), + SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync), None, None, None, @@ -149,9 +139,7 @@ def all_to_all_5D( Returns: torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) """ - assert ( - input.dim() == 5 - ), f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" + assert input.dim() == 5, f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" seq_world_size = dist.get_world_size(group) @@ -165,11 +153,7 @@ def all_to_all_5D( # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs) - input_t = ( - input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs) - .transpose(0, 3) - .contiguous() - ) + input_t = input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs).transpose(0, 3).contiguous() output = torch.empty_like(input_t) # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single @@ -218,7 +202,7 @@ def all_to_all_5D( # if scattering the seq-dim, transpose the heads back to the original dimension output = output.reshape(hc, shard_seqlen, 3, bs, hs) - # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + # (hc, seqlen/N, bs, hs) -transpose(0,2)-> (bs, seqlen/N, hc, hs) output = output.transpose(0, 3).contiguous() return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous() @@ -236,7 +220,6 @@ def forward( gather_idx: int = 1, use_sync: bool = False, ) -> Tensor: - ctx.group = group ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx @@ -248,10 +231,8 @@ def forward( def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return ( None, - SeqAllToAll5D.apply( - ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync - ), + SeqAllToAll5D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync), None, None, None, - ) \ No newline at end of file + ) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 00a7b8625e5..c4c0add3783 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -3,9 +3,9 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import pickle from collections import namedtuple from typing import Any, Dict, List, Optional, Tuple, Union -import pickle import torch import torch.distributed @@ -19,6 +19,7 @@ pass from vllm_omni.diffusion.envs import envs + if envs._is_npu(): print("torch.npu synchronize") from torch.npu import synchronize @@ -46,26 +47,19 @@ def _split_tensor_dict( metadata_list: List[Tuple[str, Any]] = [] tensor_list = [] for key, value in tensor_dict.items(): - assert "%" not in key, ( - "Avoid having '%' in key " - "as it is used as a separator for nested entries." - ) + assert "%" not in key, "Avoid having '%' in key as it is used as a separator for nested entries." if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device # index (e.g. "cuda:0"). We only need the device type. # receiving side will set the device index. device = value.device.type - metadata_list.append( - (prefix + key, TensorMetadata(device, value.dtype, value.size())) - ) + metadata_list.append((prefix + key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) elif isinstance(value, dict): if len(value) == 0: metadata_list.append((prefix + key, value)) - inner_metadata_list, inner_tensor_list = _split_tensor_dict( - value, prefix + key + "%" - ) + inner_metadata_list, inner_tensor_list = _split_tensor_dict(value, prefix + key + "%") metadata_list.extend(inner_metadata_list) tensor_list.extend(inner_tensor_list) else: @@ -116,16 +110,13 @@ def __init__( local_rank: int, torch_distributed_backend: Union[str, Backend], ): - self.rank = torch.distributed.get_rank() self.local_rank = local_rank self.device_group = None self.cpu_group = None for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") @@ -223,32 +214,29 @@ def all_gather( # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() # Allocate output tensor. input_size = list(input_.size()) input_size[0] *= world_size - output_tensor = torch.empty( - input_size, dtype=input_.dtype, device=input_.device - ) + output_tensor = torch.empty(input_size, dtype=input_.dtype, device=input_.device) # All-gather. - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=self.device_group - ) + torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.device_group) if dim != 0: input_size[0] //= world_size - output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.reshape( + [ + world_size, + ] + + input_size + ) output_tensor = output_tensor.movedim(0, dim) if separate_tensors: tensor_list = [ - output_tensor.view(-1) - .narrow(0, input_.numel() * i, input_.numel()) - .view_as(input_) + output_tensor.view(-1).narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) ] return tensor_list @@ -269,9 +257,7 @@ def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Ten # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -281,9 +267,7 @@ def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Ten else: gather_list = None # Gather. - torch.distributed.gather( - input_, gather_list, dst=self.ranks[dst], group=self.device_group - ) + torch.distributed.gather(input_, gather_list, dst=self.ranks[dst], group=self.device_group) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) else: @@ -300,9 +284,7 @@ def broadcast(self, input_: torch.Tensor, src: int = 0): if self.world_size == 1: return input_ # Broadcast. - torch.distributed.broadcast( - input_, src=self.ranks[src], group=self.device_group - ) + torch.distributed.broadcast(input_, src=self.ranks[src], group=self.device_group) return input_ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): @@ -318,20 +300,14 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): assert src == 0, "Shared memory broadcaster only supports src=0" return self.shm_broadcaster.broadcast_object(obj) if self.rank_in_group == src: - torch.distributed.broadcast_object_list( - [obj], src=self.ranks[src], group=self.cpu_group - ) + torch.distributed.broadcast_object_list([obj], src=self.ranks[src], group=self.cpu_group) return obj else: recv = [None] - torch.distributed.broadcast_object_list( - recv, src=self.ranks[src], group=self.cpu_group - ) + torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) return recv[0] - def broadcast_object_list( - self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None - ): + def broadcast_object_list(self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ @@ -341,9 +317,7 @@ def broadcast_object_list( if self.world_size == 1: return obj_list # Broadcast. - torch.distributed.broadcast_object_list( - obj_list, src=self.ranks[src], group=self.device_group - ) + torch.distributed.broadcast_object_list(obj_list, src=self.ranks[src], group=self.device_group) return obj_list def send_object(self, obj: Any, dst: int) -> None: @@ -352,17 +326,12 @@ def send_object(self, obj: Any, dst: int) -> None: assert dst < self.world_size, f"Invalid dst rank ({dst})" - assert dst != self.rank, ( - "Invalid destination rank. Destination rank is the same " - "as the current rank." - ) + assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank." # Serialize object to tensor and get the size as well object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - size_tensor = torch.tensor( - [object_tensor.numel()], dtype=torch.long, device="cpu" - ) + size_tensor = torch.tensor([object_tensor.numel()], dtype=torch.long, device="cpu") # Send object size @@ -379,16 +348,12 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" - assert ( - src != self.rank - ), "Invalid source rank. Source rank is the same as the current rank." + assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank." size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size - rank_size = torch.distributed.recv( - size_tensor, src=self.ranks[src], group=self.cpu_group - ) + rank_size = torch.distributed.recv(size_tensor, src=self.ranks[src], group=self.cpu_group) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] @@ -397,13 +362,9 @@ def recv_object(self, src: int) -> Any: device="cpu", ) - rank_object = torch.distributed.recv( - object_tensor, src=self.ranks[src], group=self.cpu_group - ) + rank_object = torch.distributed.recv(object_tensor, src=self.ranks[src], group=self.cpu_group) - assert ( - rank_object == rank_size - ), "Received object sender rank does not match the size sender rank." + assert rank_object == rank_size, "Received object sender rank does not match the size sender rank." obj = pickle.loads(object_tensor.numpy().tobytes()) @@ -431,9 +392,7 @@ def broadcast_tensor_dict( rank = self.rank if rank == src: metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, dict - ), f"Expecting a dictionary, got {type(tensor_dict)}" + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, @@ -446,14 +405,10 @@ def broadcast_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - handle = torch.distributed.broadcast( - tensor, src=src, group=metadata_group, async_op=True - ) + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) else: # use group for GPU tensors - handle = torch.distributed.broadcast( - tensor, src=src, group=group, async_op=True - ) + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -464,23 +419,17 @@ def broadcast_tensor_dict( async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty( - value.size, dtype=value.dtype, device=value.device - ) + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. _update_nested_dict(tensor_dict, key, tensor) continue if tensor.is_cpu: # use metadata_group for CPU tensors - handle = torch.distributed.broadcast( - tensor, src=src, group=metadata_group, async_op=True - ) + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) else: # use group for GPU tensors - handle = torch.distributed.broadcast( - tensor, src=src, group=group, async_op=True - ) + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) async_handles.append(handle) _update_nested_dict(tensor_dict, key, tensor) else: @@ -509,9 +458,7 @@ def send_tensor_dict( assert dst < self.world_size, f"Invalid dst rank ({dst})" metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, dict - ), f"Expecting a dictionary, got {type(tensor_dict)}" + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, @@ -523,17 +470,13 @@ def send_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.send( - tensor, dst=self.ranks[dst], group=metadata_group - ) + torch.distributed.send(tensor, dst=self.ranks[dst], group=metadata_group) else: # use group for GPU tensors torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None - def recv_tensor_dict( - self, src: Optional[int] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -559,9 +502,7 @@ def recv_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.recv( - tensor, src=self.ranks[src], group=metadata_group - ) + torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) else: # use group for GPU tensors torch.distributed.recv(tensor, src=self.ranks[src], group=group) @@ -588,16 +529,10 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: torch.distributed.send( tensor, self.ranks[dst], - group=( - self.device_groups[self.rank_in_group % 2] - if self.world_size == 2 - else self.device_group - ), + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), ) - def recv( - self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None - ) -> torch.Tensor: + def recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the src rank.""" """NOTE: `src` is the rank_in_group of the source rank.""" if src is None: @@ -607,11 +542,7 @@ def recv( torch.distributed.recv( tensor, self.ranks[src], - ( - self.device_groups[(self.rank_in_group + 1) % 2] - if self.world_size == 2 - else self.device_group - ), + (self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), ) return tensor @@ -657,9 +588,7 @@ def __init__( self.device_groups = [] if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") @@ -677,12 +606,8 @@ def __init__( # device 0. elif len(group_ranks[0]) == 2: for ranks in group_ranks: - device_group_0_1 = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) - device_group_1_0 = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + device_group_0_1 = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + device_group_1_0 = torch.distributed.new_group(ranks, backend=torch_distributed_backend) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") @@ -714,14 +639,10 @@ def __init__( self.skip_tensor_recv_buffer_set: bool = False self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = [] self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] - self.skip_tensor_recv_buffer: Optional[ - Union[List[torch.Tensor], torch.Tensor] - ] = None + self.skip_tensor_recv_buffer: Optional[Union[List[torch.Tensor], torch.Tensor]] = None self.skip_device_group = None for ranks in group_ranks: - skip_device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + skip_device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) if self.rank in ranks: self.skip_device_group = skip_device_group assert self.skip_device_group is not None @@ -748,18 +669,13 @@ def set_recv_buffer( dtype: torch.dtype, ): assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" - assert ( - isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1 - ), "num_pipefusion_patches must be greater than or equal to 1" + assert isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1, ( + "num_pipefusion_patches must be greater than or equal to 1" + ) self.dtype = dtype self.num_pipefusion_patches = num_pipefusion_patches - self.recv_buffer = [ - torch.zeros(*shape, dtype=self.dtype, device=self.device) - for shape in patches_shape_list - ] - self.recv_buffer.append( - torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) - ) + self.recv_buffer = [torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list] + self.recv_buffer.append(torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)) self.recv_buffer_set = True def set_extra_tensors_recv_buffer( @@ -770,8 +686,7 @@ def set_extra_tensors_recv_buffer( dtype: torch.dtype = torch.float16, ): self.extra_tensors_recv_buffer[name] = [ - torch.zeros(*shape, dtype=dtype, device=self.device) - for _ in range(num_buffers) + torch.zeros(*shape, dtype=dtype, device=self.device) for _ in range(num_buffers) ] def _check_shape_and_buffer( @@ -813,18 +728,12 @@ def _check_shape_and_buffer( if self.recv_buffer.get(name, None) is None: self.recv_buffer[name] = { - segment_idx: torch.zeros( - recv_prev_shape, device=self.device, dtype=self.dtype - ) + segment_idx: torch.zeros(recv_prev_shape, device=self.device, dtype=self.dtype) } else: if self.recv_buffer[name].get(segment_idx, None) is not None: - logger.warning( - f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating..." - ) - self.recv_buffer[name][segment_idx] = torch.zeros( - recv_prev_shape, device=self.device, dtype=self.dtype - ) + logger.warning(f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating...") + self.recv_buffer[name][segment_idx] = torch.zeros(recv_prev_shape, device=self.device, dtype=self.dtype) def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): """Communicate tensor shapes between stages. Used to communicate @@ -839,9 +748,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): ops = [] if recv_prev: - recv_prev_dim_tensor = torch.empty( - (1), device=self.device, dtype=torch.int64 - ) + recv_prev_dim_tensor = torch.empty((1), device=self.device, dtype=torch.int64) recv_prev_dim_op = torch.distributed.P2POp( torch.distributed.irecv, recv_prev_dim_tensor, @@ -851,9 +758,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): ops.append(recv_prev_dim_op) if tensor_send_to_next is not None: - send_next_dim_tensor = torch.tensor( - tensor_send_to_next.dim(), device=self.device, dtype=torch.int64 - ) + send_next_dim_tensor = torch.tensor(tensor_send_to_next.dim(), device=self.device, dtype=torch.int64) send_next_dim_op = torch.distributed.P2POp( torch.distributed.isend, send_next_dim_tensor, @@ -886,9 +791,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): ops.append(recv_prev_shape_op) if tensor_send_to_next is not None: - send_next_shape_tensor = torch.tensor( - tensor_send_to_next.size(), device=self.device, dtype=torch.int64 - ) + send_next_shape_tensor = torch.tensor(tensor_send_to_next.size(), device=self.device, dtype=torch.int64) send_next_shape_op = torch.distributed.P2POp( torch.distributed.isend, send_next_shape_tensor, @@ -909,22 +812,14 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): recv_prev_shape = recv_prev_shape_tensor return torch.Size(recv_prev_shape) - def pipeline_send( - self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 - ) -> None: + def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: tensor = tensor.contiguous() - self._check_shape_and_buffer( - tensor_send_to_next=tensor, name=name, segment_idx=segment_idx - ) + self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) self._pipeline_isend(tensor).wait() - def pipeline_isend( - self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 - ) -> None: + def pipeline_isend(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: tensor = tensor.contiguous() - self._check_shape_and_buffer( - tensor_send_to_next=tensor, name=name, segment_idx=segment_idx - ) + self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) self._pipeline_isend(tensor) def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: @@ -943,43 +838,27 @@ def recv_next(self): elif len(self.recv_tasks_queue) > 0: name, idx = self.recv_tasks_queue.pop(0) self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) - self.receiving_tasks.append( - (self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx) - ) + self.receiving_tasks.append((self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx)) - def get_pipeline_recv_data( - self, idx: int = -1, name: str = "latent" - ) -> torch.Tensor: - assert ( - len(self.receiving_tasks) > 0 - ), "No tasks to receive, call add_pipeline_recv_task first" + def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + assert len(self.receiving_tasks) > 0, "No tasks to receive, call add_pipeline_recv_task first" receiving_task = self.receiving_tasks.pop(0) receiving_task[0].wait() - assert ( - receiving_task[1] == name and receiving_task[2] == idx - ), "Received tensor does not match the requested" + assert receiving_task[1] == name and receiving_task[2] == idx, "Received tensor does not match the requested" return self.recv_buffer[name][idx] def _pipeline_irecv(self, tensor: torch.tensor): return torch.distributed.irecv( tensor, src=self.prev_rank, - group=( - self.device_groups[(self.rank_in_group + 1) % 2] - if self.world_size == 2 - else self.device_group - ), + group=(self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), ) def _pipeline_isend(self, tensor: torch.tensor): return torch.distributed.isend( tensor, dst=self.next_rank, - group=( - self.device_groups[self.rank_in_group % 2] - if self.world_size == 2 - else self.device_group - ), + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), ) def set_skip_tensor_recv_buffer( @@ -988,12 +867,9 @@ def set_skip_tensor_recv_buffer( feature_map_shape: List[int], ): self.skip_tensor_recv_buffer = [ - torch.zeros(*shape, dtype=self.dtype, device=self.device) - for shape in patches_shape_list + torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list ] - self.skip_tensor_recv_buffer.append( - torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) - ) + self.skip_tensor_recv_buffer.append(torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)) self.skip_tensor_recv_buffer_set = True def pipeline_send_skip(self, tensor: torch.Tensor) -> None: @@ -1012,14 +888,10 @@ def add_pipeline_recv_skip_task(self, idx: int = -1): self.recv_skip_tasks_queue.append(idx) def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor: - assert ( - len(self.receiving_skip_tasks) > 0 - ), "No tasks to receive, call add_pipeline_recv_skip_task first" + assert len(self.receiving_skip_tasks) > 0, "No tasks to receive, call add_pipeline_recv_skip_task first" receiving_skip_task = self.receiving_skip_tasks.pop(0) receiving_skip_task[0].wait() - assert ( - receiving_skip_task[2] == idx - ), "Received tensor does not match the requested" + assert receiving_skip_task[2] == idx, "Received tensor does not match the requested" return self.skip_tensor_recv_buffer[idx] def recv_skip_next(self): @@ -1037,14 +909,10 @@ def recv_skip_next(self): ) def _pipeline_irecv_skip(self, tensor: torch.tensor): - return torch.distributed.irecv( - tensor, src=self.skip_rank, group=self.skip_device_group - ) + return torch.distributed.irecv(tensor, src=self.skip_rank, group=self.skip_device_group) def _pipeline_isend_skip(self, tensor: torch.tensor): - return torch.distributed.isend( - tensor, dst=self.skip_rank, group=self.skip_device_group - ) + return torch.distributed.isend(tensor, dst=self.skip_rank, group=self.skip_device_group) class SequenceParallelGroupCoordinator(GroupCoordinator): @@ -1065,18 +933,16 @@ def __init__( ring_group = kwargs.get("ring_group", None) if ulysses_group is None: raise RuntimeError( - f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + "Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" ) if ring_group is None: raise RuntimeError( - f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + "Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" ) self.ulysses_group = ulysses_group self.ring_group = ring_group - self.ulysses_world_size = torch.distributed.get_world_size( - self.ulysses_group - ) + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) self.ring_world_size = torch.distributed.get_world_size(self.ring_group) self.ring_rank = torch.distributed.get_rank(self.ring_group) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 1057bdaf436..3a4f98843c6 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -1,4 +1,3 @@ - # SPDX-License-Identifier: Apache-2.0 # Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. @@ -29,12 +28,13 @@ If you only need to use the distributed environment without model parallelism, you can skip the model parallel initialization and destruction steps. """ -from typing import Any, List, Optional + +from typing import List, Optional import torch import torch.distributed - from vllm.logger import init_logger + from vllm_omni.diffusion.envs import envs from .group_coordinator import ( @@ -45,17 +45,16 @@ try: import torch_musa - from torch_musa.core.device import set_device, device_count + from torch_musa.core.device import device_count, set_device except ModuleNotFoundError: pass try: - from torch.npu import set_device, device_count + from torch.npu import device_count, set_device except ModuleNotFoundError: pass - env_info = envs.PACKAGES_CHECKER.get_packages_info() HAS_FLASH_ATTN = env_info["has_flash_attn"] @@ -76,7 +75,7 @@ def generate_masked_orthogonal_rank_groups( world_size: int, parallel_size: List[int], mask: List[bool] ) -> List[List[int]]: - """Generate orthogonal parallel groups based on the parallel size and mask. + r"""Generate orthogonal parallel groups based on the parallel size and mask. Arguments: world_size (int): world size @@ -151,9 +150,9 @@ def decompose(index, shape, stride=None): idx = [(index // d) % s for s, d in zip(shape, stride)] # stride is a prefix_product result. And the value of stride[-1] # is not used. - assert ( - sum([x * y for x, y in zip(idx, stride[:-1])]) == index - ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + assert sum([x * y for x, y in zip(idx, stride[:-1])]) == index, ( + f"idx {index} with shape {shape} mismatch the return idx {idx}" + ) return idx masked_shape = [s for s, m in zip(parallel_size, mask) if m] @@ -175,14 +174,13 @@ def decompose(index, shape, stride=None): # get indices from masked for rank_in_group. decomposed_rank_idx = decompose(rank_in_group, masked_shape) rank.append( - inner_product(decomposed_rank_idx, masked_stride) - + inner_product(decomposed_group_idx, unmasked_stride) + inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride) ) ranks.append(rank) return ranks -class RankGenerator(object): +class RankGenerator: def __init__( self, tp: int, @@ -250,14 +248,14 @@ def get_ranks(self, token): get full DP group. """ mask = self.get_mask(self.order, token) - ranks = generate_masked_orthogonal_rank_groups( - self.world_size, self.ordered_size, mask - ) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) if self.rank_offset > 0: for rank_group in ranks: for i in range(len(rank_group)): rank_group[i] += self.rank_offset return ranks + + # * QUERY def get_world_group() -> GroupCoordinator: assert _WORLD is not None, "world group is not initialized" @@ -340,9 +338,7 @@ def is_pipeline_last_stage(): # CFG def get_cfg_group() -> GroupCoordinator: - assert ( - _CFG is not None - ), "classifier_free_guidance parallel group is not initialized" + assert _CFG is not None, "classifier_free_guidance parallel group is not initialized" return _CFG @@ -376,8 +372,7 @@ def is_dp_last_group(): """Return True if in the last data parallel group, False otherwise.""" return ( get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1) - and get_classifier_free_guidance_rank() - == (get_classifier_free_guidance_world_size() - 1) + and get_classifier_free_guidance_rank() == (get_classifier_free_guidance_world_size() - 1) and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) ) @@ -412,9 +407,7 @@ def get_vae_parallel_rank(): # * SET -def init_world_group( - ranks: List[int], local_rank: int, backend: str -) -> GroupCoordinator: +def init_world_group(ranks: List[int], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], local_rank=local_rank, @@ -432,7 +425,7 @@ def init_distributed_environment( if backend is None: backend = envs.get_torch_distributed_backend() logger.debug( - "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", world_size, rank, local_rank, @@ -441,8 +434,7 @@ def init_distributed_environment( ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( - "distributed_init_method must be provided when initializing " - "distributed environment" + "distributed_init_method must be provided when initializing distributed environment" ) # this backend is used for WORLD torch.distributed.init_process_group( @@ -467,20 +459,14 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) else: - assert ( - _WORLD.world_size == torch.distributed.get_world_size() - ), "world group already initialized with a different world size" + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size" + ) def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return ( - _DP is not None - and _CFG is not None - and _SP is not None - and _PP is not None - and _TP is not None - ) + return _DP is not None and _CFG is not None and _SP is not None and _PP is not None and _TP is not None def init_model_parallel_group( @@ -523,9 +509,7 @@ def init_dit_group( backend: str, ): global _DIT - _DIT = torch.distributed.new_group( - ranks=list(range(dit_parallel_size)), backend=backend - ) + _DIT = torch.distributed.new_group(ranks=list(range(dit_parallel_size)), backend=backend) def get_dit_group(): @@ -573,7 +557,7 @@ def initialize_model_parallel( Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize - splited batch caused by CFG, and 2 GPUs to parallelize sequence. + split batch caused by CFG, and 2 GPUs to parallelize sequence. dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16. diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py index 9e23d2fec51..5e64474404a 100644 --- a/vllm_omni/diffusion/envs.py +++ b/vllm_omni/diffusion/envs.py @@ -2,9 +2,10 @@ # Adapted from # https://github.com/xdit-project/xDiT/blob/main/xfuser/envs.py import os -import torch -import diffusers from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +import diffusers +import torch from packaging import version try: @@ -31,9 +32,7 @@ # used in distributed environment to determine the master address "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""), # used in distributed environment to manually set the communication port - "MASTER_PORT": lambda: ( - int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None - ), + "MASTER_PORT": lambda: (int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None), # path to cudatoolkit home directory, under which should be bin, include, # and lib directories. "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), @@ -115,9 +114,7 @@ def get_device_version(): elif _is_npu(): return None else: - raise NotImplementedError( - "No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available" - ) + raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") def get_torch_distributed_backend() -> str: @@ -130,18 +127,14 @@ def get_torch_distributed_backend() -> str: elif _is_npu(): return "hccl" else: - raise NotImplementedError( - "No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available" - ) + raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") variables: Dict[str, Callable[[], Any]] = { # ================== Other Vars ================== # used in version checking "CUDA_VERSION": lambda: version.parse(get_device_version() or "0.0"), - "TORCH_VERSION": lambda: version.parse( - version.parse(torch.__version__).base_version - ), + "TORCH_VERSION": lambda: version.parse(version.parse(torch.__version__).base_version), } @@ -151,12 +144,8 @@ def _setup_musa(environment_variables, variables): return try: if musa.is_available(): - environment_variables["MUSA_HOME"] = lambda: os.environ.get( - "MUSA_HOME", None - ) - environment_variables["MUSA_VISIBLE_DEVICES"] = lambda: os.environ.get( - "MUSA_VISIBLE_DEVICES", None - ) + environment_variables["MUSA_HOME"] = lambda: os.environ.get("MUSA_HOME", None) + environment_variables["MUSA_VISIBLE_DEVICES"] = lambda: os.environ.get("MUSA_VISIBLE_DEVICES", None) musa_ver = getattr(getattr(torch, "version", None), "musa", None) if musa_ver: variables["MUSA_VERSION"] = lambda: version.parse(musa_ver) @@ -191,31 +180,26 @@ def check_aiter(self): Checks whether ROCm AITER library is installed """ try: - import aiter logger.info("Using AITER as the attention library") return True except: if _is_hip(): logger.warning( - f'Using AMD GPUs, but library "aiter" is not installed, ' - 'defaulting to other attention mechanisms' + 'Using AMD GPUs, but library "aiter" is not installed, defaulting to other attention mechanisms' ) return False - def check_flash_attn(self, packages_info): if not torch.cuda.is_available(): return False # Check if torch_npu is available if _is_npu(): - logger.info("falsh_attn is not ready on torch_npu for now") + logger.info("`falsh_attn` is not ready on torch_npu for now") return False if _is_musa(): - logger.info( - "Flash Attention library is not supported on MUSA for the moment." - ) + logger.info("Flash Attention library is not supported on MUSA for the moment.") return False try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -223,26 +207,18 @@ def check_flash_attn(self, packages_info): if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: return False else: - from flash_attn import flash_attn_func - from flash_attn import __version__ + from flash_attn import __version__, flash_attn_func if __version__ < "2.6.0": - raise ImportError(f"install flash_attn >= 2.6.0") + raise ImportError("install flash_attn >= 2.6.0") return True except ImportError: if not packages_info.get("has_aiter", False): - logger.warning( - f'Flash Attention library "flash_attn" not found, ' - f"using pytorch attention implementation" - ) + logger.warning('Flash Attention library "flash_attn" not found, using pytorch attention implementation') return False - - def check_diffusers_version(self): - if version.parse( - version.parse(diffusers.__version__).base_version - ) < version.parse("0.30.0"): + if version.parse(version.parse(diffusers.__version__).base_version) < version.parse("0.30.0"): raise RuntimeError( f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported," f"please upgrade to version > 0.30.0" diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index e20c4f92e92..8dd49ad04e2 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -8,8 +8,6 @@ import zmq from vllm.config import LoadConfig, VllmConfig, set_current_vllm_config from vllm.distributed.device_communicators.shm_broadcast import MessageQueue -from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel - from vllm.logger import init_logger from vllm.utils import DeviceMemoryProfiler, GiB_bytes @@ -19,6 +17,7 @@ DiffusionOutput, OmniDiffusionConfig, ) +from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.request import OmniDiffusionRequest From 60616b608a55440493afb1dc97306c8ce56eed4b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:30:36 +0800 Subject: [PATCH 07/85] updates Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../text_to_image/text_to_image_usp.py | 107 ++++++++++++++++++ vllm_omni/diffusion/data.py | 2 +- .../diffusion/distributed/parallel_state.py | 34 +++--- .../qwen_image/qwen_image_transformer.py | 19 ++++ 4 files changed, 148 insertions(+), 14 deletions(-) create mode 100644 examples/offline_inference/text_to_image/text_to_image_usp.py diff --git a/examples/offline_inference/text_to_image/text_to_image_usp.py b/examples/offline_inference/text_to_image/text_to_image_usp.py new file mode 100644 index 00000000000..fa9be5ca17e --- /dev/null +++ b/examples/offline_inference/text_to_image/text_to_image_usp.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import time +from pathlib import Path + +import torch + +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_vllm_config +from vllm_omni.diffusion.distributed.parallel_state import destroy_distributed_env, get_world_group +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.utils.platform_utils import detect_device_type, is_npu + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.") + parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.") + parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") + parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") + parser.add_argument( + "--cfg_scale", + type=float, + default=4.0, + help="True classifier-free guidance scale specific to Qwen-Image.", + ) + parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") + parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") + parser.add_argument( + "--output", + type=str, + default="qwen_image_output.png", + help="Path to save the generated image (PNG).", + ) + parser.add_argument( + "--num_images_per_prompt", + type=int, + default=1, + help="Number of images to generate for the given prompt.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of denoising steps for the diffusion sampler.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + device = detect_device_type() + generator = torch.Generator(device=device).manual_seed(args.seed) + local_rank = get_world_group().local_rank + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + # Enable VAE memory optimizations on NPU + vae_use_slicing = is_npu() + vae_use_tiling = is_npu() + + omni_diffusion_config = OmniDiffusionConfig( + parallel_config=DiffusionParallelConfig(sequence_parallel_size=2, ulysses_degree=2) + ) + with set_current_vllm_config(omni_diffusion_config): + omni = Omni( + model=args.model, + od_config=omni_diffusion_config, + vae_use_slicing=vae_use_slicing, + vae_use_tiling=vae_use_tiling, + ) + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + images = omni.generate( + args.prompt, + height=args.height, + width=args.width, + generator=generator, + true_cfg_scale=args.cfg_scale, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.num_images_per_prompt, + num_outputs_per_prompt=args.num_images_per_prompt, + ) + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "qwen_image_output" + if args.num_images_per_prompt <= 1: + images[0].save(output_path) + print(f"Saved generated image to {output_path}") + else: + for idx, img in enumerate(images): + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved generated image to {save_path}") + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory / 1e9:.2f} GB, memory: {peak_memory / 1e9:.2f} GB" + ) + destroy_distributed_env() + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 4bcaf249393..2f85e1c0586 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -409,7 +409,7 @@ def set_current_vllm_config(omni_diffusion_config: OmniDiffusionConfig, check_co global _current_omni_diffusion_config, _current_prefix old_omni_diffusion_config = _current_omni_diffusion_config old_prefix = _current_prefix - from vllm.compilation.counter import compilation_counter + # from vllm.compilation.counter import compilation_counter # num_models_seen = compilation_counter.num_models_seen try: diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 3a4f98843c6..8b911cdb2b6 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -29,7 +29,7 @@ you can skip the model parallel initialization and destruction steps. """ -from typing import List, Optional +from typing import Optional import torch import torch.distributed @@ -44,7 +44,6 @@ ) try: - import torch_musa from torch_musa.core.device import device_count, set_device except ModuleNotFoundError: pass @@ -73,19 +72,19 @@ def generate_masked_orthogonal_rank_groups( - world_size: int, parallel_size: List[int], mask: List[bool] -) -> List[List[int]]: + world_size: int, parallel_size: list[int], mask: list[bool] +) -> list[list[int]]: r"""Generate orthogonal parallel groups based on the parallel size and mask. Arguments: world_size (int): world size - parallel_size (List[int]): + parallel_size (list[int]): The parallel size of each orthogonal parallel type. For example, if tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. - mask (List[bool]): + mask (list[bool]): The mask controls which parallel methods the generated groups represent. If mask[i] is True, it means the generated group contains the i-th parallelism method. For example, if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then @@ -125,14 +124,14 @@ def generate_masked_orthogonal_rank_groups( dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] """ - def prefix_product(a: List[int], init=1) -> List[int]: + def prefix_product(a: list[int], init=1) -> list[int]: r = [init] for v in a: init = init * v r.append(init) return r - def inner_product(a: List[int], b: List[int]) -> int: + def inner_product(a: list[int], b: list[int]) -> int: return sum([x * y for x, y in zip(a, b)]) def decompose(index, shape, stride=None): @@ -211,7 +210,8 @@ def __init__( for name in self.name_to_size.keys(): if name not in order and self.name_to_size[name] != 1: raise RuntimeError( - f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + f"The size of ({name}) is ({self.name_to_size[name]}), " + f"but you haven't specified the order ({self.order})." ) elif name not in order: order = order + "-" + name @@ -407,7 +407,7 @@ def get_vae_parallel_rank(): # * SET -def init_world_group(ranks: List[int], local_rank: int, backend: str) -> GroupCoordinator: +def init_world_group(ranks: list[int], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], local_rank=local_rank, @@ -470,7 +470,7 @@ def model_parallel_is_initialized(): def init_model_parallel_group( - group_ranks: List[List[int]], + group_ranks: list[list[int]], local_rank: int, backend: str, parallel_mode: str, @@ -548,7 +548,8 @@ def initialize_model_parallel( Arguments: data_parallel_degree: number of data parallelism groups. classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) - sequence_parallel_degree: number of GPUs used for sequence parallelism. sequence_parallel_degree = ulysses_degree * ring_degree + sequence_parallel_degree: number of GPUs used for sequence parallelism. + sequence_parallel_degree = ulysses_degree * ring_degree ulysses_degree: number of GPUs used for ulysses sequence parallelism. ring_degree: number of GPUs used for ring sequence parallelism. tensor_parallel_degree: number of GPUs used for tensor parallelism. @@ -594,7 +595,8 @@ def initialize_model_parallel( if sequence_parallel_degree != ring_degree * ulysses_degree: raise ValueError( - f"sequence_parallel_degree is not equal to ring_degree * ulysses_degree, {sequence_parallel_degree} != {ring_degree} * {ulysses_degree}" + "sequence_parallel_degree is not equal to ring_degree * ulysses_degree," + f" but got {sequence_parallel_degree} != {ring_degree} * {ulysses_degree}" ) # FIXME: Since the async p2p communication operation of NPU is not same as cuda in torch, @@ -719,3 +721,9 @@ def destroy_distributed_environment(): _WORLD = None if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() + + +def destroy_distributed_env(self): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 4db4156bfdb..88784994e7f 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -21,6 +21,11 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) logger = init_logger(__name__) @@ -542,6 +547,7 @@ def __init__( super().__init__() model_config = od_config.tf_model_config num_layers = model_config.num_layers + self.parallel_config = od_config.parallel_config self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -616,6 +622,14 @@ def forward( # else: # lora_scale = 1.0 + ############################################################ + # parallel inputs + ############################################################ + if self.parallel_config.sequence_parallel_size > 1: + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[ + get_sequence_parallel_rank() + ] + hidden_states = self.img_in(hidden_states) # Ensure timestep tensor is on the same device and dtype as hidden_states @@ -662,6 +676,11 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) + ############################################################ + # parallel outputs + ############################################################ + if self.parallel_config.sequence_parallel_size > 1: + output = get_sp_group().all_gather(output, dim=-2) return Transformer2DModelOutput(sample=output) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: From f590c37f3c1d32e42f242621d1f0d67d9be80048 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:56:02 +0800 Subject: [PATCH 08/85] test script for ulysses sp Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 314 ++++++++++++++++++ tests/utils.py | 231 ++++++++++++- vllm_omni/diffusion/attention/layer.py | 4 +- 3 files changed, 546 insertions(+), 3 deletions(-) create mode 100644 tests/distributed/test_ulysses_sequence_parallel.py diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py new file mode 100644 index 00000000000..71bf7fc70e3 --- /dev/null +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -0,0 +1,314 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils.system_utils import update_environment_variables + +from tests.utils import multi_gpu_test +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import ( + DiffusionParallelConfig, + OmniDiffusionConfig, + set_current_vllm_config, +) +from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel + + +class TestAttentionModel(torch.nn.Module): + """Test model using Attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + hidden_size: int, + causal: bool = True, + num_kv_heads: int | None = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.hidden_size = hidden_size + self.attention = Attention( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=1.0 / (head_size**0.5), + num_kv_heads=num_kv_heads, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + # Linear projection layers for Q, K, V + self.q_proj = torch.nn.Linear(hidden_size, num_heads * head_size) + self.k_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) + self.v_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) + self.o_proj = torch.nn.Linear(num_heads * head_size, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass through attention layer.""" + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (batch_size, seq_len, num_heads, head_size) + q = q.view(batch_size, seq_len, self.num_heads, self.head_size) + k = k.view(batch_size, seq_len, k.shape[-1] // self.head_size, self.head_size) + v = v.view(batch_size, seq_len, v.shape[-1] // self.head_size, self.head_size) + + # Apply attention + attn_output = self.attention(q, k, v) + + # Reshape back and project + attn_output = attn_output.view(batch_size, seq_len, -1) + output = self.o_proj(attn_output) + + return output + + +class TestMultiLayerAttentionModel(torch.nn.Module): + """Test model with multiple attention layers.""" + + def __init__( + self, + num_layers: int, + num_heads: int, + head_size: int, + hidden_size: int, + causal: bool = True, + num_kv_heads: int | None = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.layers = torch.nn.ModuleList( + [ + TestAttentionModel( + num_heads=num_heads, + head_size=head_size, + hidden_size=hidden_size, + causal=causal, + num_kv_heads=num_kv_heads, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass through multiple attention layers.""" + for layer in self.layers: + hidden_states = hidden_states + layer(hidden_states) + return hidden_states + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "test_model_cls", + [ + TestAttentionModel, + TestMultiLayerAttentionModel, + ], +) +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("seq_len", [16, 32]) +@pytest.mark.parametrize("num_heads", [4, 8]) +@pytest.mark.parametrize("head_size", [32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_sync", [True, False]) +@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize("use_compile", [False, True]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_ulysses_attention( + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, +): + """Test Ulysses attention with various parameter combinations.""" + num_processes = 2 + ulysses_degree = 2 # Must match num_processes for this test + ring_degree = 1 + sequence_parallel_size = ulysses_degree * ring_degree + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + test_model_cls, + batch_size, + seq_len, + num_heads, + head_size, + dtype, + causal, + use_sync, + dynamic, + use_compile, + ulysses_degree, + ring_degree, + sequence_parallel_size, + ), + nprocs=nprocs, + ) + + run_torch_spawn(ulysses_attention_on_test_model, num_processes) + + +def ulysses_attention_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, + ulysses_degree: int, + ring_degree: int, + sequence_parallel_size: int, +): + """Run Ulysses attention test on a test model.""" + current_platform.seed_everything(42) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12346", # Different port to avoid conflicts + } + ) + + # Initialize distributed environment + init_distributed_environment() + + # Set up OmniDiffusionConfig with Ulysses parallel config + parallel_config = DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + tensor_parallel_size=1, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + cfg_parallel_size=1, + ) + + od_config = OmniDiffusionConfig( + model="test_model", + dtype=dtype, + parallel_config=parallel_config, + ) + + # Initialize model parallel with Ulysses + initialize_model_parallel( + data_parallel_degree=1, + classifier_free_guidance_degree=1, + sequence_parallel_degree=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + tensor_parallel_degree=1, + pipeline_parallel_degree=1, + ) + + # Set the config so Attention can access it + with set_current_vllm_config(od_config): + # Create model + hidden_size = num_heads * head_size + + # Create model with appropriate parameters + model_kwargs = { + "num_heads": num_heads, + "head_size": head_size, + "hidden_size": hidden_size, + "causal": causal, + "num_kv_heads": None, + "scatter_idx": 2, + "gather_idx": 1, + "use_sync": use_sync, + } + + if test_model_cls == TestMultiLayerAttentionModel: + model_kwargs["num_layers"] = 2 + + model = test_model_cls(**model_kwargs) + + model = model.to(device).to(dtype) + + # Create input + # In sequence parallel, each rank gets seq_len / sequence_parallel_size + local_seq_len = seq_len // sequence_parallel_size + hidden_states = torch.randn( + (batch_size, local_seq_len, hidden_size), + dtype=dtype, + device=device, + ) + + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) + torch._dynamo.mark_dynamic(hidden_states, 1) + + # Compile model if requested + if use_compile: + model = torch.compile(model) + + # Run forward pass + output = model(hidden_states) + + # Verify output shape + assert output.shape == (batch_size, local_seq_len, hidden_size), ( + f"Output shape mismatch: expected {(batch_size, local_seq_len, hidden_size)}, got {output.shape}" + ) + + # Verify that Attention is using Ulysses + if hasattr(model, "attention"): + assert hasattr(model.attention, "use_ulysses"), "Attention should have use_ulysses attribute" + assert model.attention.use_ulysses, "Attention should be using Ulysses" + elif hasattr(model, "layers"): + for i, layer in enumerate(model.layers): + assert hasattr(layer.attention, "use_ulysses"), f"Layer {i} attention should have use_ulysses attribute" + assert layer.attention.use_ulysses, f"Layer {i} attention should be using Ulysses" + + # Run backward pass to ensure gradients work + loss = output.sum() + loss.backward() + + print( + f"Rank {local_rank}: Test passed with " + f"batch_size={batch_size}, seq_len={seq_len}, " + f"num_heads={num_heads}, head_size={head_size}, " + f"dtype={dtype}, causal={causal}, use_sync={use_sync}, " + f"dynamic={dynamic}, use_compile={use_compile}" + ) diff --git a/tests/utils.py b/tests/utils.py index aba734501eb..c4b1f7f1441 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import contextlib +import functools import os +import signal +import subprocess +import sys +import tempfile import time -from contextlib import contextmanager +from collections.abc import Callable +from contextlib import ExitStack, contextmanager, suppress +from pathlib import Path +from typing import Any, Literal +import cloudpickle +import pytest +from typing_extensions import ParamSpec from vllm.platforms import current_platform +from vllm.utils.torch_utils import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( @@ -105,3 +117,218 @@ def wait_for_gpu_memory_to_clear( raise ValueError(f"Memory of devices {devices=} not free after {dur_s=:.02f} ({threshold=})") time.sleep(5) + + +VLLM_PATH = Path(__file__).parent.parent +"""Path to root of the vLLM repository.""" + +_P = ParamSpec("_P") + + +def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to fork a new process for each test function. + See https://github.com/vllm-project/vllm/issues/7053 for more details. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Make the process the leader of its own process group + # to avoid sending SIGTERM to the parent process + os.setpgrp() + from _pytest.outcomes import Skipped + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with ( + tempfile.NamedTemporaryFile( + delete=False, + mode="w+b", + prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", + suffix=".exc", + ) as exc_file, + ExitStack() as delete_after, + ): + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {"pickled_exception": e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": tb_string, + } + try: + with open(exc_file_path, "wb") as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) + else: + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with ( + contextlib.suppress(Exception), + open(exc_file_path, "rb") as f, + ): + exc_info = cloudpickle.load(f) + + if (original_exception := exc_info.get("pickled_exception")) is not None: + # Re-raise the actual exception object if it was + # successfully pickled. + assert isinstance(original_exception, Exception) + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})" + ) from None + + return wrapper + + +def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function.""" + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Check if we're already in a subprocess + if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": + # If we are, just run the function directly + return f(*args, **kwargs) + + import torch.multiprocessing as mp + + with suppress(RuntimeError): + mp.set_start_method("spawn") + + # Get the module + module_name = f.__module__ + + # Create a process with environment variable set + env = os.environ.copy() + env["RUNNING_IN_SUBPROCESS"] = "1" + + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "new_process.tmp") + + # `cloudpickle` allows pickling complex functions directly + input_bytes = cloudpickle.dumps((f, output_filepath)) + + repo_root = str(VLLM_PATH.resolve()) + + env = dict(env or os.environ) + env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "") + + cmd = [sys.executable, "-m", f"{module_name}"] + + returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e + + return wrapper + + +def create_new_process_for_each_test( + method: Literal["spawn", "fork"] | None = None, +) -> Callable[[Callable[_P, None]], Callable[_P, None]]: + """Creates a decorator that runs each test function in a new process. + + Args: + method: The process creation method. Can be either "spawn" or "fork". + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. + + Returns: + A decorator to run test functions in separate processes. + """ + if method is None: + use_spawn = current_platform.is_rocm() or current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" + + assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" + + if method == "fork": + return fork_new_process_for_each_test + + return spawn_new_process_for_each_test + + +def multi_gpu_marks(*, num_gpus: int): + """Get a collection of pytest marks to apply for `@multi_gpu_test`.""" + test_selector = pytest.mark.distributed(num_gpus=num_gpus) + test_skipif = pytest.mark.skipif( + cuda_device_count_stateless() < num_gpus, + reason=f"Need at least {num_gpus} GPUs to run the test.", + ) + + return [test_selector, test_skipif] + + +def multi_gpu_test(*, num_gpus: int): + """ + Decorate a test to be run only when multiple GPUs are available. + """ + marks = multi_gpu_marks(num_gpus=num_gpus) + + def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: + func = create_new_process_for_each_test()(f) + for mark in reversed(marks): + func = mark(func) + + return func + + return wrapper diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index b9bb0882c5c..523443ed1ce 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -53,9 +53,10 @@ def __init__( self.gather_idx = gather_idx self.use_sync = use_sync self.sequence_process_group: Optional[dist.ProcessGroup] = None - config = get_current_omni_diffusion_config() + self.use_ulysses = False try: + config = get_current_omni_diffusion_config() if config.parallel_config.ulysses_degree > 1: self.use_ulysses = True # Get sequence parallel process group @@ -64,6 +65,7 @@ def __init__( self.sequence_process_group = sp_group.device_group assert get_sequence_parallel_world_size() > 1, "Sequence parallel world size must be > 1" except (AssertionError, RuntimeError): + # If sequence parallel group is not initialized, disable Ulysses self.use_ulysses = False except Exception: self.use_ulysses = False From 883116a21ff77790fabdafdbc631bc63459d9df8 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:34:07 +0800 Subject: [PATCH 09/85] updates test Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/distributed/test_ulysses_sequence_parallel.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index 71bf7fc70e3..d641882e748 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -5,7 +5,6 @@ import torch import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils.system_utils import update_environment_variables from tests.utils import multi_gpu_test from vllm_omni.diffusion.attention.layer import Attention @@ -201,17 +200,6 @@ def ulysses_attention_on_test_model( torch.cuda.set_device(device) torch.set_default_device(device) torch.set_default_dtype(dtype) - - update_environment_variables( - { - "RANK": str(local_rank), - "LOCAL_RANK": str(local_rank), - "WORLD_SIZE": str(world_size), - "MASTER_ADDR": "localhost", - "MASTER_PORT": "12346", # Different port to avoid conflicts - } - ) - # Initialize distributed environment init_distributed_environment() From a158c4abcceed48dade3565a3a00cc9a02cc3017 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:50:53 +0800 Subject: [PATCH 10/85] fix import errors Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 4 -- .../distributed/group_coordinator.py | 72 +++++++++---------- .../diffusion/distributed/parallel_state.py | 2 +- vllm_omni/diffusion/envs.py | 16 ++--- 4 files changed, 39 insertions(+), 55 deletions(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index d641882e748..19b702bf018 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -3,10 +3,8 @@ import pytest import torch -import vllm.envs as envs from vllm.platforms import current_platform -from tests.utils import multi_gpu_test from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import ( DiffusionParallelConfig, @@ -114,7 +112,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "test_model_cls", [ @@ -130,7 +127,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("use_sync", [True, False]) @pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.parametrize("use_compile", [False, True]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_ulysses_attention( test_model_cls: type[torch.nn.Module], batch_size: int, diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index c4c0add3783..bb4d1bf93ae 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -5,20 +5,14 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import pickle from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch import torch.distributed from torch.cuda import synchronize from torch.distributed import Backend, ProcessGroup -try: - import torch_musa - from torch_musa.core.device import synchronize -except ModuleNotFoundError: - pass - -from vllm_omni.diffusion.envs import envs +from vllm_omni.diffusion import envs if envs._is_npu(): print("torch.npu synchronize") @@ -34,8 +28,8 @@ def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + tensor_dict: dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced by its metadata. @@ -44,7 +38,7 @@ def _split_tensor_dict( If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its metadata will be "key1%key2". """ - metadata_list: List[Tuple[str, Any]] = [] + metadata_list: list[tuple[str, Any]] = [] tensor_list = [] for key, value in tensor_dict.items(): assert "%" not in key, "Avoid having '%' in key as it is used as a separator for nested entries." @@ -90,7 +84,7 @@ class GroupCoordinator: # available attributes: rank: int # global rank - ranks: List[int] # global ranks in the group + ranks: list[int] # global ranks in the group world_size: int # size of the group # difference between `local_rank` and `rank_in_group`: # if we have a group of size 4 across two nodes: @@ -106,7 +100,7 @@ class GroupCoordinator: def __init__( self, - group_ranks: List[List[int]], + group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], ): @@ -209,7 +203,7 @@ def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceO def all_gather( self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: @@ -307,7 +301,7 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) return recv[0] - def broadcast_object_list(self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None): + def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ @@ -372,11 +366,11 @@ def recv_object(self, src: int) -> Any: def broadcast_tensor_dict( self, - tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, metadata_group: Optional[ProcessGroup] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -391,7 +385,7 @@ def broadcast_tensor_dict( rank = self.rank if rank == src: - metadata_list: List[Tuple[Any, Any]] = [] + metadata_list: list[tuple[Any, Any]] = [] assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. @@ -440,9 +434,9 @@ def broadcast_tensor_dict( def send_tensor_dict( self, - tensor_dict: Dict[str, Union[torch.Tensor, Any]], + tensor_dict: dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ @@ -457,7 +451,7 @@ def send_tensor_dict( dst = self.group_next_rank assert dst < self.world_size, f"Invalid dst rank ({dst})" - metadata_list: List[Tuple[Any, Any]] = [] + metadata_list: list[tuple[Any, Any]] = [] assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. @@ -476,7 +470,7 @@ def send_tensor_dict( torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None - def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -492,7 +486,7 @@ def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[Dict[str, Unio assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) - tensor_dict: Dict[str, Any] = {} + tensor_dict: dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) @@ -559,7 +553,7 @@ class PipelineGroupCoordinator(GroupCoordinator): """ available attributes: rank: int # global rank - ranks: List[int] # global ranks in the group + ranks: list[int] # global ranks in the group world_size: int # size of the group difference between `local_rank` and `rank_in_group`: if we have a group of size 4 across two nodes: @@ -576,7 +570,7 @@ class PipelineGroupCoordinator(GroupCoordinator): def __init__( self, - group_ranks: List[List[int]], + group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], ): @@ -627,19 +621,19 @@ def __init__( self.device = envs.get_device(local_rank) self.recv_buffer_set: bool = False - self.recv_tasks_queue: List[Tuple[str, int]] = [] - self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.recv_tasks_queue: list[tuple[str, int]] = [] + self.receiving_tasks: list[tuple[torch.distributed.Work, str, int]] = [] self.dtype: Optional[torch.dtype] = None self.num_pipefusion_patches: Optional[int] = None - self.recv_shape: Dict[str, Dict[int, torch.Size]] = {} - self.send_shape: Dict[str, Dict[int, torch.Size]] = {} - self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {} + self.recv_shape: dict[str, dict[int, torch.Size]] = {} + self.send_shape: dict[str, dict[int, torch.Size]] = {} + self.recv_buffer: dict[str, dict[int, torch.Size]] = {} self.skip_tensor_recv_buffer_set: bool = False - self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = [] - self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] - self.skip_tensor_recv_buffer: Optional[Union[List[torch.Tensor], torch.Tensor]] = None + self.recv_skip_tasks_queue: list[Union[int, tuple[str, int]]] = [] + self.receiving_skip_tasks: list[tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: Optional[Union[list[torch.Tensor], torch.Tensor]] = None self.skip_device_group = None for ranks in group_ranks: skip_device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) @@ -664,8 +658,8 @@ def set_config(self, dtype: torch.dtype): def set_recv_buffer( self, num_pipefusion_patches: int, - patches_shape_list: List[List[int]], - feature_map_shape: List[int], + patches_shape_list: list[list[int]], + feature_map_shape: list[int], dtype: torch.dtype, ): assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" @@ -681,7 +675,7 @@ def set_recv_buffer( def set_extra_tensors_recv_buffer( self, name: str, - shape: List[int], + shape: list[int], num_buffers: int = 1, dtype: torch.dtype = torch.float16, ): @@ -863,8 +857,8 @@ def _pipeline_isend(self, tensor: torch.tensor): def set_skip_tensor_recv_buffer( self, - patches_shape_list: List[List[int]], - feature_map_shape: List[int], + patches_shape_list: list[list[int]], + feature_map_shape: list[int], ): self.skip_tensor_recv_buffer = [ torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list @@ -918,7 +912,7 @@ def _pipeline_isend_skip(self, tensor: torch.tensor): class SequenceParallelGroupCoordinator(GroupCoordinator): def __init__( self, - group_ranks: List[List[int]], + group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], **kwargs, diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 8b911cdb2b6..bea7501ddbc 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -35,7 +35,7 @@ import torch.distributed from vllm.logger import init_logger -from vllm_omni.diffusion.envs import envs +from vllm_omni.diffusion import envs from .group_coordinator import ( GroupCoordinator, diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py index 5e64474404a..5a24d4e9f57 100644 --- a/vllm_omni/diffusion/envs.py +++ b/vllm_omni/diffusion/envs.py @@ -2,18 +2,12 @@ # Adapted from # https://github.com/xdit-project/xDiT/blob/main/xfuser/envs.py import os -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional import diffusers import torch from packaging import version - -try: - import torch_musa -except ModuleNotFoundError: - pass - -from xfuser.logger import init_logger +from vllm.logger import init_logger logger = init_logger(__name__) @@ -27,7 +21,7 @@ TORCH_VERSION: version.Version -environment_variables: Dict[str, Callable[[], Any]] = { +environment_variables: dict[str, Callable[[], Any]] = { # ================== Runtime Env Vars ================== # used in distributed environment to determine the master address "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""), @@ -130,7 +124,7 @@ def get_torch_distributed_backend() -> str: raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") -variables: Dict[str, Callable[[], Any]] = { +variables: dict[str, Callable[[], Any]] = { # ================== Other Vars ================== # used in version checking "CUDA_VERSION": lambda: version.parse(get_device_version() or "0.0"), @@ -195,7 +189,7 @@ def check_flash_attn(self, packages_info): # Check if torch_npu is available if _is_npu(): - logger.info("`falsh_attn` is not ready on torch_npu for now") + logger.info("`flash_attn` is not ready on torch_npu for now") return False if _is_musa(): From bd0a8510ab459e1a1598cd045dd9ddda2b3c5083 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 16:47:10 +0800 Subject: [PATCH 11/85] update test script Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 42 ++- tests/utils.py | 13 + vllm_omni/utils/system_utils.py | 269 ++++++++++++++++++ 3 files changed, 309 insertions(+), 15 deletions(-) create mode 100644 vllm_omni/utils/system_utils.py diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index 19b702bf018..afdc2067bf1 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -12,6 +12,7 @@ set_current_vllm_config, ) from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel +from vllm_omni.utils.system_utils import update_environment_variables class TestAttentionModel(torch.nn.Module): @@ -22,7 +23,7 @@ def __init__( num_heads: int, head_size: int, hidden_size: int, - causal: bool = True, + causal: bool = False, num_kv_heads: int | None = None, scatter_idx: int = 2, gather_idx: int = 1, @@ -121,23 +122,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) @pytest.mark.parametrize("batch_size", [2, 4]) @pytest.mark.parametrize("seq_len", [16, 32]) -@pytest.mark.parametrize("num_heads", [4, 8]) -@pytest.mark.parametrize("head_size", [32, 64]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_sync", [True, False]) -@pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.parametrize("use_compile", [False, True]) +# @pytest.mark.parametrize("num_heads", [4, 8]) +# @pytest.mark.parametrize("head_size", [32, 64]) +# @pytest.mark.parametrize("causal", [True, False]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("use_sync", [True, False]) +# @pytest.mark.parametrize("dynamic", [False, True]) +# @pytest.mark.parametrize("use_compile", [False, True]) def test_ulysses_attention( test_model_cls: type[torch.nn.Module], batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - causal: bool, - use_sync: bool, - dynamic: bool, - use_compile: bool, + seq_len: int = 16, + num_heads: int = 4, + head_size: int = 32, + dtype: torch.dtype = torch.float16, + causal: bool = False, + use_sync: bool = False, + dynamic: bool = False, + use_compile: bool = False, ): """Test Ulysses attention with various parameter combinations.""" num_processes = 2 @@ -196,6 +198,16 @@ def ulysses_attention_on_test_model( torch.cuda.set_device(device) torch.set_default_device(device) torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # Initialize distributed environment init_distributed_environment() diff --git a/tests/utils.py b/tests/utils.py index c4b1f7f1441..0e2dc3f30f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -332,3 +332,16 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: return func return wrapper + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v diff --git a/vllm_omni/utils/system_utils.py b/vllm_omni/utils/system_utils.py new file mode 100644 index 00000000000..94460885bc7 --- /dev/null +++ b/vllm_omni/utils/system_utils.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copied from +# https://github.com/vllm-project/vllm/blob/main/vllm/utils/system_utils.py +from __future__ import annotations + +import contextlib +import multiprocessing +import os +import signal +import sys +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import TextIO + +import psutil +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.ray.lazy_utils import is_in_ray_actor + +# from .platform_utils import cuda_is_initialized, xpu_is_initialized + +logger = init_logger(__name__) + +CYAN = "\033[0;36m" +RESET = "\033[0;0m" + + +# Environment variable utilities + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v + + +@contextlib.contextmanager +def set_env_var(key: str, value: str) -> Iterator[None]: + """Temporarily set an environment variable.""" + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + os.environ.pop(key, None) + else: + os.environ[key] = old + + +@contextlib.contextmanager +def suppress_stdout(): + """ + Suppress stdout from C libraries at the file descriptor level. + + Only suppresses stdout, not stderr, to preserve error messages. + Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG. + + Example: + with suppress_stdout(): + # C library calls that would normally print to stdout + torch.distributed.new_group(ranks, backend="gloo") + """ + # Don't suppress if logging level is DEBUG + if envs.VLLM_LOGGING_LEVEL == "DEBUG": + yield + return + + stdout_fd = sys.stdout.fileno() + stdout_dup = os.dup(stdout_fd) + devnull_fd = os.open(os.devnull, os.O_WRONLY) + + try: + sys.stdout.flush() + os.dup2(devnull_fd, stdout_fd) + yield + finally: + sys.stdout.flush() + os.dup2(stdout_dup, stdout_fd) + os.close(stdout_dup) + os.close(devnull_fd) + + +# File path utilities + + +def unique_filepath(fn: Callable[[int], Path]) -> Path: + """Generate a unique file path by trying incrementing integers. + + Note: This function has a TOCTOU race condition. + Caller should use atomic operations (e.g., open with 'x' mode) + when creating the file to ensure thread safety. + """ + i = 0 + while True: + p = fn(i) + if not p.exists(): + return p + i += 1 + + +# Process management utilities + + +def _maybe_force_spawn(): + """Check if we need to force the use of the `spawn` multiprocessing start + method. + """ + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": + return + + reasons = [] + if is_in_ray_actor(): + # even if we choose to spawn, we need to pass the ray address + # to the subprocess so that it knows how to connect to the ray cluster. + # env vars are inherited by subprocesses, even if we use spawn. + import ray + + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address + reasons.append("In a Ray actor and can only be spawned") + + # if cuda_is_initialized(): + # reasons.append("CUDA is initialized") + # elif xpu_is_initialized(): + # reasons.append("XPU is initialized") + + if reasons: + logger.warning( + "We must use the `spawn` multiprocessing start method. " + "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/usage/" + "troubleshooting.html#python-multiprocessing " + "for more information. Reasons: %s", + "; ".join(reasons), + ) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def get_mp_context(): + """Get a multiprocessing context with a particular method (spawn or fork). + By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to + determine the multiprocessing method (default is fork). However, under + certain conditions, we may enforce spawn and override the value of + VLLM_WORKER_MULTIPROC_METHOD. + """ + _maybe_force_spawn() + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + return multiprocessing.get_context(mp_method) + + +def set_process_title( + name: str, + suffix: str = "", + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX, +) -> None: + """Set the current process title with optional suffix.""" + try: + import setproctitle + except ImportError: + return + + if suffix: + name = f"{name}_{suffix}" + + setproctitle.setproctitle(f"{prefix}::{name}") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Add colored prefix to file output for log decoration.""" + if envs.NO_COLOR: + prefix = f"({worker_name} pid={pid}) " + else: + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find("\n", idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] + + +def decorate_logs(process_name: str | None = None) -> None: + """Decorate stdout/stderr with process name and PID prefix.""" + # Respect VLLM_CONFIGURE_LOGGING environment variable + if not envs.VLLM_CONFIGURE_LOGGING: + return + + if process_name is None: + process_name = get_mp_context().current_process().name + + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + +def kill_process_tree(pid: int): + """ + Kills all descendant processes of the given pid by sending SIGKILL. + + Args: + pid (int): Process ID of the parent process + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Send SIGKILL to all children first + for child in children: + with contextlib.suppress(ProcessLookupError): + os.kill(child.pid, signal.SIGKILL) + + # Finally kill the parent + with contextlib.suppress(ProcessLookupError): + os.kill(pid, signal.SIGKILL) + + +# Resource utilities + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 +def set_ulimit(target_soft_limit: int = 65535): + if sys.platform.startswith("win"): + logger.info("Windows detected, skipping ulimit adjustment.") + return + + import resource + + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + logger.warning( + "Found ulimit of %s and failed to automatically increase " + "with error %s. This can cause fd limit errors like " + "`OSError: [Errno 24] Too many open files`. Consider " + "increasing with ulimit -n", + current_soft, + e, + ) From 1b2206131399ef5292cf8f74e2b3c1d291e68608 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:04:48 +0800 Subject: [PATCH 12/85] set ring and ulysses group Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 2 +- .../diffusion/distributed/parallel_state.py | 62 ++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index afdc2067bf1..813258a1f4a 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -121,7 +121,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ], ) @pytest.mark.parametrize("batch_size", [2, 4]) -@pytest.mark.parametrize("seq_len", [16, 32]) +# @pytest.mark.parametrize("seq_len", [16, 32]) # @pytest.mark.parametrize("num_heads", [4, 8]) # @pytest.mark.parametrize("head_size", [32, 64]) # @pytest.mark.parametrize("causal", [True, False]) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index bea7501ddbc..8225eeb0d24 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -33,6 +33,7 @@ import torch import torch.distributed +from torch.cuda import device_count, set_device from vllm.logger import init_logger from vllm_omni.diffusion import envs @@ -529,6 +530,58 @@ def init_vae_group( _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) +# adapted from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py +def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True): + """ + sp_ulysses_degree x sp_ring_degree = seq_parallel_degree + (ulysses_degree, dp_degree) + """ + sp_degree = sp_ring_degree * sp_ulysses_degree + dp_degree = world_size // sp_degree + + assert world_size % sp_degree == 0, f"world_size {world_size} % sp_degree {sp_ulysses_degree} == 0" + + num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree + num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree + + if use_ulysses_low: + for dp_rank in range(dp_degree): + offset = dp_rank * sp_degree + for i in range(num_ulysses_pgs): + ulysses_ranks = list( + range( + i * sp_ulysses_degree + offset, + (i + 1) * sp_ulysses_degree + offset, + ) + ) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + for i in range(num_ring_pgs): + ring_ranks = list(range(i + offset, sp_degree + offset, num_ring_pgs)) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + else: + for dp_rank in range(dp_degree): + offset = dp_rank * sp_degree + for i in range(num_ring_pgs): + ring_ranks = list(range(i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset)) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + for i in range(num_ulysses_pgs): + ulysses_ranks = list(range(i + offset, sp_degree + offset, num_ulysses_pgs)) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + return ulyssess_pg, ring_pg + + def initialize_model_parallel( data_parallel_degree: int = 1, classifier_free_guidance_degree: int = 1, @@ -659,12 +712,19 @@ def initialize_model_parallel( global _SP assert _SP is None, "sequence parallel group is already initialized" - + ulysses_pg, ring_pg = set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=dit_parallel_size, + ) _SP = init_model_parallel_group( group_ranks=rank_generator.get_ranks("sp"), local_rank=get_world_group().local_rank, backend=backend, parallel_mode="sequence", + ulysses_group=ulysses_pg, + ring_group=ring_pg, ) global _TP From 69e216b525e4f78ee622190b47fa6337c7e39cb0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:55:25 +0800 Subject: [PATCH 13/85] update pg Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 523443ed1ce..389f48c27e2 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -52,7 +52,8 @@ def __init__( self.scatter_idx = scatter_idx self.gather_idx = gather_idx self.use_sync = use_sync - self.sequence_process_group: Optional[dist.ProcessGroup] = None + self.ring_pg: Optional[dist.ProcessGroup] = None + self.ulysses_pg: Optional[dist.ProcessGroup] = None self.use_ulysses = False try: @@ -62,7 +63,8 @@ def __init__( # Get sequence parallel process group try: sp_group = get_sp_group() - self.sequence_process_group = sp_group.device_group + self.ring_pg = sp_group.ring_group + self.ulysses_pg = sp_group.ulysses_group assert get_sequence_parallel_world_size() > 1, "Sequence parallel world size must be > 1" except (AssertionError, RuntimeError): # If sequence parallel group is not initialized, disable Ulysses @@ -94,9 +96,9 @@ def _forward_ulysses( """Ulysses attention forward pass with sequence parallelism.""" # scatter 2, gather 1 # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) - q = SeqAllToAll4D.apply(self.sequence_process_group, query, self.scatter_idx, self.gather_idx, self.use_sync) - k = SeqAllToAll4D.apply(self.sequence_process_group, key, self.scatter_idx, self.gather_idx, self.use_sync) - v = SeqAllToAll4D.apply(self.sequence_process_group, value, self.scatter_idx, self.gather_idx, self.use_sync) + q = SeqAllToAll4D.apply(self.ulysses_pg, query, self.scatter_idx, self.gather_idx, self.use_sync) + k = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx, self.use_sync) + v = SeqAllToAll4D.apply(self.ulysses_pg, value, self.scatter_idx, self.gather_idx, self.use_sync) softmax_scale = self.softmax_scale if softmax_scale is None: @@ -127,8 +129,6 @@ def _forward_ulysses( # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) # scatter 1, gather 2 - output = SeqAllToAll4D.apply( - self.sequence_process_group, context_layer, self.gather_idx, self.scatter_idx, self.use_sync - ) + output = SeqAllToAll4D.apply(self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync) return output From 2b3b89797c2c56d0b2b386ed7a3a5d088fc32ad5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:00:16 +0800 Subject: [PATCH 14/85] fix test shape Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index 813258a1f4a..691f54befdc 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -120,26 +120,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: TestMultiLayerAttentionModel, ], ) -@pytest.mark.parametrize("batch_size", [2, 4]) -# @pytest.mark.parametrize("seq_len", [16, 32]) -# @pytest.mark.parametrize("num_heads", [4, 8]) -# @pytest.mark.parametrize("head_size", [32, 64]) -# @pytest.mark.parametrize("causal", [True, False]) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("use_sync", [True, False]) -# @pytest.mark.parametrize("dynamic", [False, True]) -# @pytest.mark.parametrize("use_compile", [False, True]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_sync", [True, False]) +@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize("use_compile", [False, True]) def test_ulysses_attention( test_model_cls: type[torch.nn.Module], - batch_size: int, + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, + batch_size: int = 2, seq_len: int = 16, - num_heads: int = 4, + num_heads: int = 8, head_size: int = 32, - dtype: torch.dtype = torch.float16, - causal: bool = False, - use_sync: bool = False, - dynamic: bool = False, - use_compile: bool = False, ): """Test Ulysses attention with various parameter combinations.""" num_processes = 2 From 4d3f970616260564c9183f7584960d91542d0d0a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:05:55 +0800 Subject: [PATCH 15/85] destroy comm group Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/distributed/test_ulysses_sequence_parallel.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index 691f54befdc..cbade46771f 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -11,7 +11,11 @@ OmniDiffusionConfig, set_current_vllm_config, ) -from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, +) from vllm_omni.utils.system_utils import update_environment_variables @@ -304,3 +308,4 @@ def ulysses_attention_on_test_model( f"dtype={dtype}, causal={causal}, use_sync={use_sync}, " f"dynamic={dynamic}, use_compile={use_compile}" ) + destroy_distributed_env() From 28502a2c9ab892aff9ef635990d0ff831b94c4b2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:13:30 +0800 Subject: [PATCH 16/85] remove redundant arg Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 8225eeb0d24..2d76be56997 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -783,7 +783,7 @@ def destroy_distributed_environment(): torch.distributed.destroy_process_group() -def destroy_distributed_env(self): +def destroy_distributed_env(): if model_parallel_is_initialized(): destroy_model_parallel() destroy_distributed_environment() From c54be3b86a69b6cc4779a4af5a0bf0ad49b4eb05 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:14:42 +0800 Subject: [PATCH 17/85] new test parameter Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/distributed/test_ulysses_sequence_parallel.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index cbade46771f..f58432fe541 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -124,6 +124,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: TestMultiLayerAttentionModel, ], ) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_sync", [True, False]) @@ -136,10 +140,10 @@ def test_ulysses_attention( use_sync: bool, dynamic: bool, use_compile: bool, - batch_size: int = 2, - seq_len: int = 16, - num_heads: int = 8, - head_size: int = 32, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, ): """Test Ulysses attention with various parameter combinations.""" num_processes = 2 From f6c28c9004e65c89d18afdfb9d5f9f3f1150930c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:30:26 +0800 Subject: [PATCH 18/85] update test sp Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../text_to_image/text_to_image_sp.py | 121 ++++++++++++++++++ .../text_to_image/text_to_image_usp.py | 20 ++- 2 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 examples/offline_inference/text_to_image/text_to_image_sp.py diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py new file mode 100644 index 00000000000..ccaece46b4d --- /dev/null +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import time +from pathlib import Path + +import torch + +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_vllm_config +from vllm_omni.diffusion.distributed.parallel_state import destroy_distributed_env, get_world_group +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.utils.platform_utils import detect_device_type, is_npu + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.") + parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.") + parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") + parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") + parser.add_argument( + "--cfg_scale", + type=float, + default=4.0, + help="True classifier-free guidance scale specific to Qwen-Image.", + ) + parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") + parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") + parser.add_argument( + "--output", + type=str, + default="qwen_image_output.png", + help="Path to save the generated image (PNG).", + ) + parser.add_argument( + "--num_images_per_prompt", + type=int, + default=1, + help="Number of images to generate for the given prompt.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=2, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + device = detect_device_type() + generator = torch.Generator(device=device).manual_seed(args.seed) + local_rank = get_world_group().local_rank + + # Enable VAE memory optimizations on NPU + vae_use_slicing = is_npu() + vae_use_tiling = is_npu() + sequence_parallel_size = args.ulysses_degree * args.ring_degree + omni_diffusion_config = OmniDiffusionConfig( + parallel_config=DiffusionParallelConfig( + ulysses_degree=args.ulysses_degree, sequence_parallel_size=sequence_parallel_size + ) + ) + with set_current_vllm_config(omni_diffusion_config): + omni = Omni( + model=args.model, + od_config=omni_diffusion_config, + vae_use_slicing=vae_use_slicing, + vae_use_tiling=vae_use_tiling, + ) + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + images = omni.generate( + args.prompt, + height=args.height, + width=args.width, + generator=generator, + true_cfg_scale=args.cfg_scale, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.num_images_per_prompt, + num_outputs_per_prompt=args.num_images_per_prompt, + ) + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "qwen_image_output" + if args.num_images_per_prompt <= 1: + images[0].save(output_path) + print(f"Saved generated image to {output_path}") + else: + for idx, img in enumerate(images): + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved generated image to {save_path}") + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory / 1e9:.2f} GB, memory: {peak_memory / 1e9:.2f} GB" + ) + destroy_distributed_env() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/text_to_image/text_to_image_usp.py b/examples/offline_inference/text_to_image/text_to_image_usp.py index fa9be5ca17e..ccaece46b4d 100644 --- a/examples/offline_inference/text_to_image/text_to_image_usp.py +++ b/examples/offline_inference/text_to_image/text_to_image_usp.py @@ -44,6 +44,18 @@ def parse_args() -> argparse.Namespace: default=50, help="Number of denoising steps for the diffusion sampler.", ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=2, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) return parser.parse_args() @@ -52,14 +64,15 @@ def main(): device = detect_device_type() generator = torch.Generator(device=device).manual_seed(args.seed) local_rank = get_world_group().local_rank - parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - + sequence_parallel_size = args.ulysses_degree * args.ring_degree omni_diffusion_config = OmniDiffusionConfig( - parallel_config=DiffusionParallelConfig(sequence_parallel_size=2, ulysses_degree=2) + parallel_config=DiffusionParallelConfig( + ulysses_degree=args.ulysses_degree, sequence_parallel_size=sequence_parallel_size + ) ) with set_current_vllm_config(omni_diffusion_config): omni = Omni( @@ -68,6 +81,7 @@ def main(): vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, ) + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") torch.cuda.reset_peak_memory_stats() start_time = time.time() images = omni.generate( From 1d769255078d4381aac16e25749f37c2f334c1ee Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:38:48 +0800 Subject: [PATCH 19/85] allow sp is None Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../offline_inference/text_to_image/text_to_image_sp.py | 6 ++---- vllm_omni/diffusion/data.py | 4 +++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index ccaece46b4d..fa5129eaa59 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -68,11 +68,9 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - sequence_parallel_size = args.ulysses_degree * args.ring_degree + omni_diffusion_config = OmniDiffusionConfig( - parallel_config=DiffusionParallelConfig( - ulysses_degree=args.ulysses_degree, sequence_parallel_size=sequence_parallel_size - ) + parallel_config=DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) ) with set_current_vllm_config(omni_diffusion_config): omni = Omni( diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 2f85e1c0586..49a1b9d5622 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -37,7 +37,7 @@ class DiffusionParallelConfig: tensor_parallel_size: int = 1 """Number of tensor parallel groups.""" - sequence_parallel_size: int = 1 + sequence_parallel_size: int | None = None """Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree""" ulysses_degree: int = 1 @@ -55,6 +55,8 @@ def _validate_parallel_config(self) -> Self: assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0" assert self.data_parallel_size > 0, "Data parallel size must be > 0" assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0" + if self.sequence_parallel_size is None: + self.sequence_parallel_size = self.ulysses_degree * self.ring_degree assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0" assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" From 19fa17460efb3a466b001a208dd6db6095b39172 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:50:44 +0800 Subject: [PATCH 20/85] default config Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 49a1b9d5622..a808e0c61f4 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -256,6 +256,8 @@ class OmniDiffusionConfig: trust_remote_code: bool = False revision: str | None = None + num_gpus: int | None = None + hsdp_replicate_dim: int = 1 hsdp_shard_dim: int = -1 dist_timeout: int | None = None # timeout for torch.distributed @@ -377,6 +379,12 @@ def __post_init__(self): # TODO: remove hard code initial_master_port = (self.master_port or 30005) + random.randint(0, 100) self.master_port = self.settle_port(initial_master_port, 37) + if self.num_gpus is None: + self.num_gpus = 1 + if self.num_gpus < self.parallel_config.world_size: + raise ValueError( + f"num_gpus ({self.num_gpus}) < parallel_config.world_size ({self.parallel_config.world_size})" + ) # Convert cache_config dict to DiffusionCacheConfig if needed if isinstance(self.cache_config, dict): From 61e48fe33e0b9d3ba6ce080d7d72c25bc4ebf1a5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 09:47:35 +0800 Subject: [PATCH 21/85] revert utils changes Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/utils.py | 244 +------------------------------------------------ 1 file changed, 2 insertions(+), 242 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 0e2dc3f30f0..aba734501eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,23 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -import functools + import os -import signal -import subprocess -import sys -import tempfile import time -from collections.abc import Callable -from contextlib import ExitStack, contextmanager, suppress -from pathlib import Path -from typing import Any, Literal +from contextlib import contextmanager -import cloudpickle -import pytest -from typing_extensions import ParamSpec from vllm.platforms import current_platform -from vllm.utils.torch_utils import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( @@ -117,231 +105,3 @@ def wait_for_gpu_memory_to_clear( raise ValueError(f"Memory of devices {devices=} not free after {dur_s=:.02f} ({threshold=})") time.sleep(5) - - -VLLM_PATH = Path(__file__).parent.parent -"""Path to root of the vLLM repository.""" - -_P = ParamSpec("_P") - - -def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to fork a new process for each test function. - See https://github.com/vllm-project/vllm/issues/7053 for more details. - """ - - @functools.wraps(func) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # Make the process the leader of its own process group - # to avoid sending SIGTERM to the parent process - os.setpgrp() - from _pytest.outcomes import Skipped - - # Create a unique temporary file to store exception info from child - # process. Use test function name and process ID to avoid collisions. - with ( - tempfile.NamedTemporaryFile( - delete=False, - mode="w+b", - prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", - suffix=".exc", - ) as exc_file, - ExitStack() as delete_after, - ): - exc_file_path = exc_file.name - delete_after.callback(os.remove, exc_file_path) - - pid = os.fork() - print(f"Fork a new process to run a test {pid}") - if pid == 0: - # Parent process responsible for deleting, don't delete - # in child. - delete_after.pop_all() - try: - func(*args, **kwargs) - except Skipped as e: - # convert Skipped to exit code 0 - print(str(e)) - os._exit(0) - except Exception as e: - import traceback - - tb_string = traceback.format_exc() - - # Try to serialize the exception object first - exc_to_serialize: dict[str, Any] - try: - # First, try to pickle the actual exception with - # its traceback. - exc_to_serialize = {"pickled_exception": e} - # Test if it can be pickled - cloudpickle.dumps(exc_to_serialize) - except (Exception, KeyboardInterrupt): - # Fall back to string-based approach. - exc_to_serialize = { - "exception_type": type(e).__name__, - "exception_msg": str(e), - "traceback": tb_string, - } - try: - with open(exc_file_path, "wb") as f: - cloudpickle.dump(exc_to_serialize, f) - except Exception: - # Fallback: just print the traceback. - print(tb_string) - os._exit(1) - else: - os._exit(0) - else: - pgid = os.getpgid(pid) - _pid, _exitcode = os.waitpid(pid, 0) - # ignore SIGTERM signal itself - old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) - # kill all child processes - os.killpg(pgid, signal.SIGTERM) - # restore the signal handler - signal.signal(signal.SIGTERM, old_signal_handler) - if _exitcode != 0: - # Try to read the exception from the child process - exc_info = {} - if os.path.exists(exc_file_path): - with ( - contextlib.suppress(Exception), - open(exc_file_path, "rb") as f, - ): - exc_info = cloudpickle.load(f) - - if (original_exception := exc_info.get("pickled_exception")) is not None: - # Re-raise the actual exception object if it was - # successfully pickled. - assert isinstance(original_exception, Exception) - raise original_exception - - if (original_tb := exc_info.get("traceback")) is not None: - # Use string-based traceback for fallback case - raise AssertionError( - f"Test {func.__name__} failed when called with" - f" args {args} and kwargs {kwargs}" - f" (exit code: {_exitcode}):\n{original_tb}" - ) from None - - # Fallback to the original generic error - raise AssertionError( - f"function {func.__name__} failed when called with" - f" args {args} and kwargs {kwargs}" - f" (exit code: {_exitcode})" - ) from None - - return wrapper - - -def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to spawn a new process for each test function.""" - - @functools.wraps(f) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # Check if we're already in a subprocess - if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": - # If we are, just run the function directly - return f(*args, **kwargs) - - import torch.multiprocessing as mp - - with suppress(RuntimeError): - mp.set_start_method("spawn") - - # Get the module - module_name = f.__module__ - - # Create a process with environment variable set - env = os.environ.copy() - env["RUNNING_IN_SUBPROCESS"] = "1" - - with tempfile.TemporaryDirectory() as tempdir: - output_filepath = os.path.join(tempdir, "new_process.tmp") - - # `cloudpickle` allows pickling complex functions directly - input_bytes = cloudpickle.dumps((f, output_filepath)) - - repo_root = str(VLLM_PATH.resolve()) - - env = dict(env or os.environ) - env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "") - - cmd = [sys.executable, "-m", f"{module_name}"] - - returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) - - # check if the subprocess is successful - try: - returned.check_returncode() - except Exception as e: - # wrap raised exception to provide more information - raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e - - return wrapper - - -def create_new_process_for_each_test( - method: Literal["spawn", "fork"] | None = None, -) -> Callable[[Callable[_P, None]], Callable[_P, None]]: - """Creates a decorator that runs each test function in a new process. - - Args: - method: The process creation method. Can be either "spawn" or "fork". - If not specified, it defaults to "spawn" on ROCm and XPU - platforms and "fork" otherwise. - - Returns: - A decorator to run test functions in separate processes. - """ - if method is None: - use_spawn = current_platform.is_rocm() or current_platform.is_xpu() - method = "spawn" if use_spawn else "fork" - - assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" - - if method == "fork": - return fork_new_process_for_each_test - - return spawn_new_process_for_each_test - - -def multi_gpu_marks(*, num_gpus: int): - """Get a collection of pytest marks to apply for `@multi_gpu_test`.""" - test_selector = pytest.mark.distributed(num_gpus=num_gpus) - test_skipif = pytest.mark.skipif( - cuda_device_count_stateless() < num_gpus, - reason=f"Need at least {num_gpus} GPUs to run the test.", - ) - - return [test_selector, test_skipif] - - -def multi_gpu_test(*, num_gpus: int): - """ - Decorate a test to be run only when multiple GPUs are available. - """ - marks = multi_gpu_marks(num_gpus=num_gpus) - - def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: - func = create_new_process_for_each_test()(f) - for mark in reversed(marks): - func = mark(func) - - return func - - return wrapper - - -def update_environment_variables(envs_dict: dict[str, str]): - """Update multiple environment variables with logging.""" - for k, v in envs_dict.items(): - if k in os.environ and os.environ[k] != v: - logger.warning( - "Overwriting environment variable %s from '%s' to '%s'", - k, - os.environ[k], - v, - ) - os.environ[k] = v From 085ba7f18ace6ea23bd0c37fcdb077ebbee3db3b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 09:49:26 +0800 Subject: [PATCH 22/85] update env func Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/distributed/test_ulysses_sequence_parallel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index f58432fe541..a97f7734065 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import pytest import torch @@ -19,6 +20,12 @@ from vllm_omni.utils.system_utils import update_environment_variables +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + class TestAttentionModel(torch.nn.Module): """Test model using Attention layer.""" From 8f7a8c21a5e5f14d685c3c6a7dbc6336673a5730 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 09:52:08 +0800 Subject: [PATCH 23/85] remove redundant Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 1 - vllm_omni/utils/system_utils.py | 269 ------------------ 2 files changed, 270 deletions(-) delete mode 100644 vllm_omni/utils/system_utils.py diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index a97f7734065..d8f2d49a15e 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -17,7 +17,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm_omni.utils.system_utils import update_environment_variables def update_environment_variables(envs_dict: dict[str, str]): diff --git a/vllm_omni/utils/system_utils.py b/vllm_omni/utils/system_utils.py deleted file mode 100644 index 94460885bc7..00000000000 --- a/vllm_omni/utils/system_utils.py +++ /dev/null @@ -1,269 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Copied from -# https://github.com/vllm-project/vllm/blob/main/vllm/utils/system_utils.py -from __future__ import annotations - -import contextlib -import multiprocessing -import os -import signal -import sys -from collections.abc import Callable, Iterator -from pathlib import Path -from typing import TextIO - -import psutil -import vllm.envs as envs -from vllm.logger import init_logger -from vllm.ray.lazy_utils import is_in_ray_actor - -# from .platform_utils import cuda_is_initialized, xpu_is_initialized - -logger = init_logger(__name__) - -CYAN = "\033[0;36m" -RESET = "\033[0;0m" - - -# Environment variable utilities - - -def update_environment_variables(envs_dict: dict[str, str]): - """Update multiple environment variables with logging.""" - for k, v in envs_dict.items(): - if k in os.environ and os.environ[k] != v: - logger.warning( - "Overwriting environment variable %s from '%s' to '%s'", - k, - os.environ[k], - v, - ) - os.environ[k] = v - - -@contextlib.contextmanager -def set_env_var(key: str, value: str) -> Iterator[None]: - """Temporarily set an environment variable.""" - old = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - if old is None: - os.environ.pop(key, None) - else: - os.environ[key] = old - - -@contextlib.contextmanager -def suppress_stdout(): - """ - Suppress stdout from C libraries at the file descriptor level. - - Only suppresses stdout, not stderr, to preserve error messages. - Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG. - - Example: - with suppress_stdout(): - # C library calls that would normally print to stdout - torch.distributed.new_group(ranks, backend="gloo") - """ - # Don't suppress if logging level is DEBUG - if envs.VLLM_LOGGING_LEVEL == "DEBUG": - yield - return - - stdout_fd = sys.stdout.fileno() - stdout_dup = os.dup(stdout_fd) - devnull_fd = os.open(os.devnull, os.O_WRONLY) - - try: - sys.stdout.flush() - os.dup2(devnull_fd, stdout_fd) - yield - finally: - sys.stdout.flush() - os.dup2(stdout_dup, stdout_fd) - os.close(stdout_dup) - os.close(devnull_fd) - - -# File path utilities - - -def unique_filepath(fn: Callable[[int], Path]) -> Path: - """Generate a unique file path by trying incrementing integers. - - Note: This function has a TOCTOU race condition. - Caller should use atomic operations (e.g., open with 'x' mode) - when creating the file to ensure thread safety. - """ - i = 0 - while True: - p = fn(i) - if not p.exists(): - return p - i += 1 - - -# Process management utilities - - -def _maybe_force_spawn(): - """Check if we need to force the use of the `spawn` multiprocessing start - method. - """ - if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": - return - - reasons = [] - if is_in_ray_actor(): - # even if we choose to spawn, we need to pass the ray address - # to the subprocess so that it knows how to connect to the ray cluster. - # env vars are inherited by subprocesses, even if we use spawn. - import ray - - os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address - reasons.append("In a Ray actor and can only be spawned") - - # if cuda_is_initialized(): - # reasons.append("CUDA is initialized") - # elif xpu_is_initialized(): - # reasons.append("XPU is initialized") - - if reasons: - logger.warning( - "We must use the `spawn` multiprocessing start method. " - "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " - "See https://docs.vllm.ai/en/latest/usage/" - "troubleshooting.html#python-multiprocessing " - "for more information. Reasons: %s", - "; ".join(reasons), - ) - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - -def get_mp_context(): - """Get a multiprocessing context with a particular method (spawn or fork). - By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to - determine the multiprocessing method (default is fork). However, under - certain conditions, we may enforce spawn and override the value of - VLLM_WORKER_MULTIPROC_METHOD. - """ - _maybe_force_spawn() - mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD - return multiprocessing.get_context(mp_method) - - -def set_process_title( - name: str, - suffix: str = "", - prefix: str = envs.VLLM_PROCESS_NAME_PREFIX, -) -> None: - """Set the current process title with optional suffix.""" - try: - import setproctitle - except ImportError: - return - - if suffix: - name = f"{name}_{suffix}" - - setproctitle.setproctitle(f"{prefix}::{name}") - - -def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: - """Add colored prefix to file output for log decoration.""" - if envs.NO_COLOR: - prefix = f"({worker_name} pid={pid}) " - else: - prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " - file_write = file.write - - def write_with_prefix(s: str): - if not s: - return - if file.start_new_line: # type: ignore[attr-defined] - file_write(prefix) - idx = 0 - while (next_idx := s.find("\n", idx)) != -1: - next_idx += 1 - file_write(s[idx:next_idx]) - if next_idx == len(s): - file.start_new_line = True # type: ignore[attr-defined] - return - file_write(prefix) - idx = next_idx - file_write(s[idx:]) - file.start_new_line = False # type: ignore[attr-defined] - - file.start_new_line = True # type: ignore[attr-defined] - file.write = write_with_prefix # type: ignore[method-assign] - - -def decorate_logs(process_name: str | None = None) -> None: - """Decorate stdout/stderr with process name and PID prefix.""" - # Respect VLLM_CONFIGURE_LOGGING environment variable - if not envs.VLLM_CONFIGURE_LOGGING: - return - - if process_name is None: - process_name = get_mp_context().current_process().name - - pid = os.getpid() - _add_prefix(sys.stdout, process_name, pid) - _add_prefix(sys.stderr, process_name, pid) - - -def kill_process_tree(pid: int): - """ - Kills all descendant processes of the given pid by sending SIGKILL. - - Args: - pid (int): Process ID of the parent process - """ - try: - parent = psutil.Process(pid) - except psutil.NoSuchProcess: - return - - # Get all children recursively - children = parent.children(recursive=True) - - # Send SIGKILL to all children first - for child in children: - with contextlib.suppress(ProcessLookupError): - os.kill(child.pid, signal.SIGKILL) - - # Finally kill the parent - with contextlib.suppress(ProcessLookupError): - os.kill(pid, signal.SIGKILL) - - -# Resource utilities - - -# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 -def set_ulimit(target_soft_limit: int = 65535): - if sys.platform.startswith("win"): - logger.info("Windows detected, skipping ulimit adjustment.") - return - - import resource - - resource_type = resource.RLIMIT_NOFILE - current_soft, current_hard = resource.getrlimit(resource_type) - - if current_soft < target_soft_limit: - try: - resource.setrlimit(resource_type, (target_soft_limit, current_hard)) - except ValueError as e: - logger.warning( - "Found ulimit of %s and failed to automatically increase " - "with error %s. This can cause fd limit errors like " - "`OSError: [Errno 24] Too many open files`. Consider " - "increasing with ulimit -n", - current_soft, - e, - ) From bd576f0dd49001c1af877b0f33b8e5239c8d7636 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 10:14:18 +0800 Subject: [PATCH 24/85] fix num_gpus default value Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index a808e0c61f4..60b5589eece 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -380,7 +380,11 @@ def __post_init__(self): initial_master_port = (self.master_port or 30005) + random.randint(0, 100) self.master_port = self.settle_port(initial_master_port, 37) if self.num_gpus is None: - self.num_gpus = 1 + if self.parallel_config is not None: + self.num_gpus = self.parallel_config.world_size + else: + self.num_gpus = 1 + if self.num_gpus < self.parallel_config.world_size: raise ValueError( f"num_gpus ({self.num_gpus}) < parallel_config.world_size ({self.parallel_config.world_size})" From aa9826efb13c4b7e470a3926e5db9249b61df407 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 10:26:30 +0800 Subject: [PATCH 25/85] rm redundant package check Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/comm.py | 6 +++--- vllm_omni/diffusion/envs.py | 25 ------------------------- 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py index 271e22ad24e..ad847134ece 100644 --- a/vllm_omni/diffusion/distributed/comm.py +++ b/vllm_omni/diffusion/distributed/comm.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team & Jiarui Fang # from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/all_to_all.py -from typing import Any, Tuple +from typing import Any import torch import torch.distributed as dist @@ -112,7 +112,7 @@ def forward( return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: return ( None, SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync), @@ -228,7 +228,7 @@ def forward( return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: return ( None, SeqAllToAll5D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync), diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py index 5a24d4e9f57..8feb63570fa 100644 --- a/vllm_omni/diffusion/envs.py +++ b/vllm_omni/diffusion/envs.py @@ -4,7 +4,6 @@ import os from typing import TYPE_CHECKING, Any, Callable, Optional -import diffusers import torch from packaging import version from vllm.logger import init_logger @@ -164,25 +163,9 @@ def __new__(cls): def initialize(self): packages_info = {} - packages_info["has_aiter"] = self.check_aiter() packages_info["has_flash_attn"] = self.check_flash_attn(packages_info) - packages_info["diffusers_version"] = self.check_diffusers_version() self.packages_info = packages_info - def check_aiter(self): - """ - Checks whether ROCm AITER library is installed - """ - try: - logger.info("Using AITER as the attention library") - return True - except: - if _is_hip(): - logger.warning( - 'Using AMD GPUs, but library "aiter" is not installed, defaulting to other attention mechanisms' - ) - return False - def check_flash_attn(self, packages_info): if not torch.cuda.is_available(): return False @@ -211,14 +194,6 @@ def check_flash_attn(self, packages_info): logger.warning('Flash Attention library "flash_attn" not found, using pytorch attention implementation') return False - def check_diffusers_version(self): - if version.parse(version.parse(diffusers.__version__).base_version) < version.parse("0.30.0"): - raise RuntimeError( - f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported," - f"please upgrade to version > 0.30.0" - ) - return version.parse(version.parse(diffusers.__version__).base_version) - def get_packages_info(self): return self.packages_info From b21646071d892761193ce72a40fbd54db6997aa4 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:15:08 +0800 Subject: [PATCH 26/85] correct e2e Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../text_to_image/text_to_image_sp.py | 30 ++++++++++++------- vllm_omni/diffusion/data.py | 10 ++++--- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index fa5129eaa59..7972a1c41c7 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -7,8 +7,9 @@ import torch -from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_vllm_config +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_omni_diffusion_config from vllm_omni.diffusion.distributed.parallel_state import destroy_distributed_env, get_world_group +from vllm_omni.diffusion.envs import get_device_name from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -63,24 +64,31 @@ def main(): args = parse_args() device = detect_device_type() generator = torch.Generator(device=device).manual_seed(args.seed) - local_rank = get_world_group().local_rank - # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() + device_name = get_device_name() + try: + torch_device = getattr(torch, device_name) + except AttributeError: + raise ValueError(f"Device name {device_name} is not supported") + config_kwargs = { + "model": args.model, + "vae_use_slicing": vae_use_slicing, + "vae_use_tiling": vae_use_tiling, + } omni_diffusion_config = OmniDiffusionConfig( - parallel_config=DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) + **config_kwargs, parallel_config=DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) ) - with set_current_vllm_config(omni_diffusion_config): + with set_current_omni_diffusion_config(omni_diffusion_config): omni = Omni( - model=args.model, + **config_kwargs, od_config=omni_diffusion_config, - vae_use_slicing=vae_use_slicing, - vae_use_tiling=vae_use_tiling, ) - parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - torch.cuda.reset_peak_memory_stats() + local_rank = get_world_group().local_rank + parameter_peak_memory = torch_device.max_memory_allocated(device=f"{device_name}:{local_rank}") + torch_device.reset_peak_memory_stats() start_time = time.time() images = omni.generate( args.prompt, @@ -94,7 +102,7 @@ def main(): ) end_time = time.time() elapsed_time = end_time - start_time - peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + peak_memory = torch_device.max_memory_allocated(device=f"{device_name}:{local_rank}") output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 60b5589eece..1fe96169f71 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -412,13 +412,15 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": @contextmanager -def set_current_vllm_config(omni_diffusion_config: OmniDiffusionConfig, check_compile=False, prefix: str | None = None): +def set_current_omni_diffusion_config( + omni_diffusion_config: OmniDiffusionConfig, check_compile=False, prefix: str | None = None +): """ - Temporarily set the current vLLM config. + Temporarily set the current vLLM-Omni config. Used during model initialization. - We save the current vLLM config in a global variable, + We save the current vLLM-Omni config in a global variable, so that all modules can access it, e.g. custom ops - can access the vLLM config to determine how to dispatch. + can access the vLLM-Omni config to determine how to dispatch. """ global _current_omni_diffusion_config, _current_prefix old_omni_diffusion_config = _current_omni_diffusion_config From d3856a11a7260b97587a6831ab0522888695e4e5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:48:59 +0800 Subject: [PATCH 27/85] replace by vllm groupcoordinator Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../distributed/group_coordinator.py | 482 +----------------- 1 file changed, 2 insertions(+), 480 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index bb4d1bf93ae..0f26b220e44 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -3,14 +3,13 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import pickle from collections import namedtuple from typing import Any, Optional, Union import torch import torch.distributed from torch.cuda import synchronize -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend from vllm_omni.diffusion import envs @@ -18,6 +17,7 @@ print("torch.npu synchronize") from torch.npu import synchronize +from vllm.distributed.parallel_state import GroupCoordinator from vllm.logger import init_logger logger = init_logger(__name__) @@ -71,484 +71,6 @@ def _update_nested_dict(nested_dict, flattened_key, value): cur_dict[key_splits[-1]] = value -class GroupCoordinator: - """ - PyTorch ProcessGroup wrapper for a group of processes. - PyTorch ProcessGroup is bound to one specific communication backend, - e.g. NCCL, Gloo, MPI, etc. - GroupCoordinator takes charge of all the communication operations among - the processes in the group. It can route the communication to - a specific implementation (e.g. switch allreduce implementation - based on the tensor size and cuda graph mode). - """ - - # available attributes: - rank: int # global rank - ranks: list[int] # global ranks in the group - world_size: int # size of the group - # difference between `local_rank` and `rank_in_group`: - # if we have a group of size 4 across two nodes: - # Process | Node | Rank | Local Rank | Rank in Group - # 0 | 0 | 0 | 0 | 0 - # 1 | 0 | 1 | 1 | 1 - # 2 | 1 | 2 | 0 | 2 - # 3 | 1 | 3 | 1 | 3 - local_rank: int # local rank used to assign devices - rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication - - def __init__( - self, - group_ranks: list[list[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - ): - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.cpu_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group - - assert self.cpu_group is not None - assert self.device_group is not None - - self.device = envs.get_device(local_rank) - - @property - def first_rank(self): - """Return the global rank of the first process in the group""" - return self.ranks[0] - - @property - def last_rank(self): - """Return the global rank of the last process in the group""" - return self.ranks[-1] - - @property - def is_first_rank(self): - """Return whether the caller is the first process in the group""" - return self.rank == self.first_rank - - @property - def is_last_rank(self): - """Return whether the caller is the last process in the group""" - return self.rank == self.last_rank - - @property - def next_rank(self): - """Return the global rank of the process that follows the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group + 1) % world_size] - - @property - def prev_rank(self): - """Return the global rank of the process that precedes the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group - 1) % world_size] - - @property - def group_next_rank(self): - """Return the group rank of the process that follows the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return (rank_in_group + 1) % world_size - - @property - def group_prev_rank(self): - """Return the group rank of the process that precedes the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return (rank_in_group - 1) % world_size - - @property - def skip_rank(self): - """Return the global rank of the process that skip connects with the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(world_size - rank_in_group - 1) % world_size] - - @property - def group_skip_rank(self): - """Return the group rank of the process that skip connects with the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return (world_size - rank_in_group - 1) % world_size - - def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: - """ - NOTE: This operation will be applied in-place or out-of-place. - Always assume this function modifies its input, but use the return - value as the output. - """ - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - else: - torch.distributed.all_reduce(input_, op=op, group=self.device_group) - return input_ - - def all_gather( - self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False - ) -> Union[torch.Tensor, list[torch.Tensor]]: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - # Allocate output tensor. - input_size = list(input_.size()) - input_size[0] *= world_size - output_tensor = torch.empty(input_size, dtype=input_.dtype, device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.device_group) - if dim != 0: - input_size[0] //= world_size - output_tensor = output_tensor.reshape( - [ - world_size, - ] - + input_size - ) - output_tensor = output_tensor.movedim(0, dim) - - if separate_tensors: - tensor_list = [ - output_tensor.view(-1).narrow(0, input_.numel() * i, input_.numel()).view_as(input_) - for i in range(world_size) - ] - return tensor_list - else: - input_size = list(input_.size()) - input_size[dim] = input_size[dim] * world_size - # Reshape - output_tensor = output_tensor.reshape(input_size) - return output_tensor - - def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: - """ - NOTE: We assume that the input tensor is on the same device across - all the ranks. - NOTE: `dst` is the local rank of the destination rank. - """ - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, gather_list, dst=self.ranks[dst], group=self.device_group) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor - - def broadcast(self, input_: torch.Tensor, src: int = 0): - """Broadcast the input tensor. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # Broadcast. - torch.distributed.broadcast(input_, src=self.ranks[src], group=self.device_group) - return input_ - - def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): - """Broadcast the input object. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj - if self.shm_broadcaster is not None: - assert src == 0, "Shared memory broadcaster only supports src=0" - return self.shm_broadcaster.broadcast_object(obj) - if self.rank_in_group == src: - torch.distributed.broadcast_object_list([obj], src=self.ranks[src], group=self.cpu_group) - return obj - else: - recv = [None] - torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) - return recv[0] - - def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None): - """Broadcast the input object list. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj_list - # Broadcast. - torch.distributed.broadcast_object_list(obj_list, src=self.ranks[src], group=self.device_group) - return obj_list - - def send_object(self, obj: Any, dst: int) -> None: - """Send the input object list to the destination rank.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank." - - # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - - size_tensor = torch.tensor([object_tensor.numel()], dtype=torch.long, device="cpu") - - # Send object size - - torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) - - # Send object - torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) - - return None - - def recv_object(self, src: int) -> Any: - """Receive the input object list from the source rank.""" - """NOTE: `src` is the local rank of the source rank.""" - - assert src < self.world_size, f"Invalid src rank ({src})" - - assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank." - - size_tensor = torch.empty(1, dtype=torch.long, device="cpu") - - # Receive object size - rank_size = torch.distributed.recv(size_tensor, src=self.ranks[src], group=self.cpu_group) - - # Tensor to receive serialized objects into. - object_tensor = torch.empty( # type: ignore[call-overload] - size_tensor.item(), # type: ignore[arg-type] - dtype=torch.uint8, - device="cpu", - ) - - rank_object = torch.distributed.recv(object_tensor, src=self.ranks[src], group=self.cpu_group) - - assert rank_object == rank_size, "Received object sender rank does not match the size sender rank." - - obj = pickle.loads(object_tensor.numpy().tobytes()) - - return obj - - def broadcast_tensor_dict( - self, - tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return tensor_dict - - group = self.device_group - metadata_group = self.cpu_group - assert src < self.world_size, f"Invalid src rank ({src})" - src = self.ranks[src] - - rank = self.rank - if rank == src: - metadata_list: list[tuple[Any, Any]] = [] - assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `broadcast_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.broadcast_object(metadata_list, src=src) - async_handles = [] - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) - async_handles.append(handle) - for async_handle in async_handles: - async_handle.wait() - - else: - metadata_list = self.broadcast_object(None, src=src) - tensor_dict = {} - async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - _update_nested_dict(tensor_dict, key, tensor) - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) - async_handles.append(handle) - _update_nested_dict(tensor_dict, key, tensor) - else: - _update_nested_dict(tensor_dict, key, value) - for async_handle in async_handles: - async_handle.wait() - return tensor_dict - - def send_tensor_dict( - self, - tensor_dict: dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: - """Send the input tensor dictionary. - NOTE: `dst` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return tensor_dict - - group = self.device_group - metadata_group = self.cpu_group - - if dst is None: - dst = self.group_next_rank - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - metadata_list: list[tuple[Any, Any]] = [] - assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `send_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip sending empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.send(tensor, dst=self.ranks[dst], group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.send(tensor, dst=self.ranks[dst], group=group) - return None - - def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[dict[str, Union[torch.Tensor, Any]]]: - """Recv the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return None - - group = self.device_group - metadata_group = self.cpu_group - - if src is None: - src = self.group_prev_rank - assert src < self.world_size, f"Invalid src rank ({src})" - - recv_metadata_list = self.recv_object(src=src) - tensor_dict: dict[str, Any] = {} - for key, value in recv_metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - _update_nested_dict(tensor_dict, key, tensor) - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.recv(tensor, src=self.ranks[src], group=group) - _update_nested_dict(tensor_dict, key, tensor) - else: - _update_nested_dict(tensor_dict, key, value) - return tensor_dict - - def barrier(self): - """Barrier synchronization among the group. - NOTE: don't use `device_group` here! `barrier` in NCCL is - terrible because it is internally a broadcast operation with - secretly created GPU tensors. It is easy to mess up the current - device. Use the CPU group instead. - """ - torch.distributed.barrier(group=self.cpu_group) - - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """NOTE: `dst` is the rank_in_group of the destination rank.""" - if dst is None: - dst = self.group_next_rank - - torch.distributed.send( - tensor, - self.ranks[dst], - group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), - ) - - def recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the rank_in_group of the source rank.""" - if src is None: - src = self.group_prev_rank - - tensor = torch.empty(size, dtype=dtype, device=self.device) - torch.distributed.recv( - tensor, - self.ranks[src], - (self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), - ) - return tensor - - def destroy(self): - if self.device_group is not None: - torch.distributed.destroy_process_group(self.device_group) - self.device_group = None - if self.cpu_group is not None: - torch.distributed.destroy_process_group(self.cpu_group) - self.cpu_group = None - - class PipelineGroupCoordinator(GroupCoordinator): """ available attributes: From f368def5203abd5c1316ec817c3c61fa2eb0a244 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:51:19 +0800 Subject: [PATCH 28/85] correct name Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/distributed/test_ulysses_sequence_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/distributed/test_ulysses_sequence_parallel.py index d8f2d49a15e..246ee4da53f 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/distributed/test_ulysses_sequence_parallel.py @@ -10,7 +10,7 @@ from vllm_omni.diffusion.data import ( DiffusionParallelConfig, OmniDiffusionConfig, - set_current_vllm_config, + set_current_omni_diffusion_config, ) from vllm_omni.diffusion.distributed.parallel_state import ( destroy_distributed_env, @@ -250,7 +250,7 @@ def ulysses_attention_on_test_model( ) # Set the config so Attention can access it - with set_current_vllm_config(od_config): + with set_current_omni_diffusion_config(od_config): # Create model hidden_size = num_heads * head_size From 21282d3449d3d7d46dbbe92b22f2aecf43349c55 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:21:03 +0800 Subject: [PATCH 29/85] update gpu worker: set tp size and config Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/worker/gpu_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 8dd49ad04e2..85b1c8a760f 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -16,6 +16,7 @@ SHUTDOWN_MESSAGE, DiffusionOutput, OmniDiffusionConfig, + set_current_omni_diffusion_config, ) from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader @@ -58,8 +59,9 @@ def init_device_and_model(self) -> None: # hack vllm_config = VllmConfig() - vllm_config.parallel_config.tensor_parallel_size = self.od_config.num_gpus + vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size set_current_vllm_config(vllm_config) + set_current_omni_diffusion_config(self.od_config) init_distributed_environment(world_size=world_size, rank=rank) logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") From 255a2e269141044ad185435166bf7cfc0363fb68 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:21:23 +0800 Subject: [PATCH 30/85] Revert "replace by vllm groupcoordinator" This reverts commit 6a52a2949ee135e05bbe7b5a3831ce9b8cb8f54b. Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../distributed/group_coordinator.py | 482 +++++++++++++++++- 1 file changed, 480 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 0f26b220e44..bb4d1bf93ae 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -3,13 +3,14 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import pickle from collections import namedtuple from typing import Any, Optional, Union import torch import torch.distributed from torch.cuda import synchronize -from torch.distributed import Backend +from torch.distributed import Backend, ProcessGroup from vllm_omni.diffusion import envs @@ -17,7 +18,6 @@ print("torch.npu synchronize") from torch.npu import synchronize -from vllm.distributed.parallel_state import GroupCoordinator from vllm.logger import init_logger logger = init_logger(__name__) @@ -71,6 +71,484 @@ def _update_nested_dict(nested_dict, flattened_key, value): cur_dict[key_splits[-1]] = value +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = envs.get_device(local_rank) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, op=op, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, list[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty(input_size, dtype=input_.dtype, device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.device_group) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape( + [ + world_size, + ] + + input_size + ) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1).narrow(0, input_.numel() * i, input_.numel()).view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, gather_list, dst=self.ranks[dst], group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, src=self.ranks[src], group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], src=self.ranks[src], group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, src=self.ranks[src], group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank." + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], dtype=torch.long, device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, src=self.ranks[src], group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv(object_tensor, src=self.ranks[src], group=self.cpu_group) + + assert rank_object == rank_size, "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), + ) + + def recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + (self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + class PipelineGroupCoordinator(GroupCoordinator): """ available attributes: From d1e4f0b961ef7c1dd0659069273315e375901d46 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:36:38 +0800 Subject: [PATCH 31/85] vllmconfig and omnidiffusion config share dp and tp Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/worker/gpu_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 85b1c8a760f..81e56d1f304 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -60,6 +60,7 @@ def init_device_and_model(self) -> None: # hack vllm_config = VllmConfig() vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size + vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size set_current_vllm_config(vllm_config) set_current_omni_diffusion_config(self.od_config) From 34e86a61ef2edc66d796a5e497a228df61e88c7a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:42:00 +0800 Subject: [PATCH 32/85] get tp from vllm parallel_state Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/distributed/parallel_state.py | 41 +++++++------------ 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 2d76be56997..477bb6b9c05 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -33,8 +33,10 @@ import torch import torch.distributed +import vllm.distributed.parallel_state as vllm_parallel_state from torch.cuda import device_count, set_device from vllm.logger import init_logger +from vllm_parallel_state import get_tensor_model_parallel_world_size from vllm_omni.diffusion import envs @@ -63,7 +65,7 @@ _WORLD: Optional[GroupCoordinator] = None -_TP: Optional[GroupCoordinator] = None +# get _TP from vllm.distributed.parallel_state _SP: Optional[SequenceParallelGroupCoordinator] = None _PP: Optional[PipelineGroupCoordinator] = None _CFG: Optional[GroupCoordinator] = None @@ -263,22 +265,6 @@ def get_world_group() -> GroupCoordinator: return _WORLD -# TP -def get_tp_group() -> GroupCoordinator: - assert _TP is not None, "tensor model parallel group is not initialized" - return _TP - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return get_tp_group().world_size - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - return get_tp_group().rank_in_group - - # SP def get_sp_group() -> SequenceParallelGroupCoordinator: assert _SP is not None, "pipeline model parallel group is not initialized" @@ -467,7 +453,13 @@ def init_distributed_environment( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return _DP is not None and _CFG is not None and _SP is not None and _PP is not None and _TP is not None + return ( + _DP is not None + and _CFG is not None + and _SP is not None + and _PP is not None + and vllm_parallel_state._TP is not None + ) def init_model_parallel_group( @@ -727,15 +719,13 @@ def initialize_model_parallel( ring_group=ring_pg, ) - global _TP - assert _TP is None, "Tensor parallel group is already initialized" - _TP = init_model_parallel_group( + assert vllm_parallel_state._TP is None, "Tensor parallel group is already initialized" + vllm_parallel_state._TP = init_model_parallel_group( group_ranks=rank_generator.get_ranks("tp"), local_rank=get_world_group().local_rank, backend=backend, parallel_mode="tensor", ) - if vae_parallel_size > 0: init_vae_group(dit_parallel_size, vae_parallel_size, backend) init_dit_group(dit_parallel_size, backend) @@ -758,10 +748,9 @@ def destroy_model_parallel(): _SP.destroy() _SP = None - global _TP - if _TP: - _TP.destroy() - _TP = None + if vllm_parallel_state._TP: + vllm_parallel_state._TP.destroy() + vllm_parallel_state._TP = None global _PP if _PP: From 4939ed2e446103e11f912a37ff2a64a95f83742d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:05:25 +0800 Subject: [PATCH 33/85] sequence_parallel_size updates Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 1fe96169f71..e51a640be3a 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -55,8 +55,6 @@ def _validate_parallel_config(self) -> Self: assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0" assert self.data_parallel_size > 0, "Data parallel size must be > 0" assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0" - if self.sequence_parallel_size is None: - self.sequence_parallel_size = self.ulysses_degree * self.ring_degree assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0" assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" @@ -68,6 +66,8 @@ def _validate_parallel_config(self) -> Self: return self def __post_init__(self) -> None: + if self.sequence_parallel_size is None: + self.sequence_parallel_size = self.ulysses_degree * self.ring_degree self.world_size = ( self.pipeline_parallel_size * self.data_parallel_size From fcf97c4db5743ae43d63bc75ba5f9fc72744961d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:06:00 +0800 Subject: [PATCH 34/85] fix vllm.distributed.parallel_state import error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 477bb6b9c05..3936bbd2efb 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -35,8 +35,8 @@ import torch.distributed import vllm.distributed.parallel_state as vllm_parallel_state from torch.cuda import device_count, set_device +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm_parallel_state import get_tensor_model_parallel_world_size from vllm_omni.diffusion import envs From a560b82f511c912adf6047389b258c76e608962c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:06:47 +0800 Subject: [PATCH 35/85] remove local rank in sp example Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../offline_inference/text_to_image/text_to_image_sp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index 7972a1c41c7..591226381d1 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -86,8 +86,7 @@ def main(): **config_kwargs, od_config=omni_diffusion_config, ) - local_rank = get_world_group().local_rank - parameter_peak_memory = torch_device.max_memory_allocated(device=f"{device_name}:{local_rank}") + torch_device.reset_peak_memory_stats() start_time = time.time() images = omni.generate( @@ -102,7 +101,6 @@ def main(): ) end_time = time.time() elapsed_time = end_time - start_time - peak_memory = torch_device.max_memory_allocated(device=f"{device_name}:{local_rank}") output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) @@ -117,9 +115,7 @@ def main(): img.save(save_path) print(f"Saved generated image to {save_path}") if get_world_group().rank == get_world_group().world_size - 1: - print( - f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory / 1e9:.2f} GB, memory: {peak_memory / 1e9:.2f} GB" - ) + print(f"epoch time: {elapsed_time:.2f} sec") destroy_distributed_env() From bd2b0b097d5900b1bfde063ff226e73500697c19 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:10:20 +0800 Subject: [PATCH 36/85] split rotary_embed Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/qwen_image/qwen_image_transformer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 88784994e7f..2cdda8b794e 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -659,6 +659,15 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + def get_rotary_emb_chunk(freqs): + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] + return freqs + + if self.parallel_config.sequence_parallel_size > 1: + img_freqs, txt_freqs = image_rotary_emb + img_freqs = get_rotary_emb_chunk(img_freqs) + image_rotary_emb = (img_freqs, txt_freqs) + for index_block, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, From bfd6026eade29c6ba45138e625c6694ec5b40ba0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 15:16:47 +0800 Subject: [PATCH 37/85] correct field Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index e51a640be3a..fdf3dfd958a 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -13,7 +13,7 @@ from functools import lru_cache import torch -from pydantic import Field, model_validator +from pydantic import model_validator from typing_extensions import Self from vllm.config.utils import config from vllm.logger import init_logger @@ -242,7 +242,7 @@ class OmniDiffusionConfig: # Cache strategy (legacy) cache_strategy: str = "none" - parallel_config: DiffusionParallelConfig = Field(default_factory=DiffusionParallelConfig) + parallel_config: DiffusionParallelConfig = field(default_factory=DiffusionParallelConfig) # Cache backend configuration (NEW) cache_backend: str = "none" # "tea_cache", "deep_cache", etc. From 43f8933296b2375af67348ca4b5b0b6e53bf0706 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 16:27:16 +0800 Subject: [PATCH 38/85] fix ci Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/comm.py | 9 ++++++--- vllm_omni/diffusion/envs.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py index ad847134ece..6db377bef72 100644 --- a/vllm_omni/diffusion/distributed/comm.py +++ b/vllm_omni/diffusion/distributed/comm.py @@ -65,7 +65,8 @@ def all_to_all_4D( seq_world_size = dist.get_world_size(group) # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! - # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> + # (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) input_t = ( input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) .transpose(0, 3) @@ -152,7 +153,8 @@ def all_to_all_5D( shard_hc = hc // seq_world_size # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! - # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs) + # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> + # (P, seq_len/P, 3, bs, hc/P, hs) input_t = input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs).transpose(0, 3).contiguous() output = torch.empty_like(input_t) @@ -180,7 +182,8 @@ def all_to_all_5D( seq_world_size = dist.get_world_size(group) # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! - # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) + # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> + # (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) input_t = ( input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs) .transpose(0, 4) diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py index 8feb63570fa..4b5b6fb041a 100644 --- a/vllm_omni/diffusion/envs.py +++ b/vllm_omni/diffusion/envs.py @@ -157,7 +157,7 @@ class PackagesEnvChecker: def __new__(cls): if cls._instance is None: - cls._instance = super(PackagesEnvChecker, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._instance.initialize() return cls._instance @@ -184,7 +184,7 @@ def check_flash_attn(self, packages_info): if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: return False else: - from flash_attn import __version__, flash_attn_func + from flash_attn import __version__ if __version__ < "2.6.0": raise ImportError("install flash_attn >= 2.6.0") From 36aa2b9438f8b0d8f289f874a0ab3092b5d37cf1 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:34:33 +0800 Subject: [PATCH 39/85] init model and device with context manager Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/worker/gpu_worker.py | 44 ++++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 81e56d1f304..ffc4775721d 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -62,30 +62,30 @@ def init_device_and_model(self) -> None: vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size set_current_vllm_config(vllm_config) - set_current_omni_diffusion_config(self.od_config) - - init_distributed_environment(world_size=world_size, rank=rank) - logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") - parallel_config = self.od_config.parallel_config - initialize_model_parallel( - data_parallel_degree=parallel_config.data_parallel_size, - classifier_free_guidance_degree=parallel_config.cfg_parallel_size, - sequence_parallel_degree=parallel_config.sequence_parallel_size, - ulysses_degree=parallel_config.ulysses_degree, - ring_degree=parallel_config.ring_degree, - tensor_parallel_degree=parallel_config.tensor_parallel_size, - pipeline_parallel_degree=parallel_config.pipeline_parallel_size, - ) - load_config = LoadConfig() - model_loader = DiffusersPipelineLoader(load_config) - time_before_load = time.perf_counter() - with DeviceMemoryProfiler() as m: - self.pipeline = model_loader.load_model( - od_config=self.od_config, - load_device=f"cuda:{rank}", + with set_current_omni_diffusion_config(self.od_config): + init_distributed_environment(world_size=world_size, rank=rank) + logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_degree=parallel_config.data_parallel_size, + classifier_free_guidance_degree=parallel_config.cfg_parallel_size, + sequence_parallel_degree=parallel_config.sequence_parallel_size, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tensor_parallel_size, + pipeline_parallel_degree=parallel_config.pipeline_parallel_size, ) - time_after_load = time.perf_counter() + + load_config = LoadConfig() + model_loader = DiffusersPipelineLoader(load_config) + time_before_load = time.perf_counter() + with DeviceMemoryProfiler() as m: + self.pipeline = model_loader.load_model( + od_config=self.od_config, + load_device=f"cuda:{rank}", + ) + time_after_load = time.perf_counter() logger.info( "Model loading took %.4f GiB and %.6f seconds", From 2f23822229750e2d5c9a15463a1d0dca8eb4cf57 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:46:57 +0800 Subject: [PATCH 40/85] remove get_world_size Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../offline_inference/text_to_image/text_to_image_sp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index 591226381d1..87fdf063bd4 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -8,7 +8,6 @@ import torch from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_omni_diffusion_config -from vllm_omni.diffusion.distributed.parallel_state import destroy_distributed_env, get_world_group from vllm_omni.diffusion.envs import get_device_name from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -114,9 +113,8 @@ def main(): save_path = output_path.parent / f"{stem}_{idx}{suffix}" img.save(save_path) print(f"Saved generated image to {save_path}") - if get_world_group().rank == get_world_group().world_size - 1: - print(f"epoch time: {elapsed_time:.2f} sec") - destroy_distributed_env() + + print(f"epoch time: {elapsed_time:.2f} sec") if __name__ == "__main__": From 2252998c64cf5eb9e97b96b3570914edb3cba50d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:50:10 +0800 Subject: [PATCH 41/85] shutdown device and comm group Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/worker/gpu_worker.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index ffc4775721d..606000f0285 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -18,7 +18,11 @@ OmniDiffusionConfig, set_current_omni_diffusion_config, ) -from vllm_omni.diffusion.distributed.parallel_state import init_distributed_environment, initialize_model_parallel +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, +) from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -117,12 +121,7 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi return output def shutdown(self) -> None: - if torch.distributed.is_initialized(): - try: - torch.distributed.destroy_process_group() - logger.info("Worker %s: Destroyed process group", self.rank) - except Exception as exc: # pragma: no cover - best effort cleanup - logger.warning("Worker %s: Failed to destroy process group: %s", self.rank, exc) + destroy_distributed_env() class WorkerProc: From 42eeb6a7776316433d78f2edf20215f68e711020 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:18:07 +0800 Subject: [PATCH 42/85] set vllm_config as context manager Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/worker/gpu_worker.py | 44 ++++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 606000f0285..075e6aeffa4 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -65,31 +65,31 @@ def init_device_and_model(self) -> None: vllm_config = VllmConfig() vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size - set_current_vllm_config(vllm_config) with set_current_omni_diffusion_config(self.od_config): - init_distributed_environment(world_size=world_size, rank=rank) - logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") - parallel_config = self.od_config.parallel_config - initialize_model_parallel( - data_parallel_degree=parallel_config.data_parallel_size, - classifier_free_guidance_degree=parallel_config.cfg_parallel_size, - sequence_parallel_degree=parallel_config.sequence_parallel_size, - ulysses_degree=parallel_config.ulysses_degree, - ring_degree=parallel_config.ring_degree, - tensor_parallel_degree=parallel_config.tensor_parallel_size, - pipeline_parallel_degree=parallel_config.pipeline_parallel_size, - ) - - load_config = LoadConfig() - model_loader = DiffusersPipelineLoader(load_config) - time_before_load = time.perf_counter() - with DeviceMemoryProfiler() as m: - self.pipeline = model_loader.load_model( - od_config=self.od_config, - load_device=f"cuda:{rank}", + with set_current_vllm_config(vllm_config): + init_distributed_environment(world_size=world_size, rank=rank) + logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_degree=parallel_config.data_parallel_size, + classifier_free_guidance_degree=parallel_config.cfg_parallel_size, + sequence_parallel_degree=parallel_config.sequence_parallel_size, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tensor_parallel_size, + pipeline_parallel_degree=parallel_config.pipeline_parallel_size, ) - time_after_load = time.perf_counter() + + load_config = LoadConfig() + model_loader = DiffusersPipelineLoader(load_config) + time_before_load = time.perf_counter() + with DeviceMemoryProfiler() as m: + self.pipeline = model_loader.load_model( + od_config=self.od_config, + load_device=f"cuda:{rank}", + ) + time_after_load = time.perf_counter() logger.info( "Model loading took %.4f GiB and %.6f seconds", From c9d7a14aa08eaf384b0c17085efbc396c6fd1c41 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:19:45 +0800 Subject: [PATCH 43/85] record inference speed Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../text_to_image/text_to_image.py | 3 ++ .../text_to_image/text_to_image_sp.py | 46 +++++++++---------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 85c84b5961a..fb513d436f1 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -143,6 +143,9 @@ def main(): save_path = output_path.parent / f"{stem}_{idx}{suffix}" img.save(save_path) print(f"Saved generated image to {save_path}") + print( + f"inference time: {elapsed_time:.2f} sec, average time per image: {elapsed_time / args.num_images_per_prompt:.2f} sec" + ) if __name__ == "__main__": diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index 87fdf063bd4..d7c976b4711 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -7,8 +7,7 @@ import torch -from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_omni_diffusion_config -from vllm_omni.diffusion.envs import get_device_name +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -66,11 +65,7 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - device_name = get_device_name() - try: - torch_device = getattr(torch, device_name) - except AttributeError: - raise ValueError(f"Device name {device_name} is not supported") + config_kwargs = { "model": args.model, "vae_use_slicing": vae_use_slicing, @@ -80,24 +75,23 @@ def main(): omni_diffusion_config = OmniDiffusionConfig( **config_kwargs, parallel_config=DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) ) - with set_current_omni_diffusion_config(omni_diffusion_config): - omni = Omni( - **config_kwargs, - od_config=omni_diffusion_config, - ) - torch_device.reset_peak_memory_stats() - start_time = time.time() - images = omni.generate( - args.prompt, - height=args.height, - width=args.width, - generator=generator, - true_cfg_scale=args.cfg_scale, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.num_images_per_prompt, - num_outputs_per_prompt=args.num_images_per_prompt, - ) + omni = Omni( + **config_kwargs, + od_config=omni_diffusion_config, + ) + + start_time = time.time() + images = omni.generate( + args.prompt, + height=args.height, + width=args.width, + generator=generator, + true_cfg_scale=args.cfg_scale, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.num_images_per_prompt, + num_outputs_per_prompt=args.num_images_per_prompt, + ) end_time = time.time() elapsed_time = end_time - start_time @@ -114,7 +108,9 @@ def main(): img.save(save_path) print(f"Saved generated image to {save_path}") - print(f"epoch time: {elapsed_time:.2f} sec") + print( + f"inference time: {elapsed_time:.2f} sec, average time per image: {elapsed_time / args.num_images_per_prompt:.2f} sec" + ) if __name__ == "__main__": From de93d3448d04f602b71ea259a7748eafca980336 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:47:18 +0800 Subject: [PATCH 44/85] different save path Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- examples/offline_inference/text_to_image/text_to_image_sp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index d7c976b4711..e158dca1532 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -28,7 +28,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--output", type=str, - default="qwen_image_output.png", + default="qwen_image_output_sp.png", help="Path to save the generated image (PNG).", ) parser.add_argument( From 19bc2454f6b9d115361f2215041ccc886961e057 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:12:08 +0800 Subject: [PATCH 45/85] ring attention not supported yet Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- examples/offline_inference/text_to_image/text_to_image_sp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py index e158dca1532..90692154ed4 100644 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ b/examples/offline_inference/text_to_image/text_to_image_sp.py @@ -66,6 +66,8 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() + assert args.ring_degree == 1, "Ring attention is not supported yet" + config_kwargs = { "model": args.model, "vae_use_slicing": vae_use_slicing, From 696dc0a0227ecfc9b96de400ddad3b70c0d4f2de Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:49:42 +0800 Subject: [PATCH 46/85] merge two scripts into one Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../text_to_image/text_to_image.py | 22 +++- .../text_to_image/text_to_image_sp.py | 119 ------------------ 2 files changed, 18 insertions(+), 123 deletions(-) delete mode 100644 examples/offline_inference/text_to_image/text_to_image_sp.py diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index fb513d436f1..a534c9d13f0 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -7,6 +7,7 @@ import torch +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -57,6 +58,18 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=2, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) return parser.parse_args() @@ -98,12 +111,16 @@ def main(): # (e.g., QwenImagePipeline or FluxPipeline) } + + assert args.ring_degree == 1, "Ring attention is not supported yet" + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) omni = Omni( model=args.model, vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, cache_backend=args.cache_backend, cache_config=cache_config, + parallel_config=parallel_config, ) # Time profiling for generation @@ -143,10 +160,7 @@ def main(): save_path = output_path.parent / f"{stem}_{idx}{suffix}" img.save(save_path) print(f"Saved generated image to {save_path}") - print( - f"inference time: {elapsed_time:.2f} sec, average time per image: {elapsed_time / args.num_images_per_prompt:.2f} sec" - ) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/offline_inference/text_to_image/text_to_image_sp.py b/examples/offline_inference/text_to_image/text_to_image_sp.py deleted file mode 100644 index 90692154ed4..00000000000 --- a/examples/offline_inference/text_to_image/text_to_image_sp.py +++ /dev/null @@ -1,119 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import time -from pathlib import Path - -import torch - -from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig -from vllm_omni.entrypoints.omni import Omni -from vllm_omni.utils.platform_utils import detect_device_type, is_npu - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.") - parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.") - parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") - parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") - parser.add_argument( - "--cfg_scale", - type=float, - default=4.0, - help="True classifier-free guidance scale specific to Qwen-Image.", - ) - parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") - parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") - parser.add_argument( - "--output", - type=str, - default="qwen_image_output_sp.png", - help="Path to save the generated image (PNG).", - ) - parser.add_argument( - "--num_images_per_prompt", - type=int, - default=1, - help="Number of images to generate for the given prompt.", - ) - parser.add_argument( - "--num_inference_steps", - type=int, - default=50, - help="Number of denoising steps for the diffusion sampler.", - ) - parser.add_argument( - "--ulysses_degree", - type=int, - default=2, - help="Number of GPUs used for ulysses sequence parallelism.", - ) - parser.add_argument( - "--ring_degree", - type=int, - default=1, - help="Number of GPUs used for ring sequence parallelism.", - ) - return parser.parse_args() - - -def main(): - args = parse_args() - device = detect_device_type() - generator = torch.Generator(device=device).manual_seed(args.seed) - # Enable VAE memory optimizations on NPU - vae_use_slicing = is_npu() - vae_use_tiling = is_npu() - - assert args.ring_degree == 1, "Ring attention is not supported yet" - - config_kwargs = { - "model": args.model, - "vae_use_slicing": vae_use_slicing, - "vae_use_tiling": vae_use_tiling, - } - - omni_diffusion_config = OmniDiffusionConfig( - **config_kwargs, parallel_config=DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) - ) - - omni = Omni( - **config_kwargs, - od_config=omni_diffusion_config, - ) - - start_time = time.time() - images = omni.generate( - args.prompt, - height=args.height, - width=args.width, - generator=generator, - true_cfg_scale=args.cfg_scale, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.num_images_per_prompt, - num_outputs_per_prompt=args.num_images_per_prompt, - ) - end_time = time.time() - elapsed_time = end_time - start_time - - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - suffix = output_path.suffix or ".png" - stem = output_path.stem or "qwen_image_output" - if args.num_images_per_prompt <= 1: - images[0].save(output_path) - print(f"Saved generated image to {output_path}") - else: - for idx, img in enumerate(images): - save_path = output_path.parent / f"{stem}_{idx}{suffix}" - img.save(save_path) - print(f"Saved generated image to {save_path}") - - print( - f"inference time: {elapsed_time:.2f} sec, average time per image: {elapsed_time / args.num_images_per_prompt:.2f} sec" - ) - - -if __name__ == "__main__": - main() From cb0ef5dda33dd386a867baadef95b75598809e66 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:02:20 +0800 Subject: [PATCH 47/85] update test script Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../attention}/test_ulysses_sequence_parallel.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) rename tests/{distributed => diffusion/attention}/test_ulysses_sequence_parallel.py (98%) diff --git a/tests/distributed/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py similarity index 98% rename from tests/distributed/test_ulysses_sequence_parallel.py rename to tests/diffusion/attention/test_ulysses_sequence_parallel.py index 246ee4da53f..6f2e5749e0a 100644 --- a/tests/distributed/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -130,6 +130,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: TestMultiLayerAttentionModel, ], ) +@pytest.mark.parametrize("ulysses_degree", [2, 4, 8]) +@pytest.mark.parametrize("ring_degree", [1]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("num_heads", [8]) @@ -140,6 +142,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.parametrize("use_compile", [False, True]) def test_ulysses_attention( + ulysses_degree: int, + ring_degree: int, test_model_cls: type[torch.nn.Module], dtype: torch.dtype, causal: bool, @@ -152,10 +156,8 @@ def test_ulysses_attention( head_size: int, ): """Test Ulysses attention with various parameter combinations.""" - num_processes = 2 - ulysses_degree = 2 # Must match num_processes for this test - ring_degree = 1 sequence_parallel_size = ulysses_degree * ring_degree + num_processes = sequence_parallel_size def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with From cc068b3be615192c70b030251983c8945d54e381 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:11:12 +0800 Subject: [PATCH 48/85] constrain test cases Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/attention/test_ulysses_sequence_parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py index 6f2e5749e0a..f25cb0e04b7 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -130,17 +130,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: TestMultiLayerAttentionModel, ], ) -@pytest.mark.parametrize("ulysses_degree", [2, 4, 8]) +@pytest.mark.parametrize("ulysses_degree", [2, 4]) @pytest.mark.parametrize("ring_degree", [1]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_size", [32]) -@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_sync", [True, False]) -@pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.parametrize("use_compile", [False, True]) +@pytest.mark.parametrize("dynamic", [False]) +@pytest.mark.parametrize("use_compile", [False]) def test_ulysses_attention( ulysses_degree: int, ring_degree: int, From eb7804ea826c7db5fdf06465f234d370d4e2df61 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:55:32 +0800 Subject: [PATCH 49/85] smaller head size Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/diffusion/attention/test_ulysses_sequence_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py index f25cb0e04b7..e28c34fa44d 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -135,7 +135,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("num_heads", [8]) -@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("head_size", [8]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_sync", [True, False]) From 6109c1478240a01157abc1bfbe63913ac52a0873 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:00:15 +0800 Subject: [PATCH 50/85] fix DOC check error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/mkdocs/hooks/generate_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 37c2f8316b6..4e840280b26 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -195,7 +195,7 @@ def generate(self) -> str: main_file_rel = self.main_file.relative_to(ROOT_DIR) content += f'{code_fence}{self.main_file.suffix[1:]}\n--8<-- "{main_file_rel}"\n{code_fence}\n' else: - with open(self.main_file) as f: + with open(self.main_file, encoding="utf-8") as f: # Skip the title from md snippets as it's been included above main_content = f.readlines()[1:] content += self.fix_relative_links("".join(main_content)) From e68c43472552824571ac490c0cace03490e328c2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:02:05 +0800 Subject: [PATCH 51/85] no backward check Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 4 ---- vllm_omni/diffusion/distributed/comm.py | 20 ------------------- 2 files changed, 24 deletions(-) diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py index e28c34fa44d..2e7a2d310ae 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -309,10 +309,6 @@ def ulysses_attention_on_test_model( assert hasattr(layer.attention, "use_ulysses"), f"Layer {i} attention should have use_ulysses attribute" assert layer.attention.use_ulysses, f"Layer {i} attention should be using Ulysses" - # Run backward pass to ensure gradients work - loss = output.sum() - loss.backward() - print( f"Rank {local_rank}: Test passed with " f"batch_size={batch_size}, seq_len={seq_len}, " diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py index 6db377bef72..2fc00e4e6b8 100644 --- a/vllm_omni/diffusion/distributed/comm.py +++ b/vllm_omni/diffusion/distributed/comm.py @@ -112,16 +112,6 @@ def forward( ctx.use_sync = use_sync return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: - return ( - None, - SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync), - None, - None, - None, - ) - def all_to_all_5D( input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False @@ -229,13 +219,3 @@ def forward( ctx.use_sync = use_sync return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: - return ( - None, - SeqAllToAll5D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync), - None, - None, - None, - ) From c1b98647a23a840171fd50fa22baee838422e902 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:03:36 +0800 Subject: [PATCH 52/85] fix logging Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/group_coordinator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index bb4d1bf93ae..02aaa391b1f 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -11,17 +11,16 @@ import torch.distributed from torch.cuda import synchronize from torch.distributed import Backend, ProcessGroup +from vllm.logger import init_logger from vllm_omni.diffusion import envs +logger = init_logger(__name__) + if envs._is_npu(): - print("torch.npu synchronize") + logger.info("torch.npu synchronize") from torch.npu import synchronize -from vllm.logger import init_logger - -logger = init_logger(__name__) - TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) env_info = envs.PACKAGES_CHECKER.get_packages_info() From aa7c063a376fc5b587861aa0095c2629957d72dd Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:04:13 +0800 Subject: [PATCH 53/85] update qwen_image transformer Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/models/qwen_image/qwen_image_transformer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 2cdda8b794e..547f1714e20 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -622,9 +622,6 @@ def forward( # else: # lora_scale = 1.0 - ############################################################ - # parallel inputs - ############################################################ if self.parallel_config.sequence_parallel_size > 1: hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[ get_sequence_parallel_rank() @@ -685,9 +682,6 @@ def get_rotary_emb_chunk(freqs): hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - ############################################################ - # parallel outputs - ############################################################ if self.parallel_config.sequence_parallel_size > 1: output = get_sp_group().all_gather(output, dim=-2) return Transformer2DModelOutput(sample=output) From 0442526fcf26e6a59866a22ffd1c57d2a369b494 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 12 Dec 2025 19:07:47 +0800 Subject: [PATCH 54/85] solve conflicts Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../text_to_image/text_to_image_usp.py | 121 ------------------ 1 file changed, 121 deletions(-) delete mode 100644 examples/offline_inference/text_to_image/text_to_image_usp.py diff --git a/examples/offline_inference/text_to_image/text_to_image_usp.py b/examples/offline_inference/text_to_image/text_to_image_usp.py deleted file mode 100644 index ccaece46b4d..00000000000 --- a/examples/offline_inference/text_to_image/text_to_image_usp.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import time -from pathlib import Path - -import torch - -from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig, set_current_vllm_config -from vllm_omni.diffusion.distributed.parallel_state import destroy_distributed_env, get_world_group -from vllm_omni.entrypoints.omni import Omni -from vllm_omni.utils.platform_utils import detect_device_type, is_npu - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.") - parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.") - parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") - parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") - parser.add_argument( - "--cfg_scale", - type=float, - default=4.0, - help="True classifier-free guidance scale specific to Qwen-Image.", - ) - parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") - parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") - parser.add_argument( - "--output", - type=str, - default="qwen_image_output.png", - help="Path to save the generated image (PNG).", - ) - parser.add_argument( - "--num_images_per_prompt", - type=int, - default=1, - help="Number of images to generate for the given prompt.", - ) - parser.add_argument( - "--num_inference_steps", - type=int, - default=50, - help="Number of denoising steps for the diffusion sampler.", - ) - parser.add_argument( - "--ulysses_degree", - type=int, - default=2, - help="Number of GPUs used for ulysses sequence parallelism.", - ) - parser.add_argument( - "--ring_degree", - type=int, - default=1, - help="Number of GPUs used for ring sequence parallelism.", - ) - return parser.parse_args() - - -def main(): - args = parse_args() - device = detect_device_type() - generator = torch.Generator(device=device).manual_seed(args.seed) - local_rank = get_world_group().local_rank - - # Enable VAE memory optimizations on NPU - vae_use_slicing = is_npu() - vae_use_tiling = is_npu() - sequence_parallel_size = args.ulysses_degree * args.ring_degree - omni_diffusion_config = OmniDiffusionConfig( - parallel_config=DiffusionParallelConfig( - ulysses_degree=args.ulysses_degree, sequence_parallel_size=sequence_parallel_size - ) - ) - with set_current_vllm_config(omni_diffusion_config): - omni = Omni( - model=args.model, - od_config=omni_diffusion_config, - vae_use_slicing=vae_use_slicing, - vae_use_tiling=vae_use_tiling, - ) - parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - torch.cuda.reset_peak_memory_stats() - start_time = time.time() - images = omni.generate( - args.prompt, - height=args.height, - width=args.width, - generator=generator, - true_cfg_scale=args.cfg_scale, - num_inference_steps=args.num_inference_steps, - num_images_per_prompt=args.num_images_per_prompt, - num_outputs_per_prompt=args.num_images_per_prompt, - ) - end_time = time.time() - elapsed_time = end_time - start_time - peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - suffix = output_path.suffix or ".png" - stem = output_path.stem or "qwen_image_output" - if args.num_images_per_prompt <= 1: - images[0].save(output_path) - print(f"Saved generated image to {output_path}") - else: - for idx, img in enumerate(images): - save_path = output_path.parent / f"{stem}_{idx}{suffix}" - img.save(save_path) - print(f"Saved generated image to {save_path}") - if get_world_group().rank == get_world_group().world_size - 1: - print( - f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory / 1e9:.2f} GB, memory: {peak_memory / 1e9:.2f} GB" - ) - destroy_distributed_env() - - -if __name__ == "__main__": - main() From f050c0b4746e573b51e182b98f7cafdffd9876fa Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:31:57 +0800 Subject: [PATCH 55/85] update test ut Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 264 +++++++++++++++--- 1 file changed, 228 insertions(+), 36 deletions(-) diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py index 2e7a2d310ae..a423c41edf7 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +import pickle +import tempfile import pytest import torch @@ -155,17 +157,53 @@ def test_ulysses_attention( num_heads: int, head_size: int, ): - """Test Ulysses attention with various parameter combinations.""" + """Test Ulysses attention by comparing with and without SP enabled.""" sequence_parallel_size = ulysses_degree * ring_degree - num_processes = sequence_parallel_size - def run_torch_spawn(fn, nprocs): - # need to use torch.mp.spawn otherwise will have problems with - # torch.distributed and cuda + # Create temporary files to share results between processes + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + baseline_output_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + sp_output_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + model_state_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + input_data_file = f.name + + try: + # Step 1: Run without SP (baseline with ulysses_degree=1, ring_degree=1) + print("\n[Baseline] Running without SP (ulysses_degree=1, ring_degree=1)...") torch.multiprocessing.spawn( - fn, + ulysses_attention_on_test_model, args=( - num_processes, + 1, # num_processes = 1 for baseline + test_model_cls, + batch_size, + seq_len, + num_heads, + head_size, + dtype, + causal, + use_sync, + dynamic, + use_compile, + 1, # ulysses_degree = 1 + 1, # ring_degree = 1 + 1, # sequence_parallel_size = 1 + baseline_output_file, + model_state_file, + input_data_file, + True, # is_baseline + ), + nprocs=1, + ) + + # Step 2: Run with SP enabled + print(f"\n[SP Test] Running with SP (ulysses_degree={ulysses_degree}, ring_degree={ring_degree})...") + torch.multiprocessing.spawn( + ulysses_attention_on_test_model, + args=( + sequence_parallel_size, # num_processes test_model_cls, batch_size, seq_len, @@ -179,11 +217,82 @@ def run_torch_spawn(fn, nprocs): ulysses_degree, ring_degree, sequence_parallel_size, + sp_output_file, + model_state_file, + input_data_file, + False, # is_baseline ), - nprocs=nprocs, + nprocs=sequence_parallel_size, + ) + + # Step 3: Verify input consistency and compare outputs + print(f"\n{'=' * 80}") + print("Verifying input data consistency...") + with open(input_data_file, "rb") as f: + input_data = pickle.load(f) + input_checksum = hash(input_data.tobytes()) + print(f" Input data shape: {input_data.shape}") + print(f" Input data checksum: {input_checksum}") + print(" ✓ Both baseline and SP used the same input data") + + print(f"\n{'=' * 80}") + print("Comparing outputs between baseline and SP...") + with open(baseline_output_file, "rb") as f: + baseline_output = pickle.load(f) + with open(sp_output_file, "rb") as f: + sp_output = pickle.load(f) + + # Convert to tensors for comparison + baseline_tensor = torch.tensor(baseline_output) + sp_tensor = torch.tensor(sp_output) + + print(f" Baseline output shape: {baseline_tensor.shape}") + print(f" SP output shape: {sp_tensor.shape}") + assert baseline_tensor.shape == sp_tensor.shape, "Output shapes must match!" + + # Calculate differences + abs_diff = torch.abs(baseline_tensor - sp_tensor) + max_abs_diff = abs_diff.max().item() + mean_abs_diff = abs_diff.mean().item() + + # Calculate relative difference (avoid division by zero) + baseline_abs = torch.abs(baseline_tensor) + relative_diff = abs_diff / (baseline_abs + 1e-8) + max_relative_diff = relative_diff.max().item() + mean_relative_diff = relative_diff.mean().item() + + print(f"\n{'=' * 80}") + print("Output Difference Analysis:") + print(f" - Max absolute difference: {max_abs_diff:.6e}") + print(f" - Mean absolute difference: {mean_abs_diff:.6e}") + print(f" - Max relative difference: {max_relative_diff:.6e}") + print(f" - Mean relative difference: {mean_relative_diff:.6e}") + print(f" - Baseline output range: [{baseline_tensor.min().item():.6e}, {baseline_tensor.max().item():.6e}]") + print(f" - SP output range: [{sp_tensor.min().item():.6e}, {sp_tensor.max().item():.6e}]") + print(f"{'=' * 80}\n") + + # Assert that differences are within acceptable tolerance + # For FP16/BF16, we expect some numerical differences due to different computation order + if dtype == torch.float16: + atol, rtol = 1e-4, 1e-2 + elif dtype == torch.bfloat16: + atol, rtol = 1e-4, 1e-2 + else: + atol, rtol = 1e-5, 1e-3 + + assert max_abs_diff < atol or max_relative_diff < rtol, ( + f"Output difference too large: max_abs_diff={max_abs_diff:.6e}, " + f"max_relative_diff={max_relative_diff:.6e}, " + f"tolerance: atol={atol}, rtol={rtol}" ) - run_torch_spawn(ulysses_attention_on_test_model, num_processes) + print("✓ Test passed: SP output matches baseline within tolerance") + + finally: + # Clean up temporary files + for f in [baseline_output_file, sp_output_file, model_state_file, input_data_file]: + if os.path.exists(f): + os.remove(f) def ulysses_attention_on_test_model( @@ -202,9 +311,18 @@ def ulysses_attention_on_test_model( ulysses_degree: int, ring_degree: int, sequence_parallel_size: int, + output_file: str, + model_state_file: str, + input_data_file: str, + is_baseline: bool, ): - """Run Ulysses attention test on a test model.""" - current_platform.seed_everything(42) + """Run Ulysses attention test on a test model and save results for comparison.""" + # Use fixed seed for reproducibility across baseline and SP runs + RANDOM_SEED = 42 + current_platform.seed_everything(RANDOM_SEED) + + mode_str = "Baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" + print(f"\n[{mode_str}] Rank {local_rank}/{world_size} - Random seed set to {RANDOM_SEED}") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) @@ -223,7 +341,7 @@ def ulysses_attention_on_test_model( # Initialize distributed environment init_distributed_environment() - # Set up OmniDiffusionConfig with Ulysses parallel config + # Set up OmniDiffusionConfig with parallel config parallel_config = DiffusionParallelConfig( pipeline_parallel_size=1, data_parallel_size=1, @@ -240,7 +358,7 @@ def ulysses_attention_on_test_model( parallel_config=parallel_config, ) - # Initialize model parallel with Ulysses + # Initialize model parallel initialize_model_parallel( data_parallel_degree=1, classifier_free_guidance_degree=1, @@ -272,16 +390,56 @@ def ulysses_attention_on_test_model( model_kwargs["num_layers"] = 2 model = test_model_cls(**model_kwargs) - model = model.to(device).to(dtype) - # Create input - # In sequence parallel, each rank gets seq_len / sequence_parallel_size + # For baseline: Generate and save model state and input data + # This ensures both baseline and SP use exactly the same initialization + if is_baseline and local_rank == 0: + # Save model state for reuse (before any computation) + model_state = {k: v.cpu() for k, v in model.state_dict().items()} + with open(model_state_file, "wb") as f: + pickle.dump(model_state, f) + + # Generate and save full input data with fixed seed + # Reinitialize RNG to ensure reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + full_hidden_states = torch.randn( + (batch_size, seq_len, hidden_size), + dtype=dtype, + device="cpu", + ) + with open(input_data_file, "wb") as f: + pickle.dump(full_hidden_states.numpy(), f) + + print("[Baseline] Saved model state and input data") + + # Synchronize to ensure baseline has saved data before SP loads it + if world_size > 1: + torch.distributed.barrier() + + # IMPORTANT: Both baseline and SP load the same model state and input data + # This ensures exact same initialization and input for fair comparison + with open(model_state_file, "rb") as f: + model_state = pickle.load(f) + model.load_state_dict({k: v.to(device).to(dtype) for k, v in model_state.items()}) + + with open(input_data_file, "rb") as f: + full_hidden_states_np = pickle.load(f) + full_hidden_states = torch.from_numpy(full_hidden_states_np).to(device).to(dtype) + + print(f"[Rank {local_rank}] Loaded model state and full input data with shape {full_hidden_states.shape}") + + # Split input sequence according to sequence parallel BEFORE model forward + # Each rank gets a contiguous chunk of the sequence dimension local_seq_len = seq_len // sequence_parallel_size - hidden_states = torch.randn( - (batch_size, local_seq_len, hidden_size), - dtype=dtype, - device=device, + start_idx = local_rank * local_seq_len + end_idx = start_idx + local_seq_len + hidden_states = full_hidden_states[:, start_idx:end_idx, :].contiguous() + + print( + f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, " + f"indices=[{start_idx}:{end_idx}], local_shape={hidden_states.shape}" ) if dynamic: @@ -292,28 +450,62 @@ def ulysses_attention_on_test_model( if use_compile: model = torch.compile(model) - # Run forward pass + # Run forward pass with local sequence chunk + print(f"[Rank {local_rank}] Running forward pass...") output = model(hidden_states) + print(f"[Rank {local_rank}] Forward pass completed, output shape: {output.shape}") # Verify output shape assert output.shape == (batch_size, local_seq_len, hidden_size), ( f"Output shape mismatch: expected {(batch_size, local_seq_len, hidden_size)}, got {output.shape}" ) - # Verify that Attention is using Ulysses - if hasattr(model, "attention"): - assert hasattr(model.attention, "use_ulysses"), "Attention should have use_ulysses attribute" - assert model.attention.use_ulysses, "Attention should be using Ulysses" - elif hasattr(model, "layers"): - for i, layer in enumerate(model.layers): - assert hasattr(layer.attention, "use_ulysses"), f"Layer {i} attention should have use_ulysses attribute" - assert layer.attention.use_ulysses, f"Layer {i} attention should be using Ulysses" + # Verify SP usage for non-baseline runs + if not is_baseline: + if hasattr(model, "attention"): + assert hasattr(model.attention, "use_ulysses"), "Attention should have use_ulysses attribute" + assert model.attention.use_ulysses, "Attention should be using Ulysses" + elif hasattr(model, "layers"): + for i, layer in enumerate(model.layers): + assert hasattr(layer.attention, "use_ulysses"), ( + f"Layer {i} attention should have use_ulysses attribute" + ) + assert layer.attention.use_ulysses, f"Layer {i} attention should be using Ulysses" + + # Gather outputs from all ranks AFTER computation + if world_size > 1: + print(f"[Rank {local_rank}] Gathering outputs from all {world_size} ranks...") + # Gather all outputs to rank 0 + gathered_outputs = [torch.zeros_like(output) for _ in range(world_size)] + torch.distributed.all_gather(gathered_outputs, output) + if local_rank == 0: + # Concatenate along sequence dimension to reconstruct full sequence + full_output = torch.cat(gathered_outputs, dim=1) + print(f"[Rank 0] Gathered and concatenated outputs: {full_output.shape}") + # Verify the full output shape matches expected + assert full_output.shape == (batch_size, seq_len, hidden_size), ( + f"Gathered output shape mismatch: expected {(batch_size, seq_len, hidden_size)}, " + f"got {full_output.shape}" + ) + else: + full_output = None + else: + # For baseline (world_size=1), output is already complete + full_output = output + print(f"[Rank 0] No gather needed (world_size=1), output shape: {full_output.shape}") + + # Save output from rank 0 for comparison + if local_rank == 0: + output_np = full_output.cpu().numpy() + with open(output_file, "wb") as f: + pickle.dump(output_np, f) + + mode_str = "baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" + print( + f"\n[{mode_str}] ✓ Saved output with shape {full_output.shape}:\n" + f" - batch_size={batch_size}, seq_len={seq_len}\n" + f" - num_heads={num_heads}, head_size={head_size}\n" + f" - dtype={dtype}, causal={causal}, use_sync={use_sync}\n" + ) - print( - f"Rank {local_rank}: Test passed with " - f"batch_size={batch_size}, seq_len={seq_len}, " - f"num_heads={num_heads}, head_size={head_size}, " - f"dtype={dtype}, causal={causal}, use_sync={use_sync}, " - f"dynamic={dynamic}, use_compile={use_compile}" - ) destroy_distributed_env() From 01a20ea72ebd45bcba8c9cd646a737e3bde6b2a3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:44:30 +0800 Subject: [PATCH 56/85] fix ut and adapt to npu Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py index a423c41edf7..aea03809daf 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -19,6 +19,15 @@ init_distributed_environment, initialize_model_parallel, ) +from vllm_omni.utils.platform_utils import detect_device_type + +device_type = detect_device_type() +if device_type == "cuda": + torch_device = torch.cuda +elif device_type == "npu": + torch_device = torch.npu +else: + raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") def update_environment_variables(envs_dict: dict[str, str]): @@ -324,8 +333,8 @@ def ulysses_attention_on_test_model( mode_str = "Baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" print(f"\n[{mode_str}] Rank {local_rank}/{world_size} - Random seed set to {RANDOM_SEED}") - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) torch.set_default_device(device) torch.set_default_dtype(dtype) @@ -403,14 +412,14 @@ def ulysses_attention_on_test_model( # Generate and save full input data with fixed seed # Reinitialize RNG to ensure reproducibility torch.manual_seed(42) - torch.cuda.manual_seed_all(42) + torch_device.manual_seed_all(42) full_hidden_states = torch.randn( (batch_size, seq_len, hidden_size), dtype=dtype, device="cpu", ) with open(input_data_file, "wb") as f: - pickle.dump(full_hidden_states.numpy(), f) + pickle.dump(full_hidden_states.detach().cpu().float().numpy(), f) print("[Baseline] Saved model state and input data") @@ -496,7 +505,7 @@ def ulysses_attention_on_test_model( # Save output from rank 0 for comparison if local_rank == 0: - output_np = full_output.cpu().numpy() + output_np = full_output.detach().cpu().float().numpy() with open(output_file, "wb") as f: pickle.dump(output_np, f) From 253e442493a742943d96a9e80e0f731f13911ee5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:10:44 +0800 Subject: [PATCH 57/85] update device_count and set_device Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/parallel_state.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 3936bbd2efb..fa375b90d96 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -34,7 +34,6 @@ import torch import torch.distributed import vllm.distributed.parallel_state as vllm_parallel_state -from torch.cuda import device_count, set_device from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -46,15 +45,12 @@ SequenceParallelGroupCoordinator, ) -try: - from torch_musa.core.device import device_count, set_device -except ModuleNotFoundError: - pass - -try: +if envs._is_npu(): from torch.npu import device_count, set_device -except ModuleNotFoundError: - pass +elif envs._is_musa(): + from torch_musa.core.device import device_count, set_device +else: + from torch.cuda import device_count, set_device env_info = envs.PACKAGES_CHECKER.get_packages_info() From 9cd123db5a0627549ccc894ebdfd8a6d5fb5ee7a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:33:32 +0800 Subject: [PATCH 58/85] test comm Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/diffusion/distributed/test_comm.py | 294 +++++++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 tests/diffusion/distributed/test_comm.py diff --git a/tests/diffusion/distributed/test_comm.py b/tests/diffusion/distributed/test_comm.py new file mode 100644 index 00000000000..44539022ea0 --- /dev/null +++ b/tests/diffusion/distributed/test_comm.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for SeqAllToAll4D and SeqAllToAll5D communication primitives.""" + +import os + +import pytest +import torch + +from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D, SeqAllToAll5D +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_sp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.utils.platform_utils import detect_device_type + +device_type = detect_device_type() +if device_type == "cuda": + torch_device = torch.cuda +elif device_type == "npu": + torch_device = torch.npu +else: + raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len_per_rank", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("use_sync", [False, True]) +def test_4d_identity( + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Test that two consecutive all-to-all operations return the original input.""" + # Ensure num_heads is divisible by world_size + if num_heads % world_size != 0: + pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})") + + # Run test with multiprocessing spawn + torch.multiprocessing.spawn( + _test_4d_identity_worker, + args=( + world_size, + dtype, + batch_size, + seq_len_per_rank, + num_heads, + head_size, + use_sync, + ), + nprocs=world_size, + ) + + +def _test_4d_identity_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Worker function for test_4d_identity.""" + # Set device + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + # Set environment variables for distributed training + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29500", + } + ) + + # Initialize distributed environment + init_distributed_environment() + initialize_model_parallel(sequence_parallel_size=world_size) + + sp_group = get_sp_group() + + # Create input tensor: (bs, seqlen/P, hc, hs) + torch.manual_seed(42 + local_rank) + input_tensor = torch.randn( + batch_size, + seq_len_per_rank, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Save original input for comparison + original_input = input_tensor.clone() + + # First all-to-all: (bs, seqlen/P, hc, hs) -> (bs, seqlen, hc/P, hs) + intermediate = SeqAllToAll4D.apply( + sp_group, + input_tensor, + scatter_idx=2, # scatter head dimension + gather_idx=1, # gather sequence dimension + use_sync=use_sync, + ) + + # Verify intermediate shape + expected_shape = ( + batch_size, + seq_len_per_rank * world_size, + num_heads // world_size, + head_size, + ) + assert intermediate.shape == expected_shape, ( + f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}" + ) + + # Second all-to-all: (bs, seqlen, hc/P, hs) -> (bs, seqlen/P, hc, hs) + output = SeqAllToAll4D.apply( + sp_group, + intermediate, + scatter_idx=1, # scatter sequence dimension + gather_idx=2, # gather head dimension + use_sync=use_sync, + ) + + # Verify output shape matches input + assert output.shape == original_input.shape, ( + f"Output shape mismatch: expected {original_input.shape}, got {output.shape}" + ) + + # Verify output matches original input + torch.testing.assert_close( + output, + original_input, + rtol=1e-5, + atol=1e-5, + msg="Output does not match original input after two all-to-all operations", + ) + + # Cleanup distributed environment + destroy_distributed_env() + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len_per_rank", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("use_sync", [False, True]) +def test_5d_identity( + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Test that two consecutive all-to-all operations return the original input.""" + # Ensure num_heads is divisible by world_size + if num_heads % world_size != 0: + pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})") + + # Run test with multiprocessing spawn + torch.multiprocessing.spawn( + _test_5d_identity_worker, + args=( + world_size, + dtype, + batch_size, + seq_len_per_rank, + num_heads, + head_size, + use_sync, + ), + nprocs=world_size, + ) + + +def _test_5d_identity_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Worker function for test_5d_identity.""" + # Set device + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + # Set environment variables for distributed training + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29500", + } + ) + + # Initialize distributed environment + init_distributed_environment() + initialize_model_parallel(sequence_parallel_size=world_size) + + sp_group = get_sp_group() + + # Create input tensor: (bs, seqlen/P, 3, hc, hs) + # The '3' dimension is for Q, K, V + torch.manual_seed(42 + local_rank) + input_tensor = torch.randn( + batch_size, + seq_len_per_rank, + 3, # Q, K, V + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Save original input for comparison + original_input = input_tensor.clone() + + # First all-to-all: (bs, seqlen/P, 3, hc, hs) -> (bs, seqlen, 3, hc/P, hs) + intermediate = SeqAllToAll5D.apply( + sp_group, + input_tensor, + scatter_idx=3, # scatter head dimension + gather_idx=1, # gather sequence dimension + use_sync=use_sync, + ) + + # Verify intermediate shape + expected_shape = ( + batch_size, + seq_len_per_rank * world_size, + 3, + num_heads // world_size, + head_size, + ) + assert intermediate.shape == expected_shape, ( + f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}" + ) + + # Second all-to-all: (bs, seqlen, 3, hc/P, hs) -> (bs, seqlen/P, 3, hc, hs) + output = SeqAllToAll5D.apply( + sp_group, + intermediate, + scatter_idx=1, # scatter sequence dimension + gather_idx=3, # gather head dimension + use_sync=use_sync, + ) + + # Verify output shape matches input + assert output.shape == original_input.shape, ( + f"Output shape mismatch: expected {original_input.shape}, got {output.shape}" + ) + + # Verify output matches original input + torch.testing.assert_close( + output, + original_input, + rtol=1e-5, + atol=1e-5, + msg="Output does not match original input after two all-to-all operations", + ) + + # Cleanup distributed environment + destroy_distributed_env() From ccee08cf0ccf3279aea757a7e6ed1dcfadbf5931 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:02:28 +0800 Subject: [PATCH 59/85] fix comm test error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/diffusion/distributed/test_comm.py | 34 +++++++++++------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/diffusion/distributed/test_comm.py b/tests/diffusion/distributed/test_comm.py index 44539022ea0..7bd6386796e 100644 --- a/tests/diffusion/distributed/test_comm.py +++ b/tests/diffusion/distributed/test_comm.py @@ -96,9 +96,8 @@ def _test_4d_identity_worker( # Initialize distributed environment init_distributed_environment() - initialize_model_parallel(sequence_parallel_size=world_size) - - sp_group = get_sp_group() + initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default + sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group # Create input tensor: (bs, seqlen/P, hc, hs) torch.manual_seed(42 + local_rank) @@ -118,9 +117,9 @@ def _test_4d_identity_worker( intermediate = SeqAllToAll4D.apply( sp_group, input_tensor, - scatter_idx=2, # scatter head dimension - gather_idx=1, # gather sequence dimension - use_sync=use_sync, + 2, # scatter head dimension + 1, # gather sequence dimension + use_sync, ) # Verify intermediate shape @@ -138,9 +137,9 @@ def _test_4d_identity_worker( output = SeqAllToAll4D.apply( sp_group, intermediate, - scatter_idx=1, # scatter sequence dimension - gather_idx=2, # gather head dimension - use_sync=use_sync, + 1, # scatter sequence dimension + 2, # gather head dimension + use_sync, ) # Verify output shape matches input @@ -226,9 +225,8 @@ def _test_5d_identity_worker( # Initialize distributed environment init_distributed_environment() - initialize_model_parallel(sequence_parallel_size=world_size) - - sp_group = get_sp_group() + initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default + sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group # Create input tensor: (bs, seqlen/P, 3, hc, hs) # The '3' dimension is for Q, K, V @@ -250,9 +248,9 @@ def _test_5d_identity_worker( intermediate = SeqAllToAll5D.apply( sp_group, input_tensor, - scatter_idx=3, # scatter head dimension - gather_idx=1, # gather sequence dimension - use_sync=use_sync, + 3, # scatter head dimension + 1, # gather sequence dimension + use_sync, ) # Verify intermediate shape @@ -271,9 +269,9 @@ def _test_5d_identity_worker( output = SeqAllToAll5D.apply( sp_group, intermediate, - scatter_idx=1, # scatter sequence dimension - gather_idx=3, # gather head dimension - use_sync=use_sync, + 1, # scatter sequence dimension + 3, # gather head dimension + use_sync, ) # Verify output shape matches input From 54b1c67abf53f75ae402e46e48f1bd5511a5b654 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:44:11 +0800 Subject: [PATCH 60/85] set default sp config Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- examples/offline_inference/text_to_image/text_to_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index a534c9d13f0..9b28de6454f 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -7,7 +7,7 @@ import torch -from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -61,7 +61,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--ulysses_degree", type=int, - default=2, + default=1, help="Number of GPUs used for ulysses sequence parallelism.", ) parser.add_argument( @@ -111,7 +111,6 @@ def main(): # (e.g., QwenImagePipeline or FluxPipeline) } - assert args.ring_degree == 1, "Ring attention is not supported yet" parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) omni = Omni( @@ -129,6 +128,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") @@ -163,4 +163,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 3f7050ca6876f08ac6ea12ab9e0431dd1c2d7d6a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:03:50 +0800 Subject: [PATCH 61/85] test pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .buildkite/pipeline.yml | 17 +++++++++++++++++ .buildkite/scripts/simple_test.sh | 2 ++ 2 files changed, 19 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 952954e81d2..fa47c07a649 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -49,6 +49,23 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Ulysses Sequence Parallelism Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - pytest -s -v tests/diffusion/attention/test_ulysses_sequence_parallel.py tests/diffusion/distributed/test_comm.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Omni Model Test" timeout_in_minutes: 15 depends_on: image-build diff --git a/.buildkite/scripts/simple_test.sh b/.buildkite/scripts/simple_test.sh index 157ad9ca09c..bbd9be2f643 100755 --- a/.buildkite/scripts/simple_test.sh +++ b/.buildkite/scripts/simple_test.sh @@ -51,3 +51,5 @@ VENV_PYTHON="${VENV_DIR}/bin/python" "${UV_BIN}" pip install --python "${VENV_PYTHON}" -e ".[dev]" "${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/ "${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/ +"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/attention/ +"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/distributed/ From 07fe23deaeef8d12ede8eaf18345e1c909f4bb31 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:32:07 +0800 Subject: [PATCH 62/85] cache & parallel support: qwen-image edit Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../image_to_image/image_edit.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 44051c19831..c4830c9e03f 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -24,6 +24,7 @@ import torch from PIL import Image +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -65,6 +66,8 @@ def parse_args() -> argparse.Namespace: default=4.0, help="True classifier-free guidance scale specific to Qwen-Image-Edit.", ) + parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") + parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") parser.add_argument( "--output", type=str, @@ -94,6 +97,18 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=1, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) return parser.parse_args() @@ -114,7 +129,6 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - # Configure cache based on backend type cache_config = None if args.cache_backend == "cache_dit": @@ -131,30 +145,39 @@ def main(): "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) "taylorseer_order": 1, # TaylorSeer polynomial order # SCM (Step Computation Masking) parameters [cache-dit only] - "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" + "scm_steps_mask_policy": "fast", # SCM mask policy: "slow", "medium", "fast", "ultra" "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" } elif args.cache_backend == "tea_cache": - raise ValueError("TeaCache is not supported for image-to-image generation.") + # TeaCache configuration + # All parameters marked with [tea_cache only] in DiffusionCacheConfig + cache_config = { + # TeaCache parameters [tea_cache only] + "rel_l1_thresh": 0.2, # Threshold for accumulated relative L1 distance + # Note: coefficients will use model-specific defaults based on model_type + # (e.g., QwenImagePipeline or FluxPipeline) + } - # Initialize Omni with QwenImageEditPipeline + assert args.ring_degree == 1, "Ring attention is not supported yet" + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) omni = Omni( model=args.model, - model_class_name="QwenImageEditPipeline", vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, cache_backend=args.cache_backend, cache_config=cache_config, + parallel_config=parallel_config, ) - print("Pipeline loaded") + print("Pipeline loaded") # Time profiling for generation print(f"\n{'=' * 60}") print("Generation Configuration:") print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") - print(f" Input image size: {input_image.size}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") + print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") generation_start = time.perf_counter() @@ -165,6 +188,8 @@ def main(): negative_prompt=args.negative_prompt, generator=generator, true_cfg_scale=args.cfg_scale, + height=args.height, + width=args.width, num_inference_steps=args.num_inference_steps, num_outputs_per_prompt=args.num_outputs_per_prompt, ) From c4421d55f1f18bfae65c7618bf3056a91df496de Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:39:54 +0800 Subject: [PATCH 63/85] correct example script path in doc Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/cache_dit_acceleration.md | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/user_guide/cache_dit_acceleration.md b/docs/user_guide/cache_dit_acceleration.md index 35aeae3c968..d72a1ce7dd7 100644 --- a/docs/user_guide/cache_dit_acceleration.md +++ b/docs/user_guide/cache_dit_acceleration.md @@ -50,6 +50,42 @@ omni = Omni( ) ``` +### Example Script + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example with cache-dit acceleration. + +```bash +# Enable cache-dit with default parameters +cd examples/offline_inference/text_to_image +python text_to_image.py \ + --prompt "a cup of coffee on the table" \ + --enable_cache_dit \ + --num_inference_steps 50 +``` + +The `--enable_cache_dit` flag enables cache-dit acceleration with these customized parameters: + +```python +omni = Omni( + ... + cache_backend="cache_dit" if args.enable_cache_dit else None, + cache_config={ + # Scheme: Hybrid DBCache + SCM + TaylorSeer + # DBCache + "Fn_compute_blocks": 8, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.12, + # TaylorSeer + "enable_taylorseer": True, + "taylorseer_order": 1, + # SCM + "scm_steps_mask_policy": "fast", + "scm_steps_policy": "dynamic", + }, +) + +``` ## Acceleration Methods From 7f683d349ff786e9969556e0daa47546af44e9b0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:58:47 +0800 Subject: [PATCH 64/85] remove cache support Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../image_to_image/image_edit.py | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index c4830c9e03f..3567fa84e36 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -86,17 +86,6 @@ def parse_args() -> argparse.Namespace: default=50, help="Number of denoising steps for the diffusion sampler.", ) - parser.add_argument( - "--cache_backend", - type=str, - default=None, - choices=["cache_dit", "tea_cache"], - help=( - "Cache backend to use for acceleration. " - "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " - "Default: None (no cache acceleration)." - ), - ) parser.add_argument( "--ulysses_degree", type=int, @@ -129,34 +118,6 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - # Configure cache based on backend type - cache_config = None - if args.cache_backend == "cache_dit": - # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer - # All parameters marked with [cache-dit only] in DiffusionCacheConfig - cache_config = { - # DBCache parameters [cache-dit only] - "Fn_compute_blocks": 1, # Optimized for single-transformer models - "Bn_compute_blocks": 0, # Number of backward compute blocks - "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) - "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching - "max_continuous_cached_steps": 3, # Limit to prevent precision degradation - # TaylorSeer parameters [cache-dit only] - "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) - "taylorseer_order": 1, # TaylorSeer polynomial order - # SCM (Step Computation Masking) parameters [cache-dit only] - "scm_steps_mask_policy": "fast", # SCM mask policy: "slow", "medium", "fast", "ultra" - "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" - } - elif args.cache_backend == "tea_cache": - # TeaCache configuration - # All parameters marked with [tea_cache only] in DiffusionCacheConfig - cache_config = { - # TeaCache parameters [tea_cache only] - "rel_l1_thresh": 0.2, # Threshold for accumulated relative L1 distance - # Note: coefficients will use model-specific defaults based on model_type - # (e.g., QwenImagePipeline or FluxPipeline) - } assert args.ring_degree == 1, "Ring attention is not supported yet" parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) @@ -164,8 +125,6 @@ def main(): model=args.model, vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, - cache_backend=args.cache_backend, - cache_config=cache_config, parallel_config=parallel_config, ) @@ -175,7 +134,6 @@ def main(): print("Generation Configuration:") print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") - print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") From e2c98f3f3b978f5f80de8b9af32904f0ec279ef9 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 19:17:48 +0800 Subject: [PATCH 65/85] update docs Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion_acceleration.md | 63 ++++++++-- docs/user_guide/parallelism_acceleration.md | 127 ++++++++++++++++++++ 2 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 docs/user_guide/parallelism_acceleration.md diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 0e2cd06a0df..34c4c97beb5 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -1,6 +1,6 @@ # Diffusion Acceleration Overview -vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods intelligently cache intermediate computations to avoid redundant work across diffusion timesteps. +vLLM-Omni supports various acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. ## Supported Acceleration Methods @@ -14,8 +14,14 @@ vLLM-Omni currently supports two main cache acceleration backends: Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality. +vLLM-Omni also supports the sequence parallelism (SP) for the diffusion model, that includes: + +1. [Ulysses-SP](parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. + ## Quick Comparison +### Cache Methods + | Method | Configuration | Description | Best For | |--------|--------------|-------------|----------| | **TeaCache** | `cache_backend="tea_cache"` | Simple, adaptive caching with minimal configuration | Quick setup, balanced speed/quality | @@ -23,18 +29,18 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma ## Supported Models -The following table shows which models are currently supported by each cache backend: +The following table shows which models are currently supported by each acceleration method: -| Model | Model Identifier | TeaCache | Cache-DiT | -|-------|-----------------|----------|-----------| -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ❌ | ✅ | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | +|-------|-----------------|----------|-----------|-----------| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ❌ | ✅* | ✅ | ## Performance Benchmarks -The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Image-Edit** models with 50 inference steps: +The following benchmarks were measured on **Qwen/Qwen-Image** model generating images (**1024x1024**) with 50 inference steps: !!! note "Benchmark Disclaimer" These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: @@ -56,6 +62,15 @@ The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Im | **Qwen/Qwen-Image-Edit** | Cache-DiT | Default (Fn=1, Bn=0, W=4, TaylorSeer disabled, SCM disabled) | 21.6s | **2.38x** | - | +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. + +| Configuration | Ulysses degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 112.5s | 1.0x | +| Ulysses-SP | 2 | 65.2s | 1.73x | +| Ulysses-SP | 4 | 39.6s | 2.84x | +| Ulysses-SP | 8 | 30.8s | 3.65x | + ## Quick Start ### Using TeaCache @@ -92,6 +107,38 @@ omni = Omni( outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) ``` +### Using Ulysses-SP + +Run text-to-image: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + + +Run image-to-image: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image-Edit", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="turn this cat to a dog", + pil_image=input_image, num_inference_steps=50) +``` + ## Documentation For detailed information on each acceleration method: diff --git a/docs/user_guide/parallelism_acceleration.md b/docs/user_guide/parallelism_acceleration.md new file mode 100644 index 00000000000..642ddc1eed2 --- /dev/null +++ b/docs/user_guide/parallelism_acceleration.md @@ -0,0 +1,127 @@ +# Parallelism Acceleration Guide + +This guide includes how to use parallelism methods in vLLM-Omni to speed up diffusion model inference as well as reduce the memory requirement on each device. + +## Overview + +The following parallelism methods are currently supported in vLLM-Omni: +1. DeepSpeed Ulysses Sequence Parallel (Ulysses-SP) ([paper](https://arxiv.org/pdf/2309.14509)): Ulysses-SP splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. + + +The following table shows which models are currently supported by parallelism method: + + +| Model | Model Identifier | Ulysses-SP | +|-------|-----------------|-----------| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | + +### Sequence Parallelism + +#### Ulysses-SP + +##### Quick Start + +An example of using Ulysses-SP is shown below: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. + +| Configuration | Ulysses degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 112.5s | 1.0x | +| Ulysses-SP | 2 | 65.2s | 1.73x | +| Ulysses-SP | 4 | 39.6s | 2.84x | +| Ulysses-SP | 8 | 30.8s | 3.65x | + +##### How to parallelize a new model + +If a diffusion model has been deployed in vLLM-Omni and supports single-card inference, you can refer to the following instruction on how to parallelize this model with Ulysses-SP. + +First, please edit the `TransformerModel`'s `forward` function in the `xxx_model_transformer.py` to make the inputs (image hidden states, positional embeddings, etc.) as chunks separated at the sequence dimension. Taking `qwen_image_transformer.py` as an example: + +```diff +class QwenImageTransformer2DModel(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + ... + ): ++ if self.parallel_config.sequence_parallel_size > 1: ++ hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[ ++ get_sequence_parallel_rank() ++ ] + + hidden_states = self.img_in(hidden_states) + + ... + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + ++ def get_rotary_emb_chunk(freqs): ++ freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] ++ return freqs + ++ if self.parallel_config.sequence_parallel_size > 1: ++ img_freqs, txt_freqs = image_rotary_emb ++ img_freqs = get_rotary_emb_chunk(img_freqs) ++ image_rotary_emb = (img_freqs, txt_freqs) +``` + +Next, at the end of the `forward` function, please call `get_sp_group().all_gather` to gather the chunked outputs across devices, and concatenate them at the sequence dimension. + + +```diff +class QwenImageTransformer2DModel(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + ... + ): + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + ++ if self.parallel_config.sequence_parallel_size > 1: ++ output = get_sp_group().all_gather(output, dim=-2) + return Transformer2DModelOutput(sample=output) +``` + +Finally, you can set the parallel configuration and pass it to `Omni` and start parallel inference with: +```diff +from vllm_omni import Omni ++from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", ++ parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) +``` From e2658d0bedd39237e8e16712c4f5f1c1741b6b56 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 19:21:09 +0800 Subject: [PATCH 66/85] updates Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/.nav.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/.nav.yml b/docs/.nav.yml index 2666f13eeb3..57bd9cb7265 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -21,9 +21,10 @@ nav: - configuration/README.md - configuration/* - Diffusion Acceleration: - - Overview: user_guide/diffusion_acceleration.md - - TeaCache: user_guide/teacache.md - - Cache-DiT: user_guide/cache_dit_acceleration.md + - Overview: user_guide/diffusion_acceleration.md + - TeaCache: user_guide/teacache.md + - Cache-DiT: user_guide/cache_dit_acceleration.md + - Parallelism Acceleration: user_guide/parallelism_acceleration.md - Models: - models/supported_models.md - Developer Guide: From 6239b065cf036b1db174a6204c6e3e915640d7ee Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 19:27:21 +0800 Subject: [PATCH 67/85] e2e test Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .buildkite/pipeline.yml | 2 +- .../test_sequence_parallel.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/offline_inference/test_sequence_parallel.py diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index fa47c07a649..1ee63e72bf4 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -53,7 +53,7 @@ steps: timeout_in_minutes: 15 depends_on: image-build commands: - - pytest -s -v tests/diffusion/attention/test_ulysses_sequence_parallel.py tests/diffusion/distributed/test_comm.py + - pytest -s -v tests/diffusion/attention/test_ulysses_sequence_parallel.py tests/diffusion/distributed/test_comm.py tests/e2e/offline_inference/test_sequence_parallel.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py new file mode 100644 index 00000000000..7557162e373 --- /dev/null +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +System test for Ulysses sequence parallel backend. + +This test verifies that Ulysses-SP (DeepSpeed Ulysses Sequence Parallel) works +correctly with diffusion models. It uses minimal settings to keep test time +short for CI. +""" + +import os +import sys +from pathlib import Path + +import pytest +import torch + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.distributed.parallel_state import device_count +from vllm_omni.diffusion.envs import get_device_name + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +# Use random weights model for testing +models = ["riverclouds/qwen_image_random"] + + +@pytest.mark.parametrize("model_name", models) +@pytest.mark.parametrize("ulysses_degree", [1, 2]) +@pytest.mark.parametrize("ring_degree", [1]) +def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: int): + """Test SP (Ulysses-SP + Ring-SP) backend with diffusion model.""" + # Skip if not enough GPUs available + if device_count() < ulysses_degree: + pytest.skip(f"Test requires {ulysses_degree} GPUs but only {device_count()} available") + + # Configure sequence parallel with DiffusionParallelConfig + parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree) + + m = Omni( + model=model_name, + parallel_config=parallel_config, + ) + + # Use minimal settings for fast testing + height = 256 + width = 256 + num_inference_steps = 4 # Minimal steps for fast test + + images = m.generate( + "a photo of a cat sitting on a laptop keyboard", + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + generator=torch.Generator(get_device_name()).manual_seed(42), + num_outputs_per_prompt=1, # Single output for speed + ) + + # Verify generation succeeded + assert images is not None + assert len(images) == 1 + # Check image size + assert images[0].width == width + assert images[0].height == height From a26dc01289b9a6b71e5702bcdab6b3871ee9a9cf Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 15 Dec 2025 19:35:57 +0800 Subject: [PATCH 68/85] fix image edit shape Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- examples/offline_inference/image_to_image/image_edit.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 3567fa84e36..56272683f51 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -66,8 +66,6 @@ def parse_args() -> argparse.Namespace: default=4.0, help="True classifier-free guidance scale specific to Qwen-Image-Edit.", ) - parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") - parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") parser.add_argument( "--output", type=str, @@ -123,6 +121,7 @@ def main(): parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) omni = Omni( model=args.model, + model_class_name="QwenImageEditPipeline", vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, parallel_config=parallel_config, @@ -135,7 +134,6 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") - print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") generation_start = time.perf_counter() @@ -146,8 +144,6 @@ def main(): negative_prompt=args.negative_prompt, generator=generator, true_cfg_scale=args.cfg_scale, - height=args.height, - width=args.width, num_inference_steps=args.num_inference_steps, num_outputs_per_prompt=args.num_outputs_per_prompt, ) From 8c9e400050704956b528abd46a15a43e0a441317 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 09:40:49 +0800 Subject: [PATCH 69/85] fix ci Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index fdf3dfd958a..d486f0d8f86 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -5,11 +5,10 @@ import os import random from collections.abc import Callable -from dataclasses import dataclass, field, fields - -from pydantic import Field, model_validator +from pydantic import model_validator from typing import Any from contextlib import contextmanager +from dataclasses import dataclass, field, fields from functools import lru_cache import torch From 597af26f8b91575dbd64cc12ebf4357812da3f81 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 09:52:02 +0800 Subject: [PATCH 70/85] fix mkdocs Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/comm.py | 6 +++--- vllm_omni/diffusion/distributed/parallel_state.py | 7 ------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py index 2fc00e4e6b8..b5f7aa32a4f 100644 --- a/vllm_omni/diffusion/distributed/comm.py +++ b/vllm_omni/diffusion/distributed/comm.py @@ -19,7 +19,7 @@ def all_to_all_4D( input (torch.tensor): a tensor sharded along dim scatter dim scatter_idx (int): default 1 gather_idx (int): default 2 - group : torch process group + group (torch.distributed.ProcessGroup): torch process group use_sync (bool): whether to synchronize after all-to-all Returns: @@ -124,8 +124,8 @@ def all_to_all_5D( input (torch.tensor): a tensor sharded along dim scatter dim scatter_idx (int): default 1 gather_idx (int): default 2 - group : torch process group - use_sync: whether to synchronize after all-to-all + group (torch.distributed.ProcessGroup): torch process group + use_sync (bool): whether to synchronize after all-to-all Returns: torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index fa375b90d96..387639d1722 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -238,13 +238,6 @@ def get_ranks(self, token): to obtain multiple parallel types, we can use a hyphen '-' to separate them. For example, if we want to obtain the TP_DP group, the token should be 'tp-dp'. - - independent_ep (bool: True): - This flag controls whether we treat EP and DP independently. - EP shares ranks with DP, if we want to get ranks related to - EP, we should set the flag. For example, get_ranks('dp', True) - will get DP modulo EP group, and get_ranks('dp', False) will - get full DP group. """ mask = self.get_mask(self.order, token) ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) From 74da8cbb4e10ccabc4e01b75c77c08dd8d9ef846 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 09:57:54 +0800 Subject: [PATCH 71/85] fix docs Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/cache_dit_acceleration.md | 36 ----------------------- 1 file changed, 36 deletions(-) diff --git a/docs/user_guide/cache_dit_acceleration.md b/docs/user_guide/cache_dit_acceleration.md index d72a1ce7dd7..35aeae3c968 100644 --- a/docs/user_guide/cache_dit_acceleration.md +++ b/docs/user_guide/cache_dit_acceleration.md @@ -50,42 +50,6 @@ omni = Omni( ) ``` -### Example Script - -See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example with cache-dit acceleration. - -```bash -# Enable cache-dit with default parameters -cd examples/offline_inference/text_to_image -python text_to_image.py \ - --prompt "a cup of coffee on the table" \ - --enable_cache_dit \ - --num_inference_steps 50 -``` - -The `--enable_cache_dit` flag enables cache-dit acceleration with these customized parameters: - -```python -omni = Omni( - ... - cache_backend="cache_dit" if args.enable_cache_dit else None, - cache_config={ - # Scheme: Hybrid DBCache + SCM + TaylorSeer - # DBCache - "Fn_compute_blocks": 8, - "Bn_compute_blocks": 0, - "max_warmup_steps": 4, - "residual_diff_threshold": 0.12, - # TaylorSeer - "enable_taylorseer": True, - "taylorseer_order": 1, - # SCM - "scm_steps_mask_policy": "fast", - "scm_steps_policy": "dynamic", - }, -) - -``` ## Acceleration Methods From 0624d597bdc20b067ad73e180f45363044f31ad0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 10:07:57 +0800 Subject: [PATCH 72/85] fix docs Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion_acceleration.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 34c4c97beb5..33e85481b5c 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -1,6 +1,6 @@ # Diffusion Acceleration Overview -vLLM-Omni supports various acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. +vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. ## Supported Acceleration Methods @@ -35,12 +35,12 @@ The following table shows which models are currently supported by each accelerat |-------|-----------------|----------|-----------|-----------| | **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ❌ | ✅* | ✅ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ❌ | ✅ |✅ | ## Performance Benchmarks -The following benchmarks were measured on **Qwen/Qwen-Image** model generating images (**1024x1024**) with 50 inference steps: +The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Image-Edit** models generating 1024x1024 images with 50 inference steps: !!! note "Benchmark Disclaimer" These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: @@ -61,7 +61,6 @@ The following benchmarks were measured on **Qwen/Qwen-Image** model generating i | **Qwen/Qwen-Image-Edit** | None | No acceleration | 51.5s | 1.0x | Baseline (diffusers) | | **Qwen/Qwen-Image-Edit** | Cache-DiT | Default (Fn=1, Bn=0, W=4, TaylorSeer disabled, SCM disabled) | 21.6s | **2.38x** | - | - To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. | Configuration | Ulysses degree |Generation Time | Speedup | @@ -145,3 +144,4 @@ For detailed information on each acceleration method: - **[TeaCache Guide](teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Sequence Parallelism](parallelism_acceleration.md#sequence-parallelism) ** - Guidance on how to set sequence parallelism with configuration. From 8f630d2d357ccf8fb92df85c764e325967ff5b08 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 10:10:43 +0800 Subject: [PATCH 73/85] fix image edit example Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../image_to_image/image_edit.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 56272683f51..5ea76a8217f 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -84,6 +84,17 @@ def parse_args() -> argparse.Namespace: default=50, help="Number of denoising steps for the diffusion sampler.", ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit", "tea_cache"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " + "Default: None (no cache acceleration)." + ), + ) parser.add_argument( "--ulysses_degree", type=int, @@ -116,24 +127,50 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - assert args.ring_degree == 1, "Ring attention is not supported yet" parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) + # Configure cache based on backend type + cache_config = None + if args.cache_backend == "cache_dit": + # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer + # All parameters marked with [cache-dit only] in DiffusionCacheConfig + cache_config = { + # DBCache parameters [cache-dit only] + "Fn_compute_blocks": 1, # Optimized for single-transformer models + "Bn_compute_blocks": 0, # Number of backward compute blocks + "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) + "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching + "max_continuous_cached_steps": 3, # Limit to prevent precision degradation + # TaylorSeer parameters [cache-dit only] + "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) + "taylorseer_order": 1, # TaylorSeer polynomial order + # SCM (Step Computation Masking) parameters [cache-dit only] + "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" + "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" + } + elif args.cache_backend == "tea_cache": + raise ValueError("TeaCache is not supported for image-to-image generation.") + + # Initialize Omni with QwenImageEditPipeline omni = Omni( model=args.model, model_class_name="QwenImageEditPipeline", vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, + cache_backend=args.cache_backend, + cache_config=cache_config, parallel_config=parallel_config, ) - print("Pipeline loaded") + # Time profiling for generation print(f"\n{'=' * 60}") print("Generation Configuration:") print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") + print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") + print(f" Input image size: {input_image.size}") print(f"{'=' * 60}\n") generation_start = time.perf_counter() From 0ecd31277aa7515519100509f70c70a31e49f716 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:16:36 +0800 Subject: [PATCH 74/85] args name degree to size except for ring&ulysses degrees Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../test_ulysses_sequence_parallel.py | 10 +-- .../diffusion/distributed/parallel_state.py | 88 +++++++++---------- vllm_omni/diffusion/worker/gpu_worker.py | 10 +-- 3 files changed, 52 insertions(+), 56 deletions(-) diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py index aea03809daf..33429142c27 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -369,13 +369,13 @@ def ulysses_attention_on_test_model( # Initialize model parallel initialize_model_parallel( - data_parallel_degree=1, - classifier_free_guidance_degree=1, - sequence_parallel_degree=sequence_parallel_size, + data_parallel_size=1, + cfg_parallel_size=1, + sequence_parallel_size=sequence_parallel_size, ulysses_degree=ulysses_degree, ring_degree=ring_degree, - tensor_parallel_degree=1, - pipeline_parallel_degree=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, ) # Set the config so Attention can access it diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 387639d1722..8dccae68cf4 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -514,20 +514,20 @@ def init_vae_group( # adapted from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True): """ - sp_ulysses_degree x sp_ring_degree = seq_parallel_degree - (ulysses_degree, dp_degree) + sp_ulysses_degree x sp_ring_degree = seq_parallel_size + (ulysses_degree, dp_size) """ - sp_degree = sp_ring_degree * sp_ulysses_degree - dp_degree = world_size // sp_degree + sp_size = sp_ring_degree * sp_ulysses_degree + dp_size = world_size // sp_size - assert world_size % sp_degree == 0, f"world_size {world_size} % sp_degree {sp_ulysses_degree} == 0" + assert world_size % sp_size == 0, f"world_size {world_size} % sp_size {sp_ulysses_degree} == 0" num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree if use_ulysses_low: - for dp_rank in range(dp_degree): - offset = dp_rank * sp_degree + for dp_rank in range(dp_size): + offset = dp_rank * sp_size for i in range(num_ulysses_pgs): ulysses_ranks = list( range( @@ -540,14 +540,14 @@ def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use ulyssess_pg = group for i in range(num_ring_pgs): - ring_ranks = list(range(i + offset, sp_degree + offset, num_ring_pgs)) + ring_ranks = list(range(i + offset, sp_size + offset, num_ring_pgs)) group = torch.distributed.new_group(ring_ranks) if rank in ring_ranks: ring_pg = group else: - for dp_rank in range(dp_degree): - offset = dp_rank * sp_degree + for dp_rank in range(dp_size): + offset = dp_rank * sp_size for i in range(num_ring_pgs): ring_ranks = list(range(i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset)) group = torch.distributed.new_group(ring_ranks) @@ -555,7 +555,7 @@ def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use ring_pg = group for i in range(num_ulysses_pgs): - ulysses_ranks = list(range(i + offset, sp_degree + offset, num_ulysses_pgs)) + ulysses_ranks = list(range(i + offset, sp_size + offset, num_ulysses_pgs)) group = torch.distributed.new_group(ulysses_ranks) if rank in ulysses_ranks: ulyssess_pg = group @@ -564,13 +564,13 @@ def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use def initialize_model_parallel( - data_parallel_degree: int = 1, - classifier_free_guidance_degree: int = 1, - sequence_parallel_degree: Optional[int] = None, + data_parallel_size: int = 1, + cfg_parallel_size: int = 1, + sequence_parallel_size: Optional[int] = None, ulysses_degree: int = 1, ring_degree: int = 1, - tensor_parallel_degree: int = 1, - pipeline_parallel_degree: int = 1, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, vae_parallel_size: int = 0, backend: Optional[str] = None, ) -> None: @@ -580,21 +580,21 @@ def initialize_model_parallel( Initialize model parallel groups. Arguments: - data_parallel_degree: number of data parallelism groups. - classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) - sequence_parallel_degree: number of GPUs used for sequence parallelism. - sequence_parallel_degree = ulysses_degree * ring_degree + data_parallel_size: number of data parallelism groups. + cfg_parallel_size: number of GPUs used for Classifier Free Guidance (CFG) parallelism. + sequence_parallel_size: number of GPUs used for sequence parallelism. + sequence_parallel_size = ulysses_degree * ring_degree ulysses_degree: number of GPUs used for ulysses sequence parallelism. ring_degree: number of GPUs used for ring sequence parallelism. - tensor_parallel_degree: number of GPUs used for tensor parallelism. - pipeline_parallel_degree: number of GPUs used for pipeline parallelism. + tensor_parallel_size: number of GPUs used for tensor parallelism. + pipeline_parallel_size: number of GPUs used for pipeline parallelism. backend: distributed backend of pytorch collective comm. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize split batch caused by CFG, and 2 GPUs to parallelize sequence. - dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16. + dp_size (2) * cfg_size (2) * sp_size (2) * pp_size (2) = 16. The present function will create 8 data-parallel groups, 8 CFG group, 8 pipeline-parallel group, and @@ -621,48 +621,44 @@ def initialize_model_parallel( world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend(get_world_group().device_group) - if sequence_parallel_degree is None: - sequence_parallel_degree = ring_degree * ulysses_degree + if sequence_parallel_size is None: + sequence_parallel_size = ring_degree * ulysses_degree logger.info( - f"sequence_parallel_degree is not provided, using ring_degree * ulysses_degree = {sequence_parallel_degree}" + f"sequence_parallel_size is not provided, using ring_degree * ulysses_degree = {sequence_parallel_size}" ) - if sequence_parallel_degree != ring_degree * ulysses_degree: + if sequence_parallel_size != ring_degree * ulysses_degree: raise ValueError( - "sequence_parallel_degree is not equal to ring_degree * ulysses_degree," - f" but got {sequence_parallel_degree} != {ring_degree} * {ulysses_degree}" + "sequence_parallel_size is not equal to ring_degree * ulysses_degree," + f" but got {sequence_parallel_size} != {ring_degree} * {ulysses_degree}" ) # FIXME: Since the async p2p communication operation of NPU is not same as cuda in torch, # the pipefusion is not ready for npu yet if envs._is_npu(): - assert pipeline_parallel_degree == 1, "Current pipefusion is not ready for NPU" + assert pipeline_parallel_size == 1, "Current pipefusion is not ready for NPU" dit_parallel_size = ( - data_parallel_degree - * classifier_free_guidance_degree - * sequence_parallel_degree - * pipeline_parallel_degree - * tensor_parallel_degree + data_parallel_size * cfg_parallel_size * sequence_parallel_size * pipeline_parallel_size * tensor_parallel_size ) if world_size < dit_parallel_size: raise RuntimeError( f"world_size ({world_size}) is less than " - f"tensor_parallel_degree ({tensor_parallel_degree}) x " - f"pipeline_parallel_degree ({pipeline_parallel_degree}) x" - f"sequence_parallel_degree ({sequence_parallel_degree}) x" - f"classifier_free_guidance_degree " - f"({classifier_free_guidance_degree}) x" - f"data_parallel_degree ({data_parallel_degree})" + f"tensor_parallel_size ({tensor_parallel_size}) x " + f"pipeline_parallel_size ({pipeline_parallel_size}) x" + f"sequence_parallel_size ({sequence_parallel_size}) x" + f"cfg_parallel_size " + f"({cfg_parallel_size}) x" + f"data_parallel_size ({data_parallel_size})" ) rank_generator: RankGenerator = RankGenerator( - tensor_parallel_degree, - sequence_parallel_degree, - pipeline_parallel_degree, - classifier_free_guidance_degree, - data_parallel_degree, + tensor_parallel_size, + sequence_parallel_size, + pipeline_parallel_size, + cfg_parallel_size, + data_parallel_size, "tp-sp-pp-cfg-dp", ) global _DP diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 075e6aeffa4..40f0b7b300f 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -72,13 +72,13 @@ def init_device_and_model(self) -> None: logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") parallel_config = self.od_config.parallel_config initialize_model_parallel( - data_parallel_degree=parallel_config.data_parallel_size, - classifier_free_guidance_degree=parallel_config.cfg_parallel_size, - sequence_parallel_degree=parallel_config.sequence_parallel_size, + data_parallel_size=parallel_config.data_parallel_size, + cfg_parallel_size=parallel_config.cfg_parallel_size, + sequence_parallel_size=parallel_config.sequence_parallel_size, ulysses_degree=parallel_config.ulysses_degree, ring_degree=parallel_config.ring_degree, - tensor_parallel_degree=parallel_config.tensor_parallel_size, - pipeline_parallel_degree=parallel_config.pipeline_parallel_size, + tensor_parallel_size=parallel_config.tensor_parallel_size, + pipeline_parallel_size=parallel_config.pipeline_parallel_size, ) load_config = LoadConfig() From ed46752480c6d261360539e33b40931af30c0a65 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:53:47 +0800 Subject: [PATCH 75/85] rm attention npu Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 389f48c27e2..60eb4d57513 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -20,7 +20,6 @@ from vllm_omni.diffusion.data import get_current_omni_diffusion_config from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size, get_sp_group -from vllm_omni.utils.platform_utils import is_npu class Attention(nn.Module): @@ -104,25 +103,12 @@ def _forward_ulysses( if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 - if is_npu(): - context_layer = self.attention( - q, - k, - v, - num_heads=q.shape[-2], - input_layout="BSND", - scale=softmax_scale, - softmax_lse_flag=True, - pre_tokens=65535, - next_tokens=65535, - ) - else: - context_layer = self.attention.forward( - q, - k, - v, - attn_metadata=attn_metadata, - ) + context_layer = self.attention.forward( + q, + k, + v, + attn_metadata=attn_metadata, + ) if isinstance(context_layer, tuple): context_layer = context_layer[0] From e06b049d8ea3497db587839e5f16a4e237a1458a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 17:58:36 +0800 Subject: [PATCH 76/85] fix ci Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 5 ++-- .../diffusion/distributed/parallel_state.py | 23 ++++++++----------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 60eb4d57513..5611d8ee436 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -6,7 +6,6 @@ # Adapted from # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py -from typing import Optional import torch import torch.distributed as dist @@ -51,8 +50,8 @@ def __init__( self.scatter_idx = scatter_idx self.gather_idx = gather_idx self.use_sync = use_sync - self.ring_pg: Optional[dist.ProcessGroup] = None - self.ulysses_pg: Optional[dist.ProcessGroup] = None + self.ring_pg: dist.ProcessGroup | None = None + self.ulysses_pg: dist.ProcessGroup | None = None self.use_ulysses = False try: diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py index 8dccae68cf4..b249515909d 100644 --- a/vllm_omni/diffusion/distributed/parallel_state.py +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -10,7 +10,6 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - """vLLM-Omni distributed state. It takes over the control of the distributed environment from PyTorch. @@ -29,8 +28,6 @@ you can skip the model parallel initialization and destruction steps. """ -from typing import Optional - import torch import torch.distributed import vllm.distributed.parallel_state as vllm_parallel_state @@ -60,14 +57,14 @@ logger = init_logger(__name__) -_WORLD: Optional[GroupCoordinator] = None +_WORLD: GroupCoordinator | None = None # get _TP from vllm.distributed.parallel_state -_SP: Optional[SequenceParallelGroupCoordinator] = None -_PP: Optional[PipelineGroupCoordinator] = None -_CFG: Optional[GroupCoordinator] = None -_DP: Optional[GroupCoordinator] = None -_DIT: Optional[GroupCoordinator] = None -_VAE: Optional[GroupCoordinator] = None +_SP: SequenceParallelGroupCoordinator | None = None +_PP: PipelineGroupCoordinator | None = None +_CFG: GroupCoordinator | None = None +_DP: GroupCoordinator | None = None +_DIT: GroupCoordinator | None = None +_VAE: GroupCoordinator | None = None def generate_masked_orthogonal_rank_groups( @@ -396,7 +393,7 @@ def init_distributed_environment( rank: int = -1, distributed_init_method: str = "env://", local_rank: int = -1, - backend: Optional[str] = None, + backend: str | None = None, ): if backend is None: backend = envs.get_torch_distributed_backend() @@ -566,13 +563,13 @@ def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use def initialize_model_parallel( data_parallel_size: int = 1, cfg_parallel_size: int = 1, - sequence_parallel_size: Optional[int] = None, + sequence_parallel_size: int | None = None, ulysses_degree: int = 1, ring_degree: int = 1, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, vae_parallel_size: int = 0, - backend: Optional[str] = None, + backend: str | None = None, ) -> None: if backend is None: backend = envs.get_torch_distributed_backend() From eb6818cd595a047a33731ebe005548baac7ece41 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 18:08:15 +0800 Subject: [PATCH 77/85] rm simple test Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .buildkite/scripts/simple_test.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/.buildkite/scripts/simple_test.sh b/.buildkite/scripts/simple_test.sh index bbd9be2f643..157ad9ca09c 100755 --- a/.buildkite/scripts/simple_test.sh +++ b/.buildkite/scripts/simple_test.sh @@ -51,5 +51,3 @@ VENV_PYTHON="${VENV_DIR}/bin/python" "${UV_BIN}" pip install --python "${VENV_PYTHON}" -e ".[dev]" "${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/ "${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/ -"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/attention/ -"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/distributed/ From 94026b01a97659974326d051aaf6841fe0994bab Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 18:08:56 +0800 Subject: [PATCH 78/85] rm pipeline test Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .buildkite/pipeline.yml | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1ee63e72bf4..952954e81d2 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -49,23 +49,6 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" - - label: "Ulysses Sequence Parallelism Test" - timeout_in_minutes: 15 - depends_on: image-build - commands: - - pytest -s -v tests/diffusion/attention/test_ulysses_sequence_parallel.py tests/diffusion/distributed/test_comm.py tests/e2e/offline_inference/test_sequence_parallel.py - agents: - queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU - plugins: - - docker#v5.2.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - always-pull: true - propagate-environment: true - environment: - - "HF_HOME=/fsx/hf_cache" - volumes: - - "/fsx/hf_cache:/fsx/hf_cache" - - label: "Omni Model Test" timeout_in_minutes: 15 depends_on: image-build From db5c134034541aa2fd9fa081419d965555093995 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 16 Dec 2025 18:50:45 +0800 Subject: [PATCH 79/85] fix pre-commit Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../distributed/group_coordinator.py | 47 ++++++++++--------- vllm_omni/diffusion/envs.py | 9 ++-- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 02aaa391b1f..8e33d6fb657 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -5,7 +5,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import pickle from collections import namedtuple -from typing import Any, Optional, Union +from typing import Any import torch import torch.distributed @@ -21,13 +21,14 @@ logger.info("torch.npu synchronize") from torch.npu import synchronize + TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) env_info = envs.PACKAGES_CHECKER.get_packages_info() def _split_tensor_dict( - tensor_dict: dict[str, Union[torch.Tensor, Any]], prefix: str = "" + tensor_dict: dict[str, torch.Tensor | Any], prefix: str = "" ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -101,7 +102,7 @@ def __init__( self, group_ranks: list[list[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], + torch_distributed_backend: str | Backend, ): self.rank = torch.distributed.get_rank() self.local_rank = local_rank @@ -202,7 +203,7 @@ def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceO def all_gather( self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: @@ -280,7 +281,7 @@ def broadcast(self, input_: torch.Tensor, src: int = 0): torch.distributed.broadcast(input_, src=self.ranks[src], group=self.device_group) return input_ - def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + def broadcast_object(self, obj: Any | None = None, src: int = 0): """Broadcast the input object. NOTE: `src` is the local rank of the source rank. """ @@ -300,7 +301,7 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) return recv[0] - def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None): + def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ @@ -365,11 +366,11 @@ def recv_object(self, src: int) -> Any: def broadcast_tensor_dict( self, - tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -433,9 +434,9 @@ def broadcast_tensor_dict( def send_tensor_dict( self, - tensor_dict: dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + ) -> dict[str, torch.Tensor | Any] | None: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ @@ -469,7 +470,7 @@ def send_tensor_dict( torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None - def recv_tensor_dict(self, src: Optional[int] = None) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + def recv_tensor_dict(self, src: int | None = None) -> dict[str, torch.Tensor | Any] | None: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -513,7 +514,7 @@ def barrier(self): """ torch.distributed.barrier(group=self.cpu_group) - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the rank_in_group of the destination rank.""" if dst is None: @@ -525,7 +526,7 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), ) - def recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) -> torch.Tensor: + def recv(self, size: torch.Size, dtype: torch.dtype, src: int | None = None) -> torch.Tensor: """Receives a tensor from the src rank.""" """NOTE: `src` is the rank_in_group of the source rank.""" if src is None: @@ -571,7 +572,7 @@ def __init__( self, group_ranks: list[list[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], + torch_distributed_backend: str | Backend, ): self.rank = torch.distributed.get_rank() self.local_rank = local_rank @@ -622,17 +623,17 @@ def __init__( self.recv_buffer_set: bool = False self.recv_tasks_queue: list[tuple[str, int]] = [] self.receiving_tasks: list[tuple[torch.distributed.Work, str, int]] = [] - self.dtype: Optional[torch.dtype] = None - self.num_pipefusion_patches: Optional[int] = None + self.dtype: torch.dtype | None = None + self.num_pipefusion_patches: int | None = None self.recv_shape: dict[str, dict[int, torch.Size]] = {} self.send_shape: dict[str, dict[int, torch.Size]] = {} self.recv_buffer: dict[str, dict[int, torch.Size]] = {} self.skip_tensor_recv_buffer_set: bool = False - self.recv_skip_tasks_queue: list[Union[int, tuple[str, int]]] = [] + self.recv_skip_tasks_queue: list[int | tuple[str, int]] = [] self.receiving_skip_tasks: list[tuple[torch.distributed.Work, str, int]] = [] - self.skip_tensor_recv_buffer: Optional[Union[list[torch.Tensor], torch.Tensor]] = None + self.skip_tensor_recv_buffer: list[torch.Tensor] | torch.Tensor | None = None self.skip_device_group = None for ranks in group_ranks: skip_device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) @@ -686,7 +687,7 @@ def _check_shape_and_buffer( self, tensor_send_to_next=None, recv_prev=False, - name: Optional[str] = None, + name: str | None = None, segment_idx: int = 0, ): send_flag = False @@ -913,7 +914,7 @@ def __init__( self, group_ranks: list[list[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], + torch_distributed_backend: str | Backend, **kwargs, ): super().__init__( diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py index 4b5b6fb041a..47520f498e9 100644 --- a/vllm_omni/diffusion/envs.py +++ b/vllm_omni/diffusion/envs.py @@ -2,7 +2,8 @@ # Adapted from # https://github.com/xdit-project/xDiT/blob/main/xfuser/envs.py import os -from typing import TYPE_CHECKING, Any, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import torch from packaging import version @@ -12,10 +13,10 @@ if TYPE_CHECKING: MASTER_ADDR: str = "" - MASTER_PORT: Optional[int] = None - CUDA_HOME: Optional[str] = None + MASTER_PORT: int | None = None + CUDA_HOME: str | None = None LOCAL_RANK: int = 0 - CUDA_VISIBLE_DEVICES: Optional[str] = None + CUDA_VISIBLE_DEVICES: str | None = None CUDA_VERSION: version.Version TORCH_VERSION: version.Version From 14f17e2e3eeb34b02f799135e19812063c2c66c1 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 17 Dec 2025 09:33:12 +0800 Subject: [PATCH 80/85] fix ci Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/attention/layer.py | 1 - vllm_omni/diffusion/data.py | 3 +-- vllm_omni/diffusion/envs.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 5611d8ee436..99d4009c9b4 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -6,7 +6,6 @@ # Adapted from # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py - import torch import torch.distributed as dist import torch.nn as nn diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index d486f0d8f86..cd886f3aeb9 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -5,11 +5,10 @@ import os import random from collections.abc import Callable -from pydantic import model_validator -from typing import Any from contextlib import contextmanager from dataclasses import dataclass, field, fields from functools import lru_cache +from typing import Any import torch from pydantic import model_validator diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py index 47520f498e9..f717f2ca08b 100644 --- a/vllm_omni/diffusion/envs.py +++ b/vllm_omni/diffusion/envs.py @@ -20,7 +20,6 @@ CUDA_VERSION: version.Version TORCH_VERSION: version.Version - environment_variables: dict[str, Callable[[], Any]] = { # ================== Runtime Env Vars ================== # used in distributed environment to determine the master address From 1c5587312595f745b7979e2a06f2c2c4528a9f5b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:11:33 +0800 Subject: [PATCH 81/85] extend time out minutes Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 952954e81d2..6aab81b474f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -16,7 +16,7 @@ steps: queue: "cpu_queue_premerge" - label: "Diffusion Model Test" - timeout_in_minutes: 15 + timeout_in_minutes: 20 depends_on: image-build commands: - pytest -s -v tests/e2e/offline_inference/test_t2i_model.py From adf5ae7211538ede6fd784b6c52e322fdae7226d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:03:59 +0800 Subject: [PATCH 82/85] test sp pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .buildkite/pipeline.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6aab81b474f..86eec116387 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -49,6 +49,23 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Diffusion Parallelism Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Omni Model Test" timeout_in_minutes: 15 depends_on: image-build From 9d14371d5b9270178e4eb8aa02d039f3176327d0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:31:32 +0800 Subject: [PATCH 83/85] change docs structure Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/.nav.yml | 7 ++++--- docs/configuration/README.md | 4 +++- .../{ => acceleration}/cache_dit_acceleration.md | 0 .../{ => acceleration}/parallelism_acceleration.md | 0 docs/user_guide/{ => acceleration}/teacache.md | 0 docs/user_guide/diffusion_acceleration.md | 12 ++++++------ 6 files changed, 13 insertions(+), 10 deletions(-) rename docs/user_guide/{ => acceleration}/cache_dit_acceleration.md (100%) rename docs/user_guide/{ => acceleration}/parallelism_acceleration.md (100%) rename docs/user_guide/{ => acceleration}/teacache.md (100%) diff --git a/docs/.nav.yml b/docs/.nav.yml index 57bd9cb7265..450f932ca9f 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -22,9 +22,10 @@ nav: - configuration/* - Diffusion Acceleration: - Overview: user_guide/diffusion_acceleration.md - - TeaCache: user_guide/teacache.md - - Cache-DiT: user_guide/cache_dit_acceleration.md - - Parallelism Acceleration: user_guide/parallelism_acceleration.md + - Acceleration Methods: + - TeaCache: user_guide/acceleration/teacache.md + - Cache-DiT: user_guide/acceleration/cache_dit_acceleration.md + - Parallelism Acceleration: user_guide/acceleration/parallelism_acceleration.md - Models: - models/supported_models.md - Developer Guide: diff --git a/docs/configuration/README.md b/docs/configuration/README.md index 34b7c1a27c9..1ceb9f827da 100644 --- a/docs/configuration/README.md +++ b/docs/configuration/README.md @@ -12,4 +12,6 @@ For introduction, please check [Introduction for stage config](./stage_configs.m ## Optimization Features -- **[TeaCache Configuration](../user_guide/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss +- **[TeaCache Configuration](../user_guide/acceleration/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss +- **[Cache-DiT Configuration](../user_guide/acceleration/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models +- **[Parallelism Configuration](../user_guide/acceleration/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models diff --git a/docs/user_guide/cache_dit_acceleration.md b/docs/user_guide/acceleration/cache_dit_acceleration.md similarity index 100% rename from docs/user_guide/cache_dit_acceleration.md rename to docs/user_guide/acceleration/cache_dit_acceleration.md diff --git a/docs/user_guide/parallelism_acceleration.md b/docs/user_guide/acceleration/parallelism_acceleration.md similarity index 100% rename from docs/user_guide/parallelism_acceleration.md rename to docs/user_guide/acceleration/parallelism_acceleration.md diff --git a/docs/user_guide/teacache.md b/docs/user_guide/acceleration/teacache.md similarity index 100% rename from docs/user_guide/teacache.md rename to docs/user_guide/acceleration/teacache.md diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 33e85481b5c..95afbd1b0f5 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -6,8 +6,8 @@ vLLM-Omni supports various cache acceleration methods to speed up diffusion mode vLLM-Omni currently supports two main cache acceleration backends: -1. **[TeaCache](teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar -2. **[Cache-DiT](cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: +1. **[TeaCache](acceleration/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar +2. **[Cache-DiT](acceleration/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking @@ -16,7 +16,7 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma vLLM-Omni also supports the sequence parallelism (SP) for the diffusion model, that includes: -1. [Ulysses-SP](parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. +1. [Ulysses-SP](acceleration/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. ## Quick Comparison @@ -142,6 +142,6 @@ outputs = omni.generate(prompt="turn this cat to a dog", For detailed information on each acceleration method: -- **[TeaCache Guide](teacache.md)** - Complete TeaCache documentation, configuration options, and best practices -- **[Cache-DiT Acceleration Guide](cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters -- **[Sequence Parallelism](parallelism_acceleration.md#sequence-parallelism) ** - Guidance on how to set sequence parallelism with configuration. +- **[TeaCache Guide](acceleration/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices +- **[Cache-DiT Acceleration Guide](acceleration/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism) ** - Guidance on how to set sequence parallelism with configuration. From 824f45265f6a2ad3c42911cbb785c96b30329092 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:32:50 +0800 Subject: [PATCH 84/85] remove ring degree Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion_acceleration.md | 2 +- .../offline_inference/image_to_image/image_edit.py | 13 ++++--------- .../text_to_image/text_to_image.py | 12 +++--------- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 95afbd1b0f5..33cd8934fe8 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -144,4 +144,4 @@ For detailed information on each acceleration method: - **[TeaCache Guide](acceleration/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](acceleration/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters -- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism) ** - Guidance on how to set sequence parallelism with configuration. +- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 5ea76a8217f..33a95454d8b 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -101,12 +101,7 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ulysses sequence parallelism.", ) - parser.add_argument( - "--ring_degree", - type=int, - default=1, - help="Number of GPUs used for ring sequence parallelism.", - ) + return parser.parse_args() @@ -127,8 +122,8 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - assert args.ring_degree == 1, "Ring attention is not supported yet" - parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) + + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) # Configure cache based on backend type cache_config = None if args.cache_backend == "cache_dit": @@ -169,7 +164,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") - print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") print(f" Input image size: {input_image.size}") print(f"{'=' * 60}\n") diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 9b28de6454f..21e752254b4 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -64,12 +64,7 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ulysses sequence parallelism.", ) - parser.add_argument( - "--ring_degree", - type=int, - default=1, - help="Number of GPUs used for ring sequence parallelism.", - ) + return parser.parse_args() @@ -111,8 +106,7 @@ def main(): # (e.g., QwenImagePipeline or FluxPipeline) } - assert args.ring_degree == 1, "Ring attention is not supported yet" - parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) omni = Omni( model=args.model, vae_use_slicing=vae_use_slicing, @@ -128,7 +122,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") - print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") From b99f48c9799ceb4d9d11c5b97184cc20eea83afd Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:58:49 +0800 Subject: [PATCH 85/85] fix docs Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/.nav.yml | 10 +++++----- .../acceleration/parallelism_acceleration.md | 1 + docs/user_guide/diffusion_acceleration.md | 6 +++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/.nav.yml b/docs/.nav.yml index 450f932ca9f..f534d49cc09 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -21,11 +21,11 @@ nav: - configuration/README.md - configuration/* - Diffusion Acceleration: - - Overview: user_guide/diffusion_acceleration.md - - Acceleration Methods: - - TeaCache: user_guide/acceleration/teacache.md - - Cache-DiT: user_guide/acceleration/cache_dit_acceleration.md - - Parallelism Acceleration: user_guide/acceleration/parallelism_acceleration.md + - Overview: user_guide/diffusion_acceleration.md + - Acceleration Methods: + - TeaCache: user_guide/acceleration/teacache.md + - Cache-DiT: user_guide/acceleration/cache_dit_acceleration.md + - Parallelism Acceleration: user_guide/acceleration/parallelism_acceleration.md - Models: - models/supported_models.md - Developer Guide: diff --git a/docs/user_guide/acceleration/parallelism_acceleration.md b/docs/user_guide/acceleration/parallelism_acceleration.md index 642ddc1eed2..0ced0731b26 100644 --- a/docs/user_guide/acceleration/parallelism_acceleration.md +++ b/docs/user_guide/acceleration/parallelism_acceleration.md @@ -5,6 +5,7 @@ This guide includes how to use parallelism methods in vLLM-Omni to speed up diff ## Overview The following parallelism methods are currently supported in vLLM-Omni: + 1. DeepSpeed Ulysses Sequence Parallel (Ulysses-SP) ([paper](https://arxiv.org/pdf/2309.14509)): Ulysses-SP splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 33cd8934fe8..220b930d69c 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -8,9 +8,9 @@ vLLM-Omni currently supports two main cache acceleration backends: 1. **[TeaCache](acceleration/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar 2. **[Cache-DiT](acceleration/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: - - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences - - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference - - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking + - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences + - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference + - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality.