diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 952954e81d2..86eec116387 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -16,7 +16,7 @@ steps: queue: "cpu_queue_premerge" - label: "Diffusion Model Test" - timeout_in_minutes: 15 + timeout_in_minutes: 20 depends_on: image-build commands: - pytest -s -v tests/e2e/offline_inference/test_t2i_model.py @@ -49,6 +49,23 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Diffusion Parallelism Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Omni Model Test" timeout_in_minutes: 15 depends_on: image-build diff --git a/docs/.nav.yml b/docs/.nav.yml index 2666f13eeb3..f534d49cc09 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -22,8 +22,10 @@ nav: - configuration/* - Diffusion Acceleration: - Overview: user_guide/diffusion_acceleration.md - - TeaCache: user_guide/teacache.md - - Cache-DiT: user_guide/cache_dit_acceleration.md + - Acceleration Methods: + - TeaCache: user_guide/acceleration/teacache.md + - Cache-DiT: user_guide/acceleration/cache_dit_acceleration.md + - Parallelism Acceleration: user_guide/acceleration/parallelism_acceleration.md - Models: - models/supported_models.md - Developer Guide: diff --git a/docs/configuration/README.md b/docs/configuration/README.md index 34b7c1a27c9..1ceb9f827da 100644 --- a/docs/configuration/README.md +++ b/docs/configuration/README.md @@ -12,4 +12,6 @@ For introduction, please check [Introduction for stage config](./stage_configs.m ## Optimization Features -- **[TeaCache Configuration](../user_guide/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss +- **[TeaCache Configuration](../user_guide/acceleration/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss +- **[Cache-DiT Configuration](../user_guide/acceleration/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models +- **[Parallelism Configuration](../user_guide/acceleration/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 37c2f8316b6..4e840280b26 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -195,7 +195,7 @@ def generate(self) -> str: main_file_rel = self.main_file.relative_to(ROOT_DIR) content += f'{code_fence}{self.main_file.suffix[1:]}\n--8<-- "{main_file_rel}"\n{code_fence}\n' else: - with open(self.main_file) as f: + with open(self.main_file, encoding="utf-8") as f: # Skip the title from md snippets as it's been included above main_content = f.readlines()[1:] content += self.fix_relative_links("".join(main_content)) diff --git a/docs/user_guide/cache_dit_acceleration.md b/docs/user_guide/acceleration/cache_dit_acceleration.md similarity index 100% rename from docs/user_guide/cache_dit_acceleration.md rename to docs/user_guide/acceleration/cache_dit_acceleration.md diff --git a/docs/user_guide/acceleration/parallelism_acceleration.md b/docs/user_guide/acceleration/parallelism_acceleration.md new file mode 100644 index 00000000000..0ced0731b26 --- /dev/null +++ b/docs/user_guide/acceleration/parallelism_acceleration.md @@ -0,0 +1,128 @@ +# Parallelism Acceleration Guide + +This guide includes how to use parallelism methods in vLLM-Omni to speed up diffusion model inference as well as reduce the memory requirement on each device. + +## Overview + +The following parallelism methods are currently supported in vLLM-Omni: + +1. DeepSpeed Ulysses Sequence Parallel (Ulysses-SP) ([paper](https://arxiv.org/pdf/2309.14509)): Ulysses-SP splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. + + +The following table shows which models are currently supported by parallelism method: + + +| Model | Model Identifier | Ulysses-SP | +|-------|-----------------|-----------| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | + +### Sequence Parallelism + +#### Ulysses-SP + +##### Quick Start + +An example of using Ulysses-SP is shown below: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. + +| Configuration | Ulysses degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 112.5s | 1.0x | +| Ulysses-SP | 2 | 65.2s | 1.73x | +| Ulysses-SP | 4 | 39.6s | 2.84x | +| Ulysses-SP | 8 | 30.8s | 3.65x | + +##### How to parallelize a new model + +If a diffusion model has been deployed in vLLM-Omni and supports single-card inference, you can refer to the following instruction on how to parallelize this model with Ulysses-SP. + +First, please edit the `TransformerModel`'s `forward` function in the `xxx_model_transformer.py` to make the inputs (image hidden states, positional embeddings, etc.) as chunks separated at the sequence dimension. Taking `qwen_image_transformer.py` as an example: + +```diff +class QwenImageTransformer2DModel(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + ... + ): ++ if self.parallel_config.sequence_parallel_size > 1: ++ hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[ ++ get_sequence_parallel_rank() ++ ] + + hidden_states = self.img_in(hidden_states) + + ... + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + ++ def get_rotary_emb_chunk(freqs): ++ freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] ++ return freqs + ++ if self.parallel_config.sequence_parallel_size > 1: ++ img_freqs, txt_freqs = image_rotary_emb ++ img_freqs = get_rotary_emb_chunk(img_freqs) ++ image_rotary_emb = (img_freqs, txt_freqs) +``` + +Next, at the end of the `forward` function, please call `get_sp_group().all_gather` to gather the chunked outputs across devices, and concatenate them at the sequence dimension. + + +```diff +class QwenImageTransformer2DModel(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + ... + ): + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + ++ if self.parallel_config.sequence_parallel_size > 1: ++ output = get_sp_group().all_gather(output, dim=-2) + return Transformer2DModelOutput(sample=output) +``` + +Finally, you can set the parallel configuration and pass it to `Omni` and start parallel inference with: +```diff +from vllm_omni import Omni ++from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", ++ parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) +``` diff --git a/docs/user_guide/teacache.md b/docs/user_guide/acceleration/teacache.md similarity index 100% rename from docs/user_guide/teacache.md rename to docs/user_guide/acceleration/teacache.md diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 0e2cd06a0df..220b930d69c 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -1,21 +1,27 @@ # Diffusion Acceleration Overview -vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods intelligently cache intermediate computations to avoid redundant work across diffusion timesteps. +vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. ## Supported Acceleration Methods vLLM-Omni currently supports two main cache acceleration backends: -1. **[TeaCache](teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar -2. **[Cache-DiT](cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: - - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences - - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference - - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking +1. **[TeaCache](acceleration/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar +2. **[Cache-DiT](acceleration/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: + - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences + - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference + - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality. +vLLM-Omni also supports the sequence parallelism (SP) for the diffusion model, that includes: + +1. [Ulysses-SP](acceleration/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. + ## Quick Comparison +### Cache Methods + | Method | Configuration | Description | Best For | |--------|--------------|-------------|----------| | **TeaCache** | `cache_backend="tea_cache"` | Simple, adaptive caching with minimal configuration | Quick setup, balanced speed/quality | @@ -23,18 +29,18 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma ## Supported Models -The following table shows which models are currently supported by each cache backend: +The following table shows which models are currently supported by each acceleration method: -| Model | Model Identifier | TeaCache | Cache-DiT | -|-------|-----------------|----------|-----------| -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ❌ | ✅ | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | +|-------|-----------------|----------|-----------|-----------| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ❌ | ✅ |✅ | ## Performance Benchmarks -The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Image-Edit** models with 50 inference steps: +The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Image-Edit** models generating 1024x1024 images with 50 inference steps: !!! note "Benchmark Disclaimer" These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: @@ -55,6 +61,14 @@ The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Im | **Qwen/Qwen-Image-Edit** | None | No acceleration | 51.5s | 1.0x | Baseline (diffusers) | | **Qwen/Qwen-Image-Edit** | Cache-DiT | Default (Fn=1, Bn=0, W=4, TaylorSeer disabled, SCM disabled) | 21.6s | **2.38x** | - | +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. + +| Configuration | Ulysses degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 112.5s | 1.0x | +| Ulysses-SP | 2 | 65.2s | 1.73x | +| Ulysses-SP | 4 | 39.6s | 2.84x | +| Ulysses-SP | 8 | 30.8s | 3.65x | ## Quick Start @@ -92,9 +106,42 @@ omni = Omni( outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) ``` +### Using Ulysses-SP + +Run text-to-image: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + + +Run image-to-image: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image-Edit", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate(prompt="turn this cat to a dog", + pil_image=input_image, num_inference_steps=50) +``` + ## Documentation For detailed information on each acceleration method: -- **[TeaCache Guide](teacache.md)** - Complete TeaCache documentation, configuration options, and best practices -- **[Cache-DiT Acceleration Guide](cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[TeaCache Guide](acceleration/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices +- **[Cache-DiT Acceleration Guide](acceleration/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 44051c19831..33a95454d8b 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -24,6 +24,7 @@ import torch from PIL import Image +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -94,6 +95,13 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=1, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + return parser.parse_args() @@ -115,6 +123,7 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) # Configure cache based on backend type cache_config = None if args.cache_backend == "cache_dit": @@ -145,6 +154,7 @@ def main(): vae_use_tiling=vae_use_tiling, cache_backend=args.cache_backend, cache_config=cache_config, + parallel_config=parallel_config, ) print("Pipeline loaded") @@ -154,6 +164,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") print(f" Input image size: {input_image.size}") print(f"{'=' * 60}\n") diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 85c84b5961a..21e752254b4 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -7,6 +7,7 @@ import torch +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -57,6 +58,13 @@ def parse_args() -> argparse.Namespace: "Default: None (no cache acceleration)." ), ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=1, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + return parser.parse_args() @@ -98,12 +106,14 @@ def main(): # (e.g., QwenImagePipeline or FluxPipeline) } + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) omni = Omni( model=args.model, vae_use_slicing=vae_use_slicing, vae_use_tiling=vae_use_tiling, cache_backend=args.cache_backend, cache_config=cache_config, + parallel_config=parallel_config, ) # Time profiling for generation @@ -112,6 +122,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_ulysses_sequence_parallel.py new file mode 100644 index 00000000000..33429142c27 --- /dev/null +++ b/tests/diffusion/attention/test_ulysses_sequence_parallel.py @@ -0,0 +1,520 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import pickle +import tempfile + +import pytest +import torch +from vllm.platforms import current_platform + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import ( + DiffusionParallelConfig, + OmniDiffusionConfig, + set_current_omni_diffusion_config, +) +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.utils.platform_utils import detect_device_type + +device_type = detect_device_type() +if device_type == "cuda": + torch_device = torch.cuda +elif device_type == "npu": + torch_device = torch.npu +else: + raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +class TestAttentionModel(torch.nn.Module): + """Test model using Attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + hidden_size: int, + causal: bool = False, + num_kv_heads: int | None = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.hidden_size = hidden_size + self.attention = Attention( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=1.0 / (head_size**0.5), + num_kv_heads=num_kv_heads, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + # Linear projection layers for Q, K, V + self.q_proj = torch.nn.Linear(hidden_size, num_heads * head_size) + self.k_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) + self.v_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) + self.o_proj = torch.nn.Linear(num_heads * head_size, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass through attention layer.""" + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (batch_size, seq_len, num_heads, head_size) + q = q.view(batch_size, seq_len, self.num_heads, self.head_size) + k = k.view(batch_size, seq_len, k.shape[-1] // self.head_size, self.head_size) + v = v.view(batch_size, seq_len, v.shape[-1] // self.head_size, self.head_size) + + # Apply attention + attn_output = self.attention(q, k, v) + + # Reshape back and project + attn_output = attn_output.view(batch_size, seq_len, -1) + output = self.o_proj(attn_output) + + return output + + +class TestMultiLayerAttentionModel(torch.nn.Module): + """Test model with multiple attention layers.""" + + def __init__( + self, + num_layers: int, + num_heads: int, + head_size: int, + hidden_size: int, + causal: bool = True, + num_kv_heads: int | None = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.layers = torch.nn.ModuleList( + [ + TestAttentionModel( + num_heads=num_heads, + head_size=head_size, + hidden_size=hidden_size, + causal=causal, + num_kv_heads=num_kv_heads, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass through multiple attention layers.""" + for layer in self.layers: + hidden_states = hidden_states + layer(hidden_states) + return hidden_states + + +@pytest.mark.parametrize( + "test_model_cls", + [ + TestAttentionModel, + TestMultiLayerAttentionModel, + ], +) +@pytest.mark.parametrize("ulysses_degree", [2, 4]) +@pytest.mark.parametrize("ring_degree", [1]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [8]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_sync", [True, False]) +@pytest.mark.parametrize("dynamic", [False]) +@pytest.mark.parametrize("use_compile", [False]) +def test_ulysses_attention( + ulysses_degree: int, + ring_degree: int, + test_model_cls: type[torch.nn.Module], + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, +): + """Test Ulysses attention by comparing with and without SP enabled.""" + sequence_parallel_size = ulysses_degree * ring_degree + + # Create temporary files to share results between processes + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + baseline_output_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + sp_output_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + model_state_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + input_data_file = f.name + + try: + # Step 1: Run without SP (baseline with ulysses_degree=1, ring_degree=1) + print("\n[Baseline] Running without SP (ulysses_degree=1, ring_degree=1)...") + torch.multiprocessing.spawn( + ulysses_attention_on_test_model, + args=( + 1, # num_processes = 1 for baseline + test_model_cls, + batch_size, + seq_len, + num_heads, + head_size, + dtype, + causal, + use_sync, + dynamic, + use_compile, + 1, # ulysses_degree = 1 + 1, # ring_degree = 1 + 1, # sequence_parallel_size = 1 + baseline_output_file, + model_state_file, + input_data_file, + True, # is_baseline + ), + nprocs=1, + ) + + # Step 2: Run with SP enabled + print(f"\n[SP Test] Running with SP (ulysses_degree={ulysses_degree}, ring_degree={ring_degree})...") + torch.multiprocessing.spawn( + ulysses_attention_on_test_model, + args=( + sequence_parallel_size, # num_processes + test_model_cls, + batch_size, + seq_len, + num_heads, + head_size, + dtype, + causal, + use_sync, + dynamic, + use_compile, + ulysses_degree, + ring_degree, + sequence_parallel_size, + sp_output_file, + model_state_file, + input_data_file, + False, # is_baseline + ), + nprocs=sequence_parallel_size, + ) + + # Step 3: Verify input consistency and compare outputs + print(f"\n{'=' * 80}") + print("Verifying input data consistency...") + with open(input_data_file, "rb") as f: + input_data = pickle.load(f) + input_checksum = hash(input_data.tobytes()) + print(f" Input data shape: {input_data.shape}") + print(f" Input data checksum: {input_checksum}") + print(" ✓ Both baseline and SP used the same input data") + + print(f"\n{'=' * 80}") + print("Comparing outputs between baseline and SP...") + with open(baseline_output_file, "rb") as f: + baseline_output = pickle.load(f) + with open(sp_output_file, "rb") as f: + sp_output = pickle.load(f) + + # Convert to tensors for comparison + baseline_tensor = torch.tensor(baseline_output) + sp_tensor = torch.tensor(sp_output) + + print(f" Baseline output shape: {baseline_tensor.shape}") + print(f" SP output shape: {sp_tensor.shape}") + assert baseline_tensor.shape == sp_tensor.shape, "Output shapes must match!" + + # Calculate differences + abs_diff = torch.abs(baseline_tensor - sp_tensor) + max_abs_diff = abs_diff.max().item() + mean_abs_diff = abs_diff.mean().item() + + # Calculate relative difference (avoid division by zero) + baseline_abs = torch.abs(baseline_tensor) + relative_diff = abs_diff / (baseline_abs + 1e-8) + max_relative_diff = relative_diff.max().item() + mean_relative_diff = relative_diff.mean().item() + + print(f"\n{'=' * 80}") + print("Output Difference Analysis:") + print(f" - Max absolute difference: {max_abs_diff:.6e}") + print(f" - Mean absolute difference: {mean_abs_diff:.6e}") + print(f" - Max relative difference: {max_relative_diff:.6e}") + print(f" - Mean relative difference: {mean_relative_diff:.6e}") + print(f" - Baseline output range: [{baseline_tensor.min().item():.6e}, {baseline_tensor.max().item():.6e}]") + print(f" - SP output range: [{sp_tensor.min().item():.6e}, {sp_tensor.max().item():.6e}]") + print(f"{'=' * 80}\n") + + # Assert that differences are within acceptable tolerance + # For FP16/BF16, we expect some numerical differences due to different computation order + if dtype == torch.float16: + atol, rtol = 1e-4, 1e-2 + elif dtype == torch.bfloat16: + atol, rtol = 1e-4, 1e-2 + else: + atol, rtol = 1e-5, 1e-3 + + assert max_abs_diff < atol or max_relative_diff < rtol, ( + f"Output difference too large: max_abs_diff={max_abs_diff:.6e}, " + f"max_relative_diff={max_relative_diff:.6e}, " + f"tolerance: atol={atol}, rtol={rtol}" + ) + + print("✓ Test passed: SP output matches baseline within tolerance") + + finally: + # Clean up temporary files + for f in [baseline_output_file, sp_output_file, model_state_file, input_data_file]: + if os.path.exists(f): + os.remove(f) + + +def ulysses_attention_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, + ulysses_degree: int, + ring_degree: int, + sequence_parallel_size: int, + output_file: str, + model_state_file: str, + input_data_file: str, + is_baseline: bool, +): + """Run Ulysses attention test on a test model and save results for comparison.""" + # Use fixed seed for reproducibility across baseline and SP runs + RANDOM_SEED = 42 + current_platform.seed_everything(RANDOM_SEED) + + mode_str = "Baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" + print(f"\n[{mode_str}] Rank {local_rank}/{world_size} - Random seed set to {RANDOM_SEED}") + + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + # Initialize distributed environment + init_distributed_environment() + + # Set up OmniDiffusionConfig with parallel config + parallel_config = DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + tensor_parallel_size=1, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + cfg_parallel_size=1, + ) + + od_config = OmniDiffusionConfig( + model="test_model", + dtype=dtype, + parallel_config=parallel_config, + ) + + # Initialize model parallel + initialize_model_parallel( + data_parallel_size=1, + cfg_parallel_size=1, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + # Set the config so Attention can access it + with set_current_omni_diffusion_config(od_config): + # Create model + hidden_size = num_heads * head_size + + # Create model with appropriate parameters + model_kwargs = { + "num_heads": num_heads, + "head_size": head_size, + "hidden_size": hidden_size, + "causal": causal, + "num_kv_heads": None, + "scatter_idx": 2, + "gather_idx": 1, + "use_sync": use_sync, + } + + if test_model_cls == TestMultiLayerAttentionModel: + model_kwargs["num_layers"] = 2 + + model = test_model_cls(**model_kwargs) + model = model.to(device).to(dtype) + + # For baseline: Generate and save model state and input data + # This ensures both baseline and SP use exactly the same initialization + if is_baseline and local_rank == 0: + # Save model state for reuse (before any computation) + model_state = {k: v.cpu() for k, v in model.state_dict().items()} + with open(model_state_file, "wb") as f: + pickle.dump(model_state, f) + + # Generate and save full input data with fixed seed + # Reinitialize RNG to ensure reproducibility + torch.manual_seed(42) + torch_device.manual_seed_all(42) + full_hidden_states = torch.randn( + (batch_size, seq_len, hidden_size), + dtype=dtype, + device="cpu", + ) + with open(input_data_file, "wb") as f: + pickle.dump(full_hidden_states.detach().cpu().float().numpy(), f) + + print("[Baseline] Saved model state and input data") + + # Synchronize to ensure baseline has saved data before SP loads it + if world_size > 1: + torch.distributed.barrier() + + # IMPORTANT: Both baseline and SP load the same model state and input data + # This ensures exact same initialization and input for fair comparison + with open(model_state_file, "rb") as f: + model_state = pickle.load(f) + model.load_state_dict({k: v.to(device).to(dtype) for k, v in model_state.items()}) + + with open(input_data_file, "rb") as f: + full_hidden_states_np = pickle.load(f) + full_hidden_states = torch.from_numpy(full_hidden_states_np).to(device).to(dtype) + + print(f"[Rank {local_rank}] Loaded model state and full input data with shape {full_hidden_states.shape}") + + # Split input sequence according to sequence parallel BEFORE model forward + # Each rank gets a contiguous chunk of the sequence dimension + local_seq_len = seq_len // sequence_parallel_size + start_idx = local_rank * local_seq_len + end_idx = start_idx + local_seq_len + hidden_states = full_hidden_states[:, start_idx:end_idx, :].contiguous() + + print( + f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, " + f"indices=[{start_idx}:{end_idx}], local_shape={hidden_states.shape}" + ) + + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) + torch._dynamo.mark_dynamic(hidden_states, 1) + + # Compile model if requested + if use_compile: + model = torch.compile(model) + + # Run forward pass with local sequence chunk + print(f"[Rank {local_rank}] Running forward pass...") + output = model(hidden_states) + print(f"[Rank {local_rank}] Forward pass completed, output shape: {output.shape}") + + # Verify output shape + assert output.shape == (batch_size, local_seq_len, hidden_size), ( + f"Output shape mismatch: expected {(batch_size, local_seq_len, hidden_size)}, got {output.shape}" + ) + + # Verify SP usage for non-baseline runs + if not is_baseline: + if hasattr(model, "attention"): + assert hasattr(model.attention, "use_ulysses"), "Attention should have use_ulysses attribute" + assert model.attention.use_ulysses, "Attention should be using Ulysses" + elif hasattr(model, "layers"): + for i, layer in enumerate(model.layers): + assert hasattr(layer.attention, "use_ulysses"), ( + f"Layer {i} attention should have use_ulysses attribute" + ) + assert layer.attention.use_ulysses, f"Layer {i} attention should be using Ulysses" + + # Gather outputs from all ranks AFTER computation + if world_size > 1: + print(f"[Rank {local_rank}] Gathering outputs from all {world_size} ranks...") + # Gather all outputs to rank 0 + gathered_outputs = [torch.zeros_like(output) for _ in range(world_size)] + torch.distributed.all_gather(gathered_outputs, output) + if local_rank == 0: + # Concatenate along sequence dimension to reconstruct full sequence + full_output = torch.cat(gathered_outputs, dim=1) + print(f"[Rank 0] Gathered and concatenated outputs: {full_output.shape}") + # Verify the full output shape matches expected + assert full_output.shape == (batch_size, seq_len, hidden_size), ( + f"Gathered output shape mismatch: expected {(batch_size, seq_len, hidden_size)}, " + f"got {full_output.shape}" + ) + else: + full_output = None + else: + # For baseline (world_size=1), output is already complete + full_output = output + print(f"[Rank 0] No gather needed (world_size=1), output shape: {full_output.shape}") + + # Save output from rank 0 for comparison + if local_rank == 0: + output_np = full_output.detach().cpu().float().numpy() + with open(output_file, "wb") as f: + pickle.dump(output_np, f) + + mode_str = "baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" + print( + f"\n[{mode_str}] ✓ Saved output with shape {full_output.shape}:\n" + f" - batch_size={batch_size}, seq_len={seq_len}\n" + f" - num_heads={num_heads}, head_size={head_size}\n" + f" - dtype={dtype}, causal={causal}, use_sync={use_sync}\n" + ) + + destroy_distributed_env() diff --git a/tests/diffusion/distributed/test_comm.py b/tests/diffusion/distributed/test_comm.py new file mode 100644 index 00000000000..7bd6386796e --- /dev/null +++ b/tests/diffusion/distributed/test_comm.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for SeqAllToAll4D and SeqAllToAll5D communication primitives.""" + +import os + +import pytest +import torch + +from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D, SeqAllToAll5D +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_sp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.utils.platform_utils import detect_device_type + +device_type = detect_device_type() +if device_type == "cuda": + torch_device = torch.cuda +elif device_type == "npu": + torch_device = torch.npu +else: + raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len_per_rank", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("use_sync", [False, True]) +def test_4d_identity( + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Test that two consecutive all-to-all operations return the original input.""" + # Ensure num_heads is divisible by world_size + if num_heads % world_size != 0: + pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})") + + # Run test with multiprocessing spawn + torch.multiprocessing.spawn( + _test_4d_identity_worker, + args=( + world_size, + dtype, + batch_size, + seq_len_per_rank, + num_heads, + head_size, + use_sync, + ), + nprocs=world_size, + ) + + +def _test_4d_identity_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Worker function for test_4d_identity.""" + # Set device + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + # Set environment variables for distributed training + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29500", + } + ) + + # Initialize distributed environment + init_distributed_environment() + initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default + sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group + + # Create input tensor: (bs, seqlen/P, hc, hs) + torch.manual_seed(42 + local_rank) + input_tensor = torch.randn( + batch_size, + seq_len_per_rank, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Save original input for comparison + original_input = input_tensor.clone() + + # First all-to-all: (bs, seqlen/P, hc, hs) -> (bs, seqlen, hc/P, hs) + intermediate = SeqAllToAll4D.apply( + sp_group, + input_tensor, + 2, # scatter head dimension + 1, # gather sequence dimension + use_sync, + ) + + # Verify intermediate shape + expected_shape = ( + batch_size, + seq_len_per_rank * world_size, + num_heads // world_size, + head_size, + ) + assert intermediate.shape == expected_shape, ( + f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}" + ) + + # Second all-to-all: (bs, seqlen, hc/P, hs) -> (bs, seqlen/P, hc, hs) + output = SeqAllToAll4D.apply( + sp_group, + intermediate, + 1, # scatter sequence dimension + 2, # gather head dimension + use_sync, + ) + + # Verify output shape matches input + assert output.shape == original_input.shape, ( + f"Output shape mismatch: expected {original_input.shape}, got {output.shape}" + ) + + # Verify output matches original input + torch.testing.assert_close( + output, + original_input, + rtol=1e-5, + atol=1e-5, + msg="Output does not match original input after two all-to-all operations", + ) + + # Cleanup distributed environment + destroy_distributed_env() + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len_per_rank", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("use_sync", [False, True]) +def test_5d_identity( + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Test that two consecutive all-to-all operations return the original input.""" + # Ensure num_heads is divisible by world_size + if num_heads % world_size != 0: + pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})") + + # Run test with multiprocessing spawn + torch.multiprocessing.spawn( + _test_5d_identity_worker, + args=( + world_size, + dtype, + batch_size, + seq_len_per_rank, + num_heads, + head_size, + use_sync, + ), + nprocs=world_size, + ) + + +def _test_5d_identity_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Worker function for test_5d_identity.""" + # Set device + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + # Set environment variables for distributed training + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29500", + } + ) + + # Initialize distributed environment + init_distributed_environment() + initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default + sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group + + # Create input tensor: (bs, seqlen/P, 3, hc, hs) + # The '3' dimension is for Q, K, V + torch.manual_seed(42 + local_rank) + input_tensor = torch.randn( + batch_size, + seq_len_per_rank, + 3, # Q, K, V + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Save original input for comparison + original_input = input_tensor.clone() + + # First all-to-all: (bs, seqlen/P, 3, hc, hs) -> (bs, seqlen, 3, hc/P, hs) + intermediate = SeqAllToAll5D.apply( + sp_group, + input_tensor, + 3, # scatter head dimension + 1, # gather sequence dimension + use_sync, + ) + + # Verify intermediate shape + expected_shape = ( + batch_size, + seq_len_per_rank * world_size, + 3, + num_heads // world_size, + head_size, + ) + assert intermediate.shape == expected_shape, ( + f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}" + ) + + # Second all-to-all: (bs, seqlen, 3, hc/P, hs) -> (bs, seqlen/P, 3, hc, hs) + output = SeqAllToAll5D.apply( + sp_group, + intermediate, + 1, # scatter sequence dimension + 3, # gather head dimension + use_sync, + ) + + # Verify output shape matches input + assert output.shape == original_input.shape, ( + f"Output shape mismatch: expected {original_input.shape}, got {output.shape}" + ) + + # Verify output matches original input + torch.testing.assert_close( + output, + original_input, + rtol=1e-5, + atol=1e-5, + msg="Output does not match original input after two all-to-all operations", + ) + + # Cleanup distributed environment + destroy_distributed_env() diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py new file mode 100644 index 00000000000..7557162e373 --- /dev/null +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +System test for Ulysses sequence parallel backend. + +This test verifies that Ulysses-SP (DeepSpeed Ulysses Sequence Parallel) works +correctly with diffusion models. It uses minimal settings to keep test time +short for CI. +""" + +import os +import sys +from pathlib import Path + +import pytest +import torch + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.distributed.parallel_state import device_count +from vllm_omni.diffusion.envs import get_device_name + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +# Use random weights model for testing +models = ["riverclouds/qwen_image_random"] + + +@pytest.mark.parametrize("model_name", models) +@pytest.mark.parametrize("ulysses_degree", [1, 2]) +@pytest.mark.parametrize("ring_degree", [1]) +def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: int): + """Test SP (Ulysses-SP + Ring-SP) backend with diffusion model.""" + # Skip if not enough GPUs available + if device_count() < ulysses_degree: + pytest.skip(f"Test requires {ulysses_degree} GPUs but only {device_count()} available") + + # Configure sequence parallel with DiffusionParallelConfig + parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree) + + m = Omni( + model=model_name, + parallel_config=parallel_config, + ) + + # Use minimal settings for fast testing + height = 256 + width = 256 + num_inference_steps = 4 # Minimal steps for fast test + + images = m.generate( + "a photo of a cat sitting on a laptop keyboard", + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + generator=torch.Generator(get_device_name()).manual_seed(42), + num_outputs_per_prompt=1, # Single output for speed + ) + + # Verify generation succeeded + assert images is not None + assert len(images) == 1 + # Check image size + assert images[0].width == width + assert images[0].height == height diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 85d3da80265..99d4009c9b4 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -1,13 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team & Jiarui Fang +# Adapted from +# https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py import torch +import torch.distributed as dist import torch.nn as nn +from torch import Tensor from vllm_omni.diffusion.attention.backends.abstract import ( AttentionMetadata, ) from vllm_omni.diffusion.attention.selector import get_attn_backend +from vllm_omni.diffusion.data import get_current_omni_diffusion_config +from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D +from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size, get_sp_group class Attention(nn.Module): @@ -19,6 +29,10 @@ def __init__( softmax_scale: float, num_kv_heads: int | None = None, prefix: str = "", + # ulysses attention + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, ): super().__init__() self.attn_backend = get_attn_backend(-1) @@ -31,6 +45,30 @@ def __init__( num_kv_heads=num_kv_heads, ) + self.softmax_scale = softmax_scale + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.use_sync = use_sync + self.ring_pg: dist.ProcessGroup | None = None + self.ulysses_pg: dist.ProcessGroup | None = None + self.use_ulysses = False + + try: + config = get_current_omni_diffusion_config() + if config.parallel_config.ulysses_degree > 1: + self.use_ulysses = True + # Get sequence parallel process group + try: + sp_group = get_sp_group() + self.ring_pg = sp_group.ring_group + self.ulysses_pg = sp_group.ulysses_group + assert get_sequence_parallel_world_size() > 1, "Sequence parallel world size must be > 1" + except (AssertionError, RuntimeError): + # If sequence parallel group is not initialized, disable Ulysses + self.use_ulysses = False + except Exception: + self.use_ulysses = False + def forward( self, query: torch.Tensor, @@ -38,6 +76,43 @@ def forward( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - # shape: (batch_size, seq_len, num_heads, head_size) - attn_output = self.attention.forward(query, key, value, attn_metadata) - return attn_output + if self.use_ulysses: + return self._forward_ulysses(query, key, value, attn_metadata) + else: + # shape: (batch_size, seq_len, num_heads, head_size) + attn_output = self.attention.forward(query, key, value, attn_metadata) + return attn_output + + def _forward_ulysses( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_metadata: AttentionMetadata = None, + ) -> Tensor: + """Ulysses attention forward pass with sequence parallelism.""" + # scatter 2, gather 1 + # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) + q = SeqAllToAll4D.apply(self.ulysses_pg, query, self.scatter_idx, self.gather_idx, self.use_sync) + k = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx, self.use_sync) + v = SeqAllToAll4D.apply(self.ulysses_pg, value, self.scatter_idx, self.gather_idx, self.use_sync) + + softmax_scale = self.softmax_scale + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + + context_layer = self.attention.forward( + q, + k, + v, + attn_metadata=attn_metadata, + ) + + if isinstance(context_layer, tuple): + context_layer = context_layer[0] + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + output = SeqAllToAll4D.apply(self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync) + + return output diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 8210bedab51..cd886f3aeb9 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -5,10 +5,15 @@ import os import random from collections.abc import Callable +from contextlib import contextmanager from dataclasses import dataclass, field, fields +from functools import lru_cache from typing import Any import torch +from pydantic import model_validator +from typing_extensions import Self +from vllm.config.utils import config from vllm.logger import init_logger from vllm_omni.diffusion.utils.network_utils import is_port_available @@ -16,6 +21,61 @@ logger = init_logger(__name__) +@config +@dataclass +class DiffusionParallelConfig: + """Configuration for diffusion model distributed execution.""" + + pipeline_parallel_size: int = 1 + """Number of pipeline parallel stages.""" + + data_parallel_size: int = 1 + """Number of data parallel groups.""" + + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + + sequence_parallel_size: int | None = None + """Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree""" + + ulysses_degree: int = 1 + """Number of GPUs used for ulysses sequence parallelism.""" + + ring_degree: int = 1 + """Number of GPUs used for ring sequence parallelism.""" + + cfg_parallel_size: int = 1 + """Number of Classifier Free Guidance (CFG) parallel groups.""" + + @model_validator(mode="after") + def _validate_parallel_config(self) -> Self: + """Validates the config relationships among the parallel strategies.""" + assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0" + assert self.data_parallel_size > 0, "Data parallel size must be > 0" + assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0" + assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0" + assert self.ulysses_degree > 0, "Ulysses degree must be > 0" + assert self.ring_degree > 0, "Ring degree must be > 0" + assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" + assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( + "Sequence parallel size must be equal to the product of ulysses degree and ring degree," + f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" + ) + return self + + def __post_init__(self) -> None: + if self.sequence_parallel_size is None: + self.sequence_parallel_size = self.ulysses_degree * self.ring_degree + self.world_size = ( + self.pipeline_parallel_size + * self.data_parallel_size + * self.tensor_parallel_size + * self.ulysses_degree + * self.ring_degree + * self.cfg_parallel_size + ) + + @dataclass class TransformerConfig: """Container for raw transformer configuration dictionaries.""" @@ -180,6 +240,7 @@ class OmniDiffusionConfig: # Cache strategy (legacy) cache_strategy: str = "none" + parallel_config: DiffusionParallelConfig = field(default_factory=DiffusionParallelConfig) # Cache backend configuration (NEW) cache_backend: str = "none" # "tea_cache", "deep_cache", etc. @@ -193,20 +254,7 @@ class OmniDiffusionConfig: trust_remote_code: bool = False revision: str | None = None - # Parallelism - num_gpus: int = 1 - tp_size: int = -1 - sp_degree: int = -1 - # sequence parallelism - ulysses_degree: int | None = None - ring_degree: int | None = None - # data parallelism - # number of data parallelism groups - dp_size: int = 1 - # number of gpu in a dp group - dp_degree: int = 1 - # cfg parallel - enable_cfg_parallel: bool = False + num_gpus: int | None = None hsdp_replicate_dim: int = 1 hsdp_shard_dim: int = -1 @@ -329,6 +377,16 @@ def __post_init__(self): # TODO: remove hard code initial_master_port = (self.master_port or 30005) + random.randint(0, 100) self.master_port = self.settle_port(initial_master_port, 37) + if self.num_gpus is None: + if self.parallel_config is not None: + self.num_gpus = self.parallel_config.world_size + else: + self.num_gpus = 1 + + if self.num_gpus < self.parallel_config.world_size: + raise ValueError( + f"num_gpus ({self.num_gpus}) < parallel_config.world_size ({self.parallel_config.world_size})" + ) # Convert cache_config dict to DiffusionCacheConfig if needed if isinstance(self.cache_config, dict): @@ -347,6 +405,56 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": return cls(**kwargs) +_current_omni_diffusion_config: OmniDiffusionConfig | None = None +_current_prefix: str | None = None + + +@contextmanager +def set_current_omni_diffusion_config( + omni_diffusion_config: OmniDiffusionConfig, check_compile=False, prefix: str | None = None +): + """ + Temporarily set the current vLLM-Omni config. + Used during model initialization. + We save the current vLLM-Omni config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM-Omni config to determine how to dispatch. + """ + global _current_omni_diffusion_config, _current_prefix + old_omni_diffusion_config = _current_omni_diffusion_config + old_prefix = _current_prefix + # from vllm.compilation.counter import compilation_counter + + # num_models_seen = compilation_counter.num_models_seen + try: + _current_omni_diffusion_config = omni_diffusion_config + _current_prefix = prefix + yield + except Exception: + raise + else: + if check_compile: + raise RuntimeError("Compilation is not yet supported for OmniDiffusion") + finally: + _current_omni_diffusion_config = old_omni_diffusion_config + _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_omni_diffusion_config()""" + return get_current_omni_diffusion_config().compilation_config + + +def get_current_omni_diffusion_config() -> OmniDiffusionConfig: + if _current_omni_diffusion_config is None: + logger.warning("Current OmniDiffusionConfig is not set.") + return OmniDiffusionConfig() + return _current_omni_diffusion_config + + @dataclass class DiffusionOutput: """ diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py new file mode 100644 index 00000000000..b5f7aa32a4f --- /dev/null +++ b/vllm_omni/diffusion/distributed/comm.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team & Jiarui Fang +# from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/all_to_all.py +from typing import Any + +import torch +import torch.distributed as dist +from torch import Tensor + + +def all_to_all_4D( + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group (torch.distributed.ProcessGroup): torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert input.dim() == 4, f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous() + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> + # (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -transpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + use_sync: bool = False, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + +def all_to_all_5D( + input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs) + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group (torch.distributed.ProcessGroup): torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) + """ + assert input.dim() == 5, f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 3 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs) + bs, shard_seqlen, t_cnt, hc, hs = input.shape + + assert t_cnt == 3 + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> + # (P, seq_len/P, 3, bs, hc/P, hs) + input_t = input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs).transpose(0, 3).contiguous() + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, 3, bs, shard_hc, hs) + + # (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs) + output = output.transpose(0, 2).transpose(1, 2).contiguous() + + return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous() + elif scatter_idx == 1 and gather_idx == 3: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, _, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> + # (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs) + .transpose(0, 4) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.cuda.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, 3, bs, hs) + + # (hc, seqlen/N, bs, hs) -transpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 3).contiguous() + + return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous() + else: + raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3") + + +class SeqAllToAll5D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int = 3, + gather_idx: int = 1, + use_sync: bool = False, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + + return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py new file mode 100644 index 00000000000..8e33d6fb657 --- /dev/null +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -0,0 +1,942 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import pickle +from collections import namedtuple +from typing import Any + +import torch +import torch.distributed +from torch.cuda import synchronize +from torch.distributed import Backend, ProcessGroup +from vllm.logger import init_logger + +from vllm_omni.diffusion import envs + +logger = init_logger(__name__) + +if envs._is_npu(): + logger.info("torch.npu synchronize") + from torch.npu import synchronize + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + +env_info = envs.PACKAGES_CHECKER.get_packages_info() + + +def _split_tensor_dict( + tensor_dict: dict[str, torch.Tensor | Any], prefix: str = "" +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: list[tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, "Avoid having '%' in key as it is used as a separator for nested entries." + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append((prefix + key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict(value, prefix + key + "%") + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = envs.get_device(local_rank) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, op=op, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> torch.Tensor | list[torch.Tensor]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty(input_size, dtype=input_.dtype, device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.device_group) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape( + [ + world_size, + ] + + input_size + ) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1).narrow(0, input_.numel() * i, input_.numel()).view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, gather_list, dst=self.ranks[dst], group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, src=self.ranks[src], group=self.device_group) + return input_ + + def broadcast_object(self, obj: Any | None = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], src=self.ranks[src], group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, src=self.ranks[src], group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank." + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], dtype=torch.long, device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, src=self.ranks[src], group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv(object_tensor, src=self.ranks[src], group=self.cpu_group) + + assert rank_object == rank_size, "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, + src: int = 0, + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict(self, src: int | None = None) -> dict[str, torch.Tensor | Any] | None: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), + ) + + def recv(self, size: torch.Size, dtype: torch.dtype, src: int | None = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + (self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class PipelineGroupCoordinator(GroupCoordinator): + """ + available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + difference between `local_rank` and `rank_in_group`: + if we have a group of size 4 across two nodes: + Process | Node | Rank | Local Rank | Rank in Group + 0 | 0 | 0 | 0 | 0 + 1 | 0 | 1 | 1 | 1 + 2 | 1 | 2 | 0 | 2 + 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + """ + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.cpu_groups = [] + self.device_groups = [] + if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: + for ranks in group_ranks: + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + # when pipeline parallelism is 2, we need to create two groups to avoid + # communication stall. + # *_group_0_1 represents the group for communication from device 0 to + # device 1. + # *_group_1_0 represents the group for communication from device 1 to + # device 0. + elif len(group_ranks[0]) == 2: + for ranks in group_ranks: + device_group_0_1 = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + device_group_1_0 = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") + cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_groups = [device_group_0_1, device_group_1_0] + self.cpu_groups = [cpu_group_0_1, cpu_group_1_0] + self.device_group = device_group_0_1 + self.cpu_group = cpu_group_0_1 + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = envs.get_device(local_rank) + + self.recv_buffer_set: bool = False + self.recv_tasks_queue: list[tuple[str, int]] = [] + self.receiving_tasks: list[tuple[torch.distributed.Work, str, int]] = [] + self.dtype: torch.dtype | None = None + self.num_pipefusion_patches: int | None = None + + self.recv_shape: dict[str, dict[int, torch.Size]] = {} + self.send_shape: dict[str, dict[int, torch.Size]] = {} + self.recv_buffer: dict[str, dict[int, torch.Size]] = {} + + self.skip_tensor_recv_buffer_set: bool = False + self.recv_skip_tasks_queue: list[int | tuple[str, int]] = [] + self.receiving_skip_tasks: list[tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: list[torch.Tensor] | torch.Tensor | None = None + self.skip_device_group = None + for ranks in group_ranks: + skip_device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + if self.rank in ranks: + self.skip_device_group = skip_device_group + assert self.skip_device_group is not None + + def reset_buffer(self): + self.recv_tasks_queue = [] + self.receiving_tasks = [] + self.recv_shape = {} + self.send_shape = {} + self.recv_buffer = {} + + self.recv_skip_tasks_queue = [] + self.receiving_skip_tasks = [] + self.skip_tensor_recv_buffer = {} + + def set_config(self, dtype: torch.dtype): + self.dtype = dtype + + def set_recv_buffer( + self, + num_pipefusion_patches: int, + patches_shape_list: list[list[int]], + feature_map_shape: list[int], + dtype: torch.dtype, + ): + assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" + assert isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1, ( + "num_pipefusion_patches must be greater than or equal to 1" + ) + self.dtype = dtype + self.num_pipefusion_patches = num_pipefusion_patches + self.recv_buffer = [torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list] + self.recv_buffer.append(torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)) + self.recv_buffer_set = True + + def set_extra_tensors_recv_buffer( + self, + name: str, + shape: list[int], + num_buffers: int = 1, + dtype: torch.dtype = torch.float16, + ): + self.extra_tensors_recv_buffer[name] = [ + torch.zeros(*shape, dtype=dtype, device=self.device) for _ in range(num_buffers) + ] + + def _check_shape_and_buffer( + self, + tensor_send_to_next=None, + recv_prev=False, + name: str | None = None, + segment_idx: int = 0, + ): + send_flag = False + name = name or "latent" + if tensor_send_to_next is not None: + shape_list = self.send_shape.get(name, None) + if shape_list is None: + self.send_shape[name] = {segment_idx: tensor_send_to_next.shape} + send_flag = True + elif shape_list.get(segment_idx, None) is None: + self.send_shape[name][segment_idx] = tensor_send_to_next.shape + send_flag = True + + recv_flag = False + if recv_prev: + shape_list = self.recv_shape.get(name, None) + if shape_list is None: + recv_flag = True + elif shape_list.get(segment_idx, None) is None: + recv_flag = True + + recv_prev_shape = self._communicate_shapes( + tensor_send_to_next=tensor_send_to_next if send_flag else None, + recv_prev=recv_flag, + ) + + if recv_flag: + if self.recv_shape.get(name, None) is None: + self.recv_shape[name] = {segment_idx: recv_prev_shape} + else: + self.recv_shape[name][segment_idx] = recv_prev_shape + + if self.recv_buffer.get(name, None) is None: + self.recv_buffer[name] = { + segment_idx: torch.zeros(recv_prev_shape, device=self.device, dtype=self.dtype) + } + else: + if self.recv_buffer[name].get(segment_idx, None) is not None: + logger.warning(f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating...") + self.recv_buffer[name][segment_idx] = torch.zeros(recv_prev_shape, device=self.device, dtype=self.dtype) + + def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + + Args: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + """ + + ops = [] + if recv_prev: + recv_prev_dim_tensor = torch.empty((1), device=self.device, dtype=torch.int64) + recv_prev_dim_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_dim_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_dim_op) + + if tensor_send_to_next is not None: + send_next_dim_tensor = torch.tensor(tensor_send_to_next.dim(), device=self.device, dtype=torch.int64) + send_next_dim_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_dim_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_dim_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + synchronize() + + ops = [] + recv_prev_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty( + torch.Size(recv_prev_dim_tensor), device=self.device, dtype=torch.int64 + ) + recv_prev_shape_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_shape_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_shape_op) + + if tensor_send_to_next is not None: + send_next_shape_tensor = torch.tensor(tensor_send_to_next.size(), device=self.device, dtype=torch.int64) + send_next_shape_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_shape_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_shape_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor + return torch.Size(recv_prev_shape) + + def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) + self._pipeline_isend(tensor).wait() + + def pipeline_isend(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) + self._pipeline_isend(tensor) + + def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + name = name or "latent" + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self._pipeline_irecv(self.recv_buffer[name][idx]).wait() + return self.recv_buffer[name][idx] + + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): + name = name or "latent" + self.recv_tasks_queue.append((name, idx)) + + def recv_next(self): + if len(self.recv_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_tasks_queue) > 0: + name, idx = self.recv_tasks_queue.pop(0) + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self.receiving_tasks.append((self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx)) + + def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + assert len(self.receiving_tasks) > 0, "No tasks to receive, call add_pipeline_recv_task first" + receiving_task = self.receiving_tasks.pop(0) + receiving_task[0].wait() + assert receiving_task[1] == name and receiving_task[2] == idx, "Received tensor does not match the requested" + return self.recv_buffer[name][idx] + + def _pipeline_irecv(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, + src=self.prev_rank, + group=(self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), + ) + + def _pipeline_isend(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, + dst=self.next_rank, + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), + ) + + def set_skip_tensor_recv_buffer( + self, + patches_shape_list: list[list[int]], + feature_map_shape: list[int], + ): + self.skip_tensor_recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list + ] + self.skip_tensor_recv_buffer.append(torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)) + self.skip_tensor_recv_buffer_set = True + + def pipeline_send_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor).wait() + + def pipeline_isend_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor) + + def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor: + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait() + return self.skip_tensor_recv_buffer[idx] + + def add_pipeline_recv_skip_task(self, idx: int = -1): + self.recv_skip_tasks_queue.append(idx) + + def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor: + assert len(self.receiving_skip_tasks) > 0, "No tasks to receive, call add_pipeline_recv_skip_task first" + receiving_skip_task = self.receiving_skip_tasks.pop(0) + receiving_skip_task[0].wait() + assert receiving_skip_task[2] == idx, "Received tensor does not match the requested" + return self.skip_tensor_recv_buffer[idx] + + def recv_skip_next(self): + if len(self.recv_skip_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_skip_tasks_queue) > 0: + task = self.recv_skip_tasks_queue.pop(0) + idx = task + self.receiving_skip_tasks.append( + ( + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]), + None, + idx, + ) + ) + + def _pipeline_irecv_skip(self, tensor: torch.tensor): + return torch.distributed.irecv(tensor, src=self.skip_rank, group=self.skip_device_group) + + def _pipeline_isend_skip(self, tensor: torch.tensor): + return torch.distributed.isend(tensor, dst=self.skip_rank, group=self.skip_device_group) + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + + ulysses_group = kwargs.get("ulysses_group", None) + ring_group = kwargs.get("ring_group", None) + if ulysses_group is None: + raise RuntimeError( + "Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + ) + if ring_group is None: + raise RuntimeError( + "Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + ) + self.ulysses_group = ulysses_group + self.ring_group = ring_group + + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py new file mode 100644 index 00000000000..b249515909d --- /dev/null +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -0,0 +1,760 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/xdit-project/xDiT/blob/main/xfuser/core/distributed/utils.py +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM-Omni distributed state. + +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model parallelism, + you can skip the model parallel initialization and destruction steps. +""" + +import torch +import torch.distributed +import vllm.distributed.parallel_state as vllm_parallel_state +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size +from vllm.logger import init_logger + +from vllm_omni.diffusion import envs + +from .group_coordinator import ( + GroupCoordinator, + PipelineGroupCoordinator, + SequenceParallelGroupCoordinator, +) + +if envs._is_npu(): + from torch.npu import device_count, set_device +elif envs._is_musa(): + from torch_musa.core.device import device_count, set_device +else: + from torch.cuda import device_count, set_device + + +env_info = envs.PACKAGES_CHECKER.get_packages_info() + +HAS_FLASH_ATTN = env_info["has_flash_attn"] + +logger = init_logger(__name__) + + +_WORLD: GroupCoordinator | None = None +# get _TP from vllm.distributed.parallel_state +_SP: SequenceParallelGroupCoordinator | None = None +_PP: PipelineGroupCoordinator | None = None +_CFG: GroupCoordinator | None = None +_DP: GroupCoordinator | None = None +_DIT: GroupCoordinator | None = None +_VAE: GroupCoordinator | None = None + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: list[int], mask: list[bool] +) -> list[list[int]]: + r"""Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (list[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (list[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) + tp_rank \in [0, tp_size) + dp_rank \in [0, dp_size) + pp_rank \in [0, pp_size) + + If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. + For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the + dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) + The tp_rank and pp_rank will be combined to form the `dp_group_index`. + dp_group_index = tp_rank + pp_rank * tp_size (2) + + So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in + range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the + equation (1). + + This function solve this math problem. + + For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], + and the mask = [False, True, False]. Then, + dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 + dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 + ... + dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 + + dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] + dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] + ... + dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] + """ + + def prefix_product(a: list[int], init=1) -> list[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: list[int], b: list[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert sum([x * y for x, y in zip(idx, stride[:-1])]) == index, ( + f"idx {index} with shape {shape} mismatch the return idx {idx}" + ) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator: + def __init__( + self, + tp: int, + sp: int, + pp: int, + cfg: int, + dp: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.pp = pp + self.cfg = cfg + self.dp = dp + self.rank_offset = rank_offset + self.world_size = tp * sp * pp * cfg * dp + + self.name_to_size = { + "tp": self.tp, + "sp": self.sp, + "pp": self.pp, + "cfg": self.cfg, + "dp": self.dp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), " + f"but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks + + +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + + +def get_ulysses_parallel_world_size(): + return get_sp_group().ulysses_world_size + + +def get_ulysses_parallel_rank(): + return get_sp_group().ulysses_rank + + +def get_ring_parallel_world_size(): + return get_sp_group().ring_world_size + + +def get_ring_parallel_rank(): + return get_sp_group().ring_rank + + +# PP +def get_pp_group() -> PipelineGroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +def get_pipeline_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_first_stage(): + """Return True if in the first pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == 0 + + +def is_pipeline_last_stage(): + """Return True if in the last pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert _CFG is not None, "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + +# DP +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "pipeline model parallel group is not initialized" + return _DP + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def is_dp_last_group(): + """Return True if in the last data parallel group, False otherwise.""" + return ( + get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1) + and get_classifier_free_guidance_rank() == (get_classifier_free_guidance_world_size() - 1) + and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + ) + + +def get_dit_world_size(): + """Return world size for the DiT model (excluding VAE).""" + return ( + get_data_parallel_world_size() + * get_classifier_free_guidance_world_size() + * get_sequence_parallel_world_size() + * get_pipeline_parallel_world_size() + * get_tensor_model_parallel_world_size() + ) + + +# Add VAE getter functions +def get_vae_parallel_group() -> GroupCoordinator: + assert _VAE is not None, "VAE parallel group is not initialized" + return _VAE + + +def get_vae_parallel_world_size(): + """Return world size for the VAE parallel group.""" + return get_vae_parallel_group().world_size + + +def get_vae_parallel_rank(): + """Return my rank for the VAE parallel group.""" + return get_vae_parallel_group().rank_in_group + + +# * SET + + +def init_world_group(ranks: list[int], local_rank: int, backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str | None = None, +): + if backend is None: + backend = envs.get_torch_distributed_backend() + logger.debug( + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing distributed environment" + ) + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + set_device(torch.distributed.get_rank() % device_count()) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size" + ) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _DP is not None + and _CFG is not None + and _SP is not None + and _PP is not None + and vllm_parallel_state._TP is not None + ) + + +def init_model_parallel_group( + group_ranks: list[list[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "data", + "pipeline", + "tensor", + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "pipeline": + return PipelineGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + elif parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_dit_group( + dit_parallel_size: int, + backend: str, +): + global _DIT + _DIT = torch.distributed.new_group(ranks=list(range(dit_parallel_size)), backend=backend) + + +def get_dit_group(): + assert _DIT is not None, "DIT group is not initialized" + return _DIT + + +def init_vae_group( + dit_parallel_size: int, + vae_parallel_size: int, + backend: str, +): + # Initialize VAE group first + global _VAE + assert _VAE is None, "VAE parallel group is already initialized" + vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) + _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) + + +# adapted from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py +def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True): + """ + sp_ulysses_degree x sp_ring_degree = seq_parallel_size + (ulysses_degree, dp_size) + """ + sp_size = sp_ring_degree * sp_ulysses_degree + dp_size = world_size // sp_size + + assert world_size % sp_size == 0, f"world_size {world_size} % sp_size {sp_ulysses_degree} == 0" + + num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree + num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree + + if use_ulysses_low: + for dp_rank in range(dp_size): + offset = dp_rank * sp_size + for i in range(num_ulysses_pgs): + ulysses_ranks = list( + range( + i * sp_ulysses_degree + offset, + (i + 1) * sp_ulysses_degree + offset, + ) + ) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + for i in range(num_ring_pgs): + ring_ranks = list(range(i + offset, sp_size + offset, num_ring_pgs)) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + else: + for dp_rank in range(dp_size): + offset = dp_rank * sp_size + for i in range(num_ring_pgs): + ring_ranks = list(range(i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset)) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + for i in range(num_ulysses_pgs): + ulysses_ranks = list(range(i + offset, sp_size + offset, num_ulysses_pgs)) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + return ulyssess_pg, ring_pg + + +def initialize_model_parallel( + data_parallel_size: int = 1, + cfg_parallel_size: int = 1, + sequence_parallel_size: int | None = None, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + vae_parallel_size: int = 0, + backend: str | None = None, +) -> None: + if backend is None: + backend = envs.get_torch_distributed_backend() + """ + Initialize model parallel groups. + + Arguments: + data_parallel_size: number of data parallelism groups. + cfg_parallel_size: number of GPUs used for Classifier Free Guidance (CFG) parallelism. + sequence_parallel_size: number of GPUs used for sequence parallelism. + sequence_parallel_size = ulysses_degree * ring_degree + ulysses_degree: number of GPUs used for ulysses sequence parallelism. + ring_degree: number of GPUs used for ring sequence parallelism. + tensor_parallel_size: number of GPUs used for tensor parallelism. + pipeline_parallel_size: number of GPUs used for pipeline parallelism. + backend: distributed backend of pytorch collective comm. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize + split batch caused by CFG, and 2 GPUs to parallelize sequence. + + dp_size (2) * cfg_size (2) * sp_size (2) * pp_size (2) = 16. + + The present function will create 8 data-parallel groups, + 8 CFG group, 8 pipeline-parallel group, and + 8 sequence-parallel groups: + 8 data-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], + [g4, g12], [g5, g13], [g6, g14], [g7, g15] + 8 CFG-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7], + [g8, g12], [g9, g13], [g10, g14], [g11, g15] + 8 sequence-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], + [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 pipeline-parallel groups: + [g0, g2], [g4, g6], [g8, g10], [g12, g14], + [g1, g3], [g5, g7], [g9, g11], [g13, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if sequence_parallel_size is None: + sequence_parallel_size = ring_degree * ulysses_degree + logger.info( + f"sequence_parallel_size is not provided, using ring_degree * ulysses_degree = {sequence_parallel_size}" + ) + + if sequence_parallel_size != ring_degree * ulysses_degree: + raise ValueError( + "sequence_parallel_size is not equal to ring_degree * ulysses_degree," + f" but got {sequence_parallel_size} != {ring_degree} * {ulysses_degree}" + ) + + # FIXME: Since the async p2p communication operation of NPU is not same as cuda in torch, + # the pipefusion is not ready for npu yet + if envs._is_npu(): + assert pipeline_parallel_size == 1, "Current pipefusion is not ready for NPU" + + dit_parallel_size = ( + data_parallel_size * cfg_parallel_size * sequence_parallel_size * pipeline_parallel_size * tensor_parallel_size + ) + + if world_size < dit_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than " + f"tensor_parallel_size ({tensor_parallel_size}) x " + f"pipeline_parallel_size ({pipeline_parallel_size}) x" + f"sequence_parallel_size ({sequence_parallel_size}) x" + f"cfg_parallel_size " + f"({cfg_parallel_size}) x" + f"data_parallel_size ({data_parallel_size})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_size, + sequence_parallel_size, + pipeline_parallel_size, + cfg_parallel_size, + data_parallel_size, + "tp-sp-pp-cfg-dp", + ) + global _DP + assert _DP is None, "data parallel group is already initialized" + _DP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("dp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="data", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + _PP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("pp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="pipeline", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + ulysses_pg, ring_pg = set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=dit_parallel_size, + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=ulysses_pg, + ring_group=ring_pg, + ) + + assert vllm_parallel_state._TP is None, "Tensor parallel group is already initialized" + vllm_parallel_state._TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + if vae_parallel_size > 0: + init_vae_group(dit_parallel_size, vae_parallel_size, backend) + init_dit_group(dit_parallel_size, backend) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _DP + if _DP: + _DP.destroy() + _DP = None + + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + if vllm_parallel_state._TP: + vllm_parallel_state._TP.destroy() + vllm_parallel_state._TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + global _VAE + if _VAE: + _VAE.destroy() + _VAE = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def destroy_distributed_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py new file mode 100644 index 00000000000..f717f2ca08b --- /dev/null +++ b/vllm_omni/diffusion/envs.py @@ -0,0 +1,214 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/xdit-project/xDiT/blob/main/xfuser/envs.py +import os +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import torch +from packaging import version +from vllm.logger import init_logger + +logger = init_logger(__name__) + +if TYPE_CHECKING: + MASTER_ADDR: str = "" + MASTER_PORT: int | None = None + CUDA_HOME: str | None = None + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: str | None = None + CUDA_VERSION: version.Version + TORCH_VERSION: version.Version + +environment_variables: dict[str, Callable[[], Any]] = { + # ================== Runtime Env Vars ================== + # used in distributed environment to determine the master address + "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""), + # used in distributed environment to manually set the communication port + "MASTER_PORT": lambda: (int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None), + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), +} + + +def _is_hip(): + has_rocm = torch.version.hip is not None + return has_rocm + + +def _is_cuda(): + has_cuda = torch.version.cuda is not None + return has_cuda + + +def _is_musa(): + try: + if hasattr(torch, "musa") and torch.musa.is_available(): + return True + except ModuleNotFoundError: + return False + + +def _is_mps(): + return torch.backends.mps.is_available() + + +def _is_npu(): + try: + if hasattr(torch, "npu") and torch.npu.is_available(): + return True + except ModuleNotFoundError: + return False + + +def get_device(local_rank: int) -> torch.device: + if _is_cuda() or _is_hip(): + return torch.device("cuda", local_rank) + elif _is_musa(): + return torch.device("musa", local_rank) + elif _is_mps(): + return torch.device("mps") + elif _is_npu(): + return torch.device("npu", local_rank) + else: + return torch.device("cpu") + + +def get_device_name() -> str: + if _is_cuda() or _is_hip(): + return "cuda" + elif _is_musa(): + return "musa" + elif _is_mps(): + return "mps" + elif _is_npu(): + return "npu" + else: + return "cpu" + + +def get_device_version(): + if _is_hip(): + hip_version = torch.version.hip + hip_version = hip_version.split("-")[0] + return hip_version + elif _is_cuda(): + return torch.version.cuda + elif _is_musa(): + return torch.version.musa + elif _is_mps(): + return None + elif _is_npu(): + return None + else: + raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") + + +def get_torch_distributed_backend() -> str: + if _is_cuda() or _is_hip(): + return "nccl" + elif _is_musa(): + return "mccl" + elif _is_mps(): + return "gloo" + elif _is_npu(): + return "hccl" + else: + raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") + + +variables: dict[str, Callable[[], Any]] = { + # ================== Other Vars ================== + # used in version checking + "CUDA_VERSION": lambda: version.parse(get_device_version() or "0.0"), + "TORCH_VERSION": lambda: version.parse(version.parse(torch.__version__).base_version), +} + + +def _setup_musa(environment_variables, variables): + musa = getattr(torch, "musa", None) + if musa is None: + return + try: + if musa.is_available(): + environment_variables["MUSA_HOME"] = lambda: os.environ.get("MUSA_HOME", None) + environment_variables["MUSA_VISIBLE_DEVICES"] = lambda: os.environ.get("MUSA_VISIBLE_DEVICES", None) + musa_ver = getattr(getattr(torch, "version", None), "musa", None) + if musa_ver: + variables["MUSA_VERSION"] = lambda: version.parse(musa_ver) + except Exception: + pass + + +try: + _setup_musa(environment_variables, variables) +except (AttributeError, ModuleNotFoundError): + pass + + +class PackagesEnvChecker: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.initialize() + return cls._instance + + def initialize(self): + packages_info = {} + packages_info["has_flash_attn"] = self.check_flash_attn(packages_info) + self.packages_info = packages_info + + def check_flash_attn(self, packages_info): + if not torch.cuda.is_available(): + return False + + # Check if torch_npu is available + if _is_npu(): + logger.info("`flash_attn` is not ready on torch_npu for now") + return False + + if _is_musa(): + logger.info("Flash Attention library is not supported on MUSA for the moment.") + return False + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + gpu_name = torch.cuda.get_device_name(device) + if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: + return False + else: + from flash_attn import __version__ + + if __version__ < "2.6.0": + raise ImportError("install flash_attn >= 2.6.0") + return True + except ImportError: + if not packages_info.get("has_aiter", False): + logger.warning('Flash Attention library "flash_attn" not found, using pytorch attention implementation') + return False + + def get_packages_info(self): + return self.packages_info + + +PACKAGES_CHECKER = PackagesEnvChecker() + + +def __getattr__(name): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + if name in variables: + return variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 4db4156bfdb..547f1714e20 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -21,6 +21,11 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) logger = init_logger(__name__) @@ -542,6 +547,7 @@ def __init__( super().__init__() model_config = od_config.tf_model_config num_layers = model_config.num_layers + self.parallel_config = od_config.parallel_config self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -616,6 +622,11 @@ def forward( # else: # lora_scale = 1.0 + if self.parallel_config.sequence_parallel_size > 1: + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[ + get_sequence_parallel_rank() + ] + hidden_states = self.img_in(hidden_states) # Ensure timestep tensor is on the same device and dtype as hidden_states @@ -645,6 +656,15 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + def get_rotary_emb_chunk(freqs): + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] + return freqs + + if self.parallel_config.sequence_parallel_size > 1: + img_freqs, txt_freqs = image_rotary_emb + img_freqs = get_rotary_emb_chunk(img_freqs) + image_rotary_emb = (img_freqs, txt_freqs) + for index_block, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, @@ -662,6 +682,8 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) + if self.parallel_config.sequence_parallel_size > 1: + output = get_sp_group().all_gather(output, dim=-2) return Transformer2DModelOutput(sample=output) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 86861184149..40f0b7b300f 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -8,10 +8,6 @@ import zmq from vllm.config import LoadConfig, VllmConfig, set_current_vllm_config from vllm.distributed.device_communicators.shm_broadcast import MessageQueue -from vllm.distributed.parallel_state import ( - init_distributed_environment, - initialize_model_parallel, -) from vllm.logger import init_logger from vllm.utils import DeviceMemoryProfiler, GiB_bytes @@ -20,6 +16,12 @@ SHUTDOWN_MESSAGE, DiffusionOutput, OmniDiffusionConfig, + set_current_omni_diffusion_config, +) +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, ) from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -61,22 +63,33 @@ def init_device_and_model(self) -> None: # hack vllm_config = VllmConfig() - vllm_config.parallel_config.tensor_parallel_size = self.od_config.num_gpus - set_current_vllm_config(vllm_config) - - init_distributed_environment(world_size=world_size, rank=rank) - initialize_model_parallel(tensor_model_parallel_size=world_size) - logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") - - load_config = LoadConfig() - model_loader = DiffusersPipelineLoader(load_config) - time_before_load = time.perf_counter() - with DeviceMemoryProfiler() as m: - self.pipeline = model_loader.load_model( - od_config=self.od_config, - load_device=f"cuda:{rank}", - ) - time_after_load = time.perf_counter() + vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size + vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size + + with set_current_omni_diffusion_config(self.od_config): + with set_current_vllm_config(vllm_config): + init_distributed_environment(world_size=world_size, rank=rank) + logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_size=parallel_config.data_parallel_size, + cfg_parallel_size=parallel_config.cfg_parallel_size, + sequence_parallel_size=parallel_config.sequence_parallel_size, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_size=parallel_config.tensor_parallel_size, + pipeline_parallel_size=parallel_config.pipeline_parallel_size, + ) + + load_config = LoadConfig() + model_loader = DiffusersPipelineLoader(load_config) + time_before_load = time.perf_counter() + with DeviceMemoryProfiler() as m: + self.pipeline = model_loader.load_model( + od_config=self.od_config, + load_device=f"cuda:{rank}", + ) + time_after_load = time.perf_counter() logger.info( "Model loading took %.4f GiB and %.6f seconds", @@ -108,12 +121,7 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi return output def shutdown(self) -> None: - if torch.distributed.is_initialized(): - try: - torch.distributed.destroy_process_group() - logger.info("Worker %s: Destroyed process group", self.rank) - except Exception as exc: # pragma: no cover - best effort cleanup - logger.warning("Worker %s: Failed to destroy process group: %s", self.rank, exc) + destroy_distributed_env() class WorkerProc: