diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index a0146042487..ca604f11744 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -71,6 +71,7 @@ steps: image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT always-pull: true propagate-environment: true + shm-size: "8gb" environment: - "HF_HOME=/fsx/hf_cache" volumes: diff --git a/docs/user_guide/acceleration/parallelism_acceleration.md b/docs/user_guide/acceleration/parallelism_acceleration.md index 3f2283b3b28..603c165aaf4 100644 --- a/docs/user_guide/acceleration/parallelism_acceleration.md +++ b/docs/user_guide/acceleration/parallelism_acceleration.md @@ -8,27 +8,29 @@ The following parallelism methods are currently supported in vLLM-Omni: 1. DeepSpeed Ulysses Sequence Parallel (DeepSpeed Ulysses-SP) ([arxiv paper](https://arxiv.org/pdf/2309.14509)): Ulysses-SP splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. +2. [Ring-Attention](#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded + The following table shows which models are currently supported by parallelism method: ### ImageGen -| Model | Model Identifier | Ulysses-SP | -|-------|------------------|-----------| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | +| Model | Model Identifier | Ulysses-SP | Ring-SP | +|-------|------------------|-----------|---------| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ❌ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ❌ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ | ### VideoGen -| Model | Model Identifier | Ulysses-SP | -|-------|------------------|-----------| -| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | +| Model | Model Identifier | Ulysses-SP | Ring-SP | +|-------|------------------|-----------|---------| +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ❌ | ### Sequence Parallelism @@ -80,6 +82,101 @@ To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** m | Ulysses-SP | 4 | 39.6s | 2.84x | | Ulysses-SP | 8 | 30.8s | 3.65x | +#### Ring-Attention + +Ring-Attention ([arxiv paper](https://arxiv.org/abs/2310.01889)) splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results. Unlike Ulysses-SP which uses all-to-all communication, Ring-Attention keeps the sequence dimension sharded throughout the computation and circulates Key/Value blocks through a ring topology. + +##### Offline Inference + +An example of offline inference script using Ring-Attention is shown below: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ring_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ring_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. + + +##### Online Serving + +You can enable Ring-Attention in online serving for diffusion models via `--ring`: + +```bash +# Text-to-image (requires >= 2 GPUs) +vllm serve Qwen/Qwen-Image --omni --port 8091 --ring 2 +``` + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**1024x1024** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA A100 GPUs. `flash_attn` is the attention backends. + +| Configuration | Ring degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 45.2s | 1.0x | +| Ring-Attention | 2 | 29.9s | 1.51x | +| Ring-Attention | 4 | 23.3s | 1.94x | + + +#### Hybrid Ulysses + Ring + +You can combine both Ulysses-SP and Ring-Attention for larger scale parallelism. The total sequence parallel size equals `ulysses_degree × ring_degree`. + +##### Offline Inference + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +# Hybrid: 2 Ulysses × 2 Ring = 4 GPUs total +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + +##### Online Serving + +```bash +# Text-to-image (requires >= 4 GPUs) +vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2 --ring 2 +``` + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**1024x1024** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA A100 GPUs. `flash_attn` is the attention backends. + +| Configuration | Ulysses degree | Ring degree | Generation Time | Speedup | +|---------------|----------------|-------------|-----------------|---------| +| **Baseline (diffusers)** | - | - | 45.2s | 1.0x | +| Hybrid Ulysses + Ring | 2 | 2 | 24.3s | 1.87x | + + ##### How to parallelize a new model If a diffusion model has been deployed in vLLM-Omni and supports single-card inference, you can refer to the following instructions to parallelize it with [Ulysses-SP](https://arxiv.org/pdf/2309.14509). diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 716d196c71f..baa0c81dc4e 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -17,6 +17,7 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma vLLM-Omni also supports the sequence parallelism (SP) for the diffusion model, that includes: 1. [Ulysses-SP](acceleration/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. +2. [Ring-Attention](acceleration/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded. ## Quick Comparison @@ -33,6 +34,14 @@ The following table shows which models are currently supported by each accelerat ### ImageGen +<<<<<<< HEAD +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | +|-------|-----------------|----------|-----------|-----------|----------------| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ✅ |❌ | ❌ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ |✅ | - | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ❌ | ✅ |✅ | - | +======= | Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | |-------|------------------|----------|-----------|-----------| | **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | @@ -151,6 +160,22 @@ outputs = omni.generate(prompt="turn this cat to a dog", pil_image=input_image, num_inference_steps=50) ``` +### Using Ring-Attention + +Run text-to-image: +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +ring_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ring_degree=2) +) + +outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +``` + ## Documentation For detailed information on each acceleration method: diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index bf4348e4c5e..c2ee3b4b603 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -159,7 +159,12 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ulysses sequence parallelism.", ) - + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.") parser.add_argument( "--resolution", @@ -268,8 +273,7 @@ def main(): # Enable VAE memory optimizations on NPU vae_use_slicing = is_npu() vae_use_tiling = is_npu() - - parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) # Configure cache based on backend type cache_config = None if args.cache_backend == "cache_dit": @@ -315,7 +319,7 @@ def main(): print(f" Image {idx + 1} size: {img.size}") else: print(f" Input image size: {input_image.size}") - print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") print(f"{'=' * 60}\n") try: diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index ded976f3d88..b487683ed26 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -66,7 +66,12 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ulysses sequence parallelism.", ) - + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) return parser.parse_args() @@ -108,7 +113,8 @@ def main(): # (e.g., QwenImagePipeline or FluxPipeline) } - parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree) + # assert args.ring_degree == 1, "Ring attention is not supported yet" + parallel_config = DiffusionParallelConfig(ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree) omni = Omni( model=args.model, vae_use_slicing=vae_use_slicing, @@ -124,7 +130,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") - print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") + print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") print(f" Image size: {args.width}x{args.height}") print(f"{'=' * 60}\n") diff --git a/tests/diffusion/attention/test_ulysses_sequence_parallel.py b/tests/diffusion/attention/test_sequence_parallel.py similarity index 72% rename from tests/diffusion/attention/test_ulysses_sequence_parallel.py rename to tests/diffusion/attention/test_sequence_parallel.py index 0b60a340c9e..9601db44104 100644 --- a/tests/diffusion/attention/test_ulysses_sequence_parallel.py +++ b/tests/diffusion/attention/test_sequence_parallel.py @@ -8,9 +8,6 @@ import torch from vllm.platforms import current_platform -from vllm_omni.diffusion.attention.backends.abstract import ( - AttentionMetadata, -) from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import ( DiffusionParallelConfig, @@ -19,11 +16,9 @@ ) from vllm_omni.diffusion.distributed.parallel_state import ( destroy_distributed_env, - get_sequence_parallel_world_size, init_distributed_environment, initialize_model_parallel, ) -from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.utils.platform_utils import detect_device_type device_type = detect_device_type() @@ -34,9 +29,6 @@ else: raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") -global split_text_embed_in_sp -split_text_embed_in_sp = False - def update_environment_variables(envs_dict: dict[str, str]): """Update multiple environment variables with logging.""" @@ -78,52 +70,26 @@ def __init__( self.v_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) self.o_proj = torch.nn.Linear(num_heads * head_size, hidden_size) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Forward pass through attention layer.""" batch_size, seq_len, _ = hidden_states.shape - # Combine hidden_states and encoder_hidden_states if provided - if encoder_hidden_states is not None: - # Concatenate along sequence dimension - combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - else: - combined_hidden_states = hidden_states - # Project to Q, K, V - q = self.q_proj(combined_hidden_states) - k = self.k_proj(combined_hidden_states) - v = self.v_proj(combined_hidden_states) - - # Reshape to (batch_size, total_seq_len, num_heads, head_size) - total_seq_len = combined_hidden_states.shape[1] - q = q.view(batch_size, total_seq_len, self.num_heads, self.head_size) - k = k.view(batch_size, total_seq_len, k.shape[-1] // self.head_size, self.head_size) - v = v.view(batch_size, total_seq_len, v.shape[-1] // self.head_size, self.head_size) - - # Apply attention with split logic - if get_sequence_parallel_world_size() > 1 and not split_text_embed_in_sp and encoder_hidden_states is not None: - q_encoder, q_hidden = torch.split(q, [total_seq_len - seq_len, seq_len], dim=1) - k_encoder, k_hidden = torch.split(k, [total_seq_len - seq_len, seq_len], dim=1) - v_encoder, v_hidden = torch.split(v, [total_seq_len - seq_len, seq_len], dim=1) - - # Use hidden_states part as main attention, encoder part as joint - attn_output = self.attention( - q_hidden, - k_hidden, - v_hidden, - AttentionMetadata( - joint_query=q_encoder, - joint_key=k_encoder, - joint_value=v_encoder, - ), - ) - else: - attn_output = self.attention(q, k, v) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (batch_size, seq_len, num_heads, head_size) + q = q.view(batch_size, seq_len, self.num_heads, self.head_size) + k = k.view(batch_size, seq_len, k.shape[-1] // self.head_size, self.head_size) + v = v.view(batch_size, seq_len, v.shape[-1] // self.head_size, self.head_size) + + # Apply attention + attn_output = self.attention(q, k, v) # Reshape back and project - attn_output = attn_output.view(batch_size, total_seq_len, -1) + attn_output = attn_output.view(batch_size, seq_len, -1) output = self.o_proj(attn_output) - output = output[:, encoder_hidden_states.shape[1] :, :] if encoder_hidden_states is not None else output return output @@ -161,33 +127,32 @@ def __init__( ] ) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Forward pass through multiple attention layers.""" for layer in self.layers: - hidden_states = hidden_states + layer(hidden_states, encoder_hidden_states) + hidden_states = hidden_states + layer(hidden_states) return hidden_states @pytest.mark.parametrize( "test_model_cls", [ - TestAttentionModel, TestMultiLayerAttentionModel, ], ) @pytest.mark.parametrize("ulysses_degree", [2]) -@pytest.mark.parametrize("ring_degree", [1]) +@pytest.mark.parametrize("ring_degree", [2]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [16]) -@pytest.mark.parametrize("encoder_seq_len", [16, 13]) # Test both divisible and non-divisible @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_size", [8]) @pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_sync", [True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # [torch.float16, torch.bfloat16] +@pytest.mark.parametrize("use_sync", [False]) @pytest.mark.parametrize("dynamic", [False]) @pytest.mark.parametrize("use_compile", [False]) -def test_ulysses_attention( +@pytest.mark.parametrize("attn_backend", ["sdpa", "flash_attn"]) +def test_sequence_parallel( ulysses_degree: int, ring_degree: int, test_model_cls: type[torch.nn.Module], @@ -198,18 +163,13 @@ def test_ulysses_attention( use_compile: bool, batch_size: int, seq_len: int, - encoder_seq_len: int, num_heads: int, head_size: int, + attn_backend: str, ): """Test Ulysses attention by comparing with and without SP enabled.""" sequence_parallel_size = ulysses_degree * ring_degree - # Determine if we can split encoder_hidden_states in SP - can_split_encoder = (encoder_seq_len % sequence_parallel_size) == 0 - print(f"\nEncoder sequence length: {encoder_seq_len}, SP size: {sequence_parallel_size}") - print(f"Can split encoder in SP: {can_split_encoder}") - # Create temporary files to share results between processes with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: baseline_output_file = f.name @@ -219,8 +179,6 @@ def test_ulysses_attention( model_state_file = f.name with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: input_data_file = f.name - with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: - encoder_input_data_file = f.name try: # Step 1: Run without SP (baseline with ulysses_degree=1, ring_degree=1) @@ -232,7 +190,6 @@ def test_ulysses_attention( test_model_cls, batch_size, seq_len, - encoder_seq_len, num_heads, head_size, dtype, @@ -246,8 +203,8 @@ def test_ulysses_attention( baseline_output_file, model_state_file, input_data_file, - encoder_input_data_file, True, # is_baseline + attn_backend, ), nprocs=1, ) @@ -261,7 +218,6 @@ def test_ulysses_attention( test_model_cls, batch_size, seq_len, - encoder_seq_len, num_heads, head_size, dtype, @@ -275,8 +231,8 @@ def test_ulysses_attention( sp_output_file, model_state_file, input_data_file, - encoder_input_data_file, False, # is_baseline + attn_backend, ), nprocs=sequence_parallel_size, ) @@ -328,13 +284,14 @@ def test_ulysses_attention( print(f"{'=' * 80}\n") # Assert that differences are within acceptable tolerance - # For FP16/BF16, we expect some numerical differences due to different computation order + # For FP16/BF16, we expect some numerical differences due to different computation order under parallelism. + # If we use the same backend (e.g. Flash Attention) for both baseline and SP, differences should be smaller. if dtype == torch.float16: - atol, rtol = 1e-4, 1e-2 + atol, rtol = 5e-2, 5e-2 # Increased tolerance for Ring Attention elif dtype == torch.bfloat16: - atol, rtol = 1e-4, 1e-2 + atol, rtol = 5e-2, 5e-2 # Increased tolerance for Ring Attention else: - atol, rtol = 1e-5, 1e-3 + atol, rtol = 1e-5, 1e-4 assert max_abs_diff < atol or max_relative_diff < rtol, ( f"Output difference too large: max_abs_diff={max_abs_diff:.6e}, " @@ -346,7 +303,7 @@ def test_ulysses_attention( finally: # Clean up temporary files - for f in [baseline_output_file, sp_output_file, model_state_file, input_data_file, encoder_input_data_file]: + for f in [baseline_output_file, sp_output_file, model_state_file, input_data_file]: if os.path.exists(f): os.remove(f) @@ -357,7 +314,6 @@ def ulysses_attention_on_test_model( test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, - encoder_seq_len: int, num_heads: int, head_size: int, dtype: torch.dtype, @@ -371,23 +327,16 @@ def ulysses_attention_on_test_model( output_file: str, model_state_file: str, input_data_file: str, - encoder_input_data_file: str, is_baseline: bool, + attn_backend: str, ): """Run Ulysses attention test on a test model and save results for comparison.""" # Use fixed seed for reproducibility across baseline and SP runs RANDOM_SEED = 42 current_platform.seed_everything(RANDOM_SEED) - # Determine if we can split encoder_hidden_states based on divisibility - global split_text_embed_in_sp - split_text_embed_in_sp = (encoder_seq_len % sequence_parallel_size) == 0 and not is_baseline - mode_str = "Baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" print(f"\n[{mode_str}] Rank {local_rank}/{world_size} - Random seed set to {RANDOM_SEED}") - print( - f"[{mode_str}] Rank {local_rank}/{world_size} - encoder_seq_len={encoder_seq_len}, split_text_embed_in_sp={split_text_embed_in_sp}" - ) device = torch.device(f"{device_type}:{local_rank}") torch_device.set_device(device) @@ -421,6 +370,7 @@ def ulysses_attention_on_test_model( model="test_model", dtype=dtype, parallel_config=parallel_config, + attention_backend=attn_backend, # Set the attention backend here ) # Initialize model parallel @@ -435,7 +385,7 @@ def ulysses_attention_on_test_model( ) # Set the config so Attention can access it - with set_forward_context(omni_diffusion_config=od_config), set_current_omni_diffusion_config(od_config): + with set_current_omni_diffusion_config(od_config): # Create model hidden_size = num_heads * head_size @@ -474,17 +424,10 @@ def ulysses_attention_on_test_model( dtype=dtype, device="cpu", ) - full_encoder_hidden_states = torch.randn( - (batch_size, encoder_seq_len, hidden_size), - dtype=dtype, - device="cpu", - ) with open(input_data_file, "wb") as f: pickle.dump(full_hidden_states.detach().cpu().float().numpy(), f) - with open(encoder_input_data_file, "wb") as f: - pickle.dump(full_encoder_hidden_states.detach().cpu().float().numpy(), f) - print("[Baseline] Saved model state and input data (including encoder_hidden_states)") + print("[Baseline] Saved model state and input data") # Synchronize to ensure baseline has saved data before SP loads it if world_size > 1: @@ -500,13 +443,7 @@ def ulysses_attention_on_test_model( full_hidden_states_np = pickle.load(f) full_hidden_states = torch.from_numpy(full_hidden_states_np).to(device).to(dtype) - with open(encoder_input_data_file, "rb") as f: - full_encoder_hidden_states_np = pickle.load(f) - full_encoder_hidden_states = torch.from_numpy(full_encoder_hidden_states_np).to(device).to(dtype) - - print( - f"[Rank {local_rank}] Loaded model state and full input data with shape {full_hidden_states.shape}, encoder shape {full_encoder_hidden_states.shape}" - ) + print(f"[Rank {local_rank}] Loaded model state and full input data with shape {full_hidden_states.shape}") # Split input sequence according to sequence parallel BEFORE model forward # Each rank gets a contiguous chunk of the sequence dimension @@ -515,28 +452,10 @@ def ulysses_attention_on_test_model( end_idx = start_idx + local_seq_len hidden_states = full_hidden_states[:, start_idx:end_idx, :].contiguous() - # Handle encoder_hidden_states splitting based on split_text_embed_in_sp - if get_sequence_parallel_world_size() > 1 and split_text_embed_in_sp: - # Split encoder_hidden_states in the same way as hidden_states - local_encoder_seq_len = encoder_seq_len // sequence_parallel_size - encoder_start_idx = local_rank * local_encoder_seq_len - encoder_end_idx = encoder_start_idx + local_encoder_seq_len - encoder_hidden_states = full_encoder_hidden_states[:, encoder_start_idx:encoder_end_idx, :].contiguous() - print( - f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, " - f"indices=[{start_idx}:{end_idx}], hidden_states shape={hidden_states.shape}, " - f"encoder_hidden_states (split) shape={encoder_hidden_states.shape}, " - f"encoder_indices=[{encoder_start_idx}:{encoder_end_idx}]" - ) - else: - # No splitting for encoder_hidden_states, use full sequence - encoder_hidden_states = full_encoder_hidden_states - print( - f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, " - f"indices=[{start_idx}:{end_idx}], hidden_states shape={hidden_states.shape}, " - f"encoder_hidden_states (full) shape={encoder_hidden_states.shape}, " - f"split_text_embed_in_sp={split_text_embed_in_sp}" - ) + print( + f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, " + f"indices=[{start_idx}:{end_idx}], local_shape={hidden_states.shape}" + ) if dynamic: torch._dynamo.mark_dynamic(hidden_states, 0) @@ -548,13 +467,14 @@ def ulysses_attention_on_test_model( # Run forward pass with local sequence chunk print(f"[Rank {local_rank}] Running forward pass...") - output = model(hidden_states, encoder_hidden_states) + output = model(hidden_states) print(f"[Rank {local_rank}] Forward pass completed, output shape: {output.shape}") + + # Verify output shape assert output.shape == (batch_size, local_seq_len, hidden_size), ( f"Output shape mismatch: expected {(batch_size, local_seq_len, hidden_size)}, got {output.shape}" ) - output = output.contiguous() # Gather outputs from all ranks AFTER computation if world_size > 1: print(f"[Rank {local_rank}] Gathering outputs from all {world_size} ranks...") @@ -565,7 +485,7 @@ def ulysses_attention_on_test_model( # Concatenate along sequence dimension to reconstruct full sequence full_output = torch.cat(gathered_outputs, dim=1) print(f"[Rank 0] Gathered and concatenated outputs: {full_output.shape}") - + # Verify the full output shape matches expected assert full_output.shape == (batch_size, seq_len, hidden_size), ( f"Gathered output shape mismatch: expected {(batch_size, seq_len, hidden_size)}, " f"got {full_output.shape}" @@ -586,10 +506,9 @@ def ulysses_attention_on_test_model( mode_str = "baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" print( f"\n[{mode_str}] ✓ Saved output with shape {full_output.shape}:\n" - f" - batch_size={batch_size}, seq_len={seq_len}, encoder_seq_len={encoder_seq_len}\n" + f" - batch_size={batch_size}, seq_len={seq_len}\n" f" - num_heads={num_heads}, head_size={head_size}\n" f" - dtype={dtype}, causal={causal}, use_sync={use_sync}\n" - f" - split_text_embed_in_sp={split_text_embed_in_sp}\n" ) destroy_distributed_env() diff --git a/tests/diffusion/distributed/test_comm.py b/tests/diffusion/distributed/test_comm.py index 7bd6386796e..4a84516969c 100644 --- a/tests/diffusion/distributed/test_comm.py +++ b/tests/diffusion/distributed/test_comm.py @@ -7,7 +7,7 @@ import pytest import torch -from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D, SeqAllToAll5D +from vllm_omni.diffusion.distributed.comm import RingComm, SeqAllToAll4D, SeqAllToAll5D from vllm_omni.diffusion.distributed.parallel_state import ( destroy_distributed_env, get_sp_group, @@ -290,3 +290,110 @@ def _test_5d_identity_worker( # Cleanup distributed environment destroy_distributed_env() + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [128]) +def test_ring_p2p( + world_size: int, + dtype: torch.dtype, + batch_size: int, + num_heads: int, + head_size: int, +): + """Test Ring P2P communication (send_recv).""" + torch.multiprocessing.spawn( + _test_ring_p2p_worker, + args=(world_size, dtype, batch_size, num_heads, head_size), + nprocs=world_size, + ) + + +def _test_ring_p2p_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + num_heads: int, + head_size: int, +): + """Worker for Ring P2P test.""" + import sys + + # Set device + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + # Set env vars + # Use a different port to avoid conflict with other tests if run in parallel + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29501", + } + ) + + # Init distributed + try: + init_distributed_environment() + # Ring degree = world_size to test ring group + initialize_model_parallel(ring_degree=world_size) + sp_group = get_sp_group() + + print(f"[Rank {local_rank}] Initialized. Ring group size: {sp_group.ring_group.size()}") + sys.stdout.flush() + + # Create RingComm + comm = RingComm(sp_group.ring_group) + + # Create tensor: rank-specific data + # (batch, num_heads, head_size) + # Fill with rank value + 1 to avoid 0 and make verification easy + input_tensor = torch.full( + (batch_size, num_heads, head_size), fill_value=float(local_rank + 1), dtype=dtype, device=device + ) + + print(f"[Rank {local_rank}] Input sum: {input_tensor.sum().item()}") + sys.stdout.flush() + + # Send input, receive from prev + # RingComm.send_recv sends to next, receives from prev + t0 = __import__("time").time() + recv_tensor = comm.send_recv(input_tensor) + comm.commit() + comm.wait() + t1 = __import__("time").time() + + print(f"[Rank {local_rank}] Communication done in {t1 - t0:.4f}s") + + # Verify + # Expected value: from (rank - 1) % world_size + prev_rank = (local_rank - 1 + world_size) % world_size + expected_value = float(prev_rank + 1) + + recv_sum = recv_tensor.sum().item() + print(f"[Rank {local_rank}] Received sum: {recv_sum}, Expected value: {expected_value}") + sys.stdout.flush() + + expected_tensor = torch.full_like(recv_tensor, fill_value=expected_value) + + # Use a slightly loose tolerance for bfloat16 + torch.testing.assert_close( + recv_tensor, expected_tensor, rtol=1e-3, atol=1e-3, msg=f"[Rank {local_rank}] Data mismatch!" + ) + print(f"[Rank {local_rank}] Verification PASSED") + + except Exception as e: + print(f"[Rank {local_rank}] FAILED with error: {e}") + import traceback + + traceback.print_exc() + raise e + finally: + destroy_distributed_env() diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py index 9101f98f2ff..ab2b858142d 100644 --- a/tests/e2e/offline_inference/test_sequence_parallel.py +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -11,11 +11,13 @@ import os import sys +import time from pathlib import Path import numpy as np import pytest import torch +import torch.distributed as dist from PIL import Image # ruff: noqa: E402 @@ -51,18 +53,61 @@ def _diff_metrics(a: Image.Image, b: Image.Image) -> tuple[float, float]: return abs_diff.mean().item(), abs_diff.max().item() +def _get_images(output): + """Extract images from output, handling both dict and SimpleNamespace types. + + The output structure varies depending on serialization path: + - Direct memory: SimpleNamespace with .images attribute + - SHM serialization: dict with "images" key (dataclass converted via asdict) + - Wrapped output: SimpleNamespace(output=...) which needs unwrapping + """ + # Check if output has direct images attribute (diffusion mode) + if hasattr(output, "images") and output.images: + return output.images + + # Check request_output for pipeline mode + if output.request_output is None: + return None + + if isinstance(output.request_output, list) and len(output.request_output) == 0: + return None + + item = output.request_output[0] + + # Handle wrapped SimpleNamespace (e.g. from omni_stage.py) + # Some items are wrapped as SimpleNamespace(request_id=..., output=...) + while hasattr(item, "output") and not hasattr(item, "images"): + item = item.output + if item is None: + return None + + # Handle both dict (from SHM serialization) and object (direct) types + if isinstance(item, dict): + return item.get("images") + return getattr(item, "images", None) + + @pytest.mark.parametrize("model_name", models) -@pytest.mark.parametrize("ulysses_degree", [2]) -@pytest.mark.parametrize("ring_degree", [1]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: int, dtype: torch.dtype): +@pytest.mark.parametrize("ulysses_degree", [1, 2]) +@pytest.mark.parametrize("ring_degree", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) # Only test bfloat16 to reduce CI time +@pytest.mark.parametrize("attn_backend", ["sdpa"]) +def test_sequence_parallel( + model_name: str, + ulysses_degree: int, + ring_degree: int, + dtype: torch.dtype, + attn_backend: str, +): """Compare baseline (ulysses_degree=1) vs SP (ulysses_degree>1) outputs.""" - if ulysses_degree <= 1: - pytest.skip("This test compares ulysses_degree=1 vs ulysses_degree>1; provide ulysses_degree>1.") + if ulysses_degree <= 1 and ring_degree <= 1: + pytest.skip( + "This test compares ulysses_degree * ring_degree = 1 vs ulysses_degree * ring_degree > 1; provide ulysses_degree or ring_degree>1." + ) # Skip if not enough GPUs available for SP run - if device_count() < ulysses_degree: - pytest.skip(f"Test requires {ulysses_degree} GPUs but only {device_count()} available") + if device_count() < ulysses_degree * ring_degree: + pytest.skip(f"Test requires {ulysses_degree * ring_degree} GPUs but only {device_count()} available") # Use minimal settings for fast testing height = 256 @@ -76,6 +121,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in model=model_name, parallel_config=baseline_parallel_config, dtype=dtype, + attention_backend=attn_backend, ) try: outputs = baseline.generate( @@ -90,6 +136,11 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in baseline_images = list(outputs)[0].request_output[0].images finally: baseline.close() + if dist.is_initialized(): + dist.destroy_process_group() + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]: + os.environ.pop(key, None) + time.sleep(2) # Wait for resources to release assert baseline_images is not None assert len(baseline_images) == 1 @@ -102,6 +153,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in model=model_name, parallel_config=sp_parallel_config, dtype=dtype, + attention_backend=attn_backend, ) try: outputs = sp.generate( @@ -116,6 +168,11 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in sp_images = list(outputs)[0].request_output[0].images finally: sp.close() + if dist.is_initialized(): + dist.destroy_process_group() + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]: + os.environ.pop(key, None) + time.sleep(2) assert sp_images is not None assert len(sp_images) == 1 @@ -134,7 +191,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in max_threshold = 1e-1 print( - "Image diff stats (baseline ulysses_degree=1 vs SP): " + "Image diff stats (baseline ulysses_degree*ring_degree=1 vs SP): " f"mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e}; " f"thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}; " f"ulysses_degree={ulysses_degree}, ring_degree={ring_degree}, dtype={dtype}" diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index ab7162400ec..562999d7e58 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -19,7 +19,7 @@ "vllm_omni/entrypoints/omni_llm.py", "tests/e2e/offline_inference/utils.py", "vllm_omni/diffusion/distributed/group_coordinator.py", - "tests/diffusion/attention/test_ulysses_sequence_parallel.py", + "tests/diffusion/attention/test_sequence_parallel.py", } PICKLE_RE = re.compile( diff --git a/vllm_omni/diffusion/attention/backends/ring/__init__.py b/vllm_omni/diffusion/attention/backends/ring/__init__.py new file mode 100644 index 00000000000..77a31704088 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/__init__.py @@ -0,0 +1 @@ +# Ring attention backend components diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_globals.py b/vllm_omni/diffusion/attention/backends/ring/ring_globals.py new file mode 100644 index 00000000000..7ec0927857f --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_globals.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + + +# test if flash_attn is available +try: + import flash_attn # noqa: F401 + from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward # noqa: F401 + + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + +try: + from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward # noqa: F401 + from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper # noqa: F401 + from flash_attn_interface import flash_attn_func as flash3_attn_func # noqa: F401 + + HAS_FLASH_ATTN_HOPPER = True +except ImportError: + HAS_FLASH_ATTN_HOPPER = False + +try: + from flashinfer.prefill import single_prefill_with_kv_cache # noqa: F401 + + HAS_FLASHINFER = True +except ImportError: + HAS_FLASHINFER = False + +try: + import aiter # noqa: F401 + from aiter import flash_attn_func as flash_attn_func_aiter # noqa: F401 + + HAS_AITER = True +except ImportError: + HAS_AITER = False + +try: + import sageattention # noqa: F401 + + HAS_SAGE_ATTENTION = True +except ImportError: + HAS_SAGE_ATTENTION = False + +try: + import spas_sage_attn # noqa: F401 + + HAS_SPARSE_SAGE_ATTENTION = True +except ImportError: + HAS_SPARSE_SAGE_ATTENTION = False + +try: + import torch_npu # noqa: F401 + + HAS_NPU = True +except ImportError: + HAS_NPU = False diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_kernels.py b/vllm_omni/diffusion/attention/backends/ring/ring_kernels.py new file mode 100644 index 00000000000..2444dd894f8 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_kernels.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +import math + +import torch + +from .ring_globals import HAS_AITER, HAS_FLASH_ATTN, HAS_FLASH_ATTN_HOPPER, HAS_FLASHINFER, HAS_NPU + +_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention +_scaled_dot_product_efficient_attention = torch.ops.aten._scaled_dot_product_efficient_attention + +try: + import torch_musa # noqa: F401 + + _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_attention_flash_musa + _scaled_dot_product_efficient_attention = None +except ModuleNotFoundError: + pass + +if HAS_AITER: + from aiter import flash_attn_func as flash_attn_func_aiter + +if HAS_FLASH_ATTN: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward + +if HAS_FLASH_ATTN_HOPPER: + from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward + from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper + from flash_attn_interface import flash_attn_func as flash3_attn_func +else: + flash_attn_forward_hopper = None + flash_attn_func_hopper_backward = None + flash3_attn_func = None + +if HAS_FLASHINFER: + from flashinfer.prefill import single_prefill_with_kv_cache + + _LOG2_E = math.log2(math.e) + +if HAS_NPU: + import torch_npu + + +def pytorch_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, + op_type="efficient", +): + assert op_type in ["flash", "efficient"], f"Invalid op_type: {op_type}" + """ + q shape (bs, seqlen, nhead, hs) + k shape (bs, seqlen, nhead, hs) + v shape (bs, seqlen, nhead, hs) + """ + # Fallback logic: Flash Attention does not support float32. + # If op_type is 'flash' but dtype is float32, force 'efficient'. + if op_type == "flash" and q.dtype == torch.float32: + op_type = "efficient" + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if op_type == "flash": + out, lse = _scaled_dot_product_flash_attention( + q, + k, + v, + dropout_p=dropout_p, + is_causal=causal, + scale=softmax_scale, + )[:2] + elif op_type == "efficient": + out, lse = _scaled_dot_product_efficient_attention( + q, + k, + v, + attn_bias=None, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=causal, + scale=softmax_scale, + )[:2] + else: + raise ValueError(f"Invalid op_type: {op_type}") + + out = out.transpose(1, 2) + lse = lse.to(q.dtype) + + return out, lse + + +def flash_attn_forward( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, +): + assert HAS_FLASH_ATTN, "FlashAttention is not available" + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if flash_attn.__version__ < "2.6.3": + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + else: + block_out, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + return block_out, block_lse + + +def flash_attn3_func_forward( + q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax +): + assert HAS_FLASH_ATTN_HOPPER + # current signature of flash_attn_forward_hopper: + # (q, k, v, softmax_scale, causal, window_size, descale_q=None, descale_k=None, descale_v=None, gqa_parallel=False) + + out, softmax_lse, *unused = flash_attn_forward_hopper( + q=q, + k=k, + v=v, + k_new=None, + v_new=None, + qv=None, + out=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + cu_seqlens_k_new=None, + seqused_q=None, + seqused_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + page_table=None, + kv_batch_idx=None, + leftpad_k=None, + rotary_cos=None, + rotary_sin=None, + seqlens_rotary=None, + q_descale=None, + k_descale=None, + v_descale=None, + softmax_scale=softmax_scale, + causal=False, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, + pack_gqa=None, + sm_margin=0, + ) + + return out, softmax_lse + + +def flash_attn_forward_aiter( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, +): + assert HAS_AITER, "Aiter is not available" + block_out, block_lse = flash_attn_func_aiter( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_lse=True, + ) + + return block_out, block_lse + + +def flashinfer_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + softmax_scale: float | None = None, + causal: bool = False, + window_size: tuple[int, int] = (-1, -1), + softcap: float | None = None, + alibi_slopes: torch.Tensor | None = None, + return_softmax: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + assert HAS_FLASHINFER, "FlashInfer is not available" + if q.ndim == 4: + if q.shape[0] > 1: + raise ValueError("batch size > 1 is not supported") + out, lse = single_prefill_with_kv_cache( + q[0], + k[0], + v[0], + sm_scale=softmax_scale, + causal=causal, + logits_soft_cap=softcap, + window_left=window_size[0], + return_lse=True, + ) + lse = lse.transpose(0, 1) + out, lse = out.unsqueeze(0), lse.unsqueeze(0) + elif q.ndim == 3: + out, lse = single_prefill_with_kv_cache( + q, + k, + v, + sm_scale=softmax_scale, + causal=causal, + logits_soft_cap=softcap, + window_left=window_size[0], + return_lse=True, + ) + lse = lse.transpose(0, 1) + else: + raise ValueError(f"Invalid input shape: {q.shape}") + lse = lse / _LOG2_E + return out, lse + + +def npu_attn_forward(q, k, v, softmax_scale=None, layout="BSND"): + assert HAS_NPU, "torch_npu is not available" + softmax_scale = q.shape[-1] ** (-0.5) + block_out, block_lse = torch_npu.npu_fused_infer_attention_score( + q, + k, + v, + num_heads=q.shape[-2], + input_layout=layout, + scale=softmax_scale, + softmax_lse_flag=True, + pre_tokens=65535, + next_tokens=65535, + ) + return block_out, block_lse.squeeze(dim=-1) diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_selector.py b/vllm_omni/diffusion/attention/backends/ring/ring_selector.py new file mode 100644 index 00000000000..56e4a2e6b2f --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_selector.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +from collections.abc import Callable +from enum import Enum +from functools import partial + +import torch + +from .ring_globals import ( + HAS_NPU, + HAS_SAGE_ATTENTION, + HAS_SPARSE_SAGE_ATTENTION, +) +from .ring_kernels import ( + flash_attn3_func_forward, + flash_attn_forward, + flash_attn_forward_aiter, + flashinfer_attn_forward, + pytorch_attn_forward, +) + +if HAS_SAGE_ATTENTION: + import sageattention + +if HAS_SPARSE_SAGE_ATTENTION: + from spas_sage_attn.autotune import SparseAttentionMeansim + +if HAS_NPU: + from torch_npu import npu_fused_infer_attention_score + + +class AttnType(Enum): + AITER = "aiter" + FA = "fa" + FA3 = "fa3" + FLASHINFER = "flashinfer" + TORCH = "torch" + SAGE_AUTO = "sage_auto" + SAGE_FP16 = "sage_fp16" + SAGE_FP16_TRITON = "sage_fp16_triton" + SAGE_FP8 = "sage_fp8" + SAGE_FP8_SM90 = "sage_fp8_sm90" + SPARSE_SAGE = "sparse_sage" + NPU = "npu" + + @classmethod + def from_string(cls, s: str): + for member in cls: + if member.value == s: + return member + raise ValueError(f"'{s}' is not a valid {cls.__name__}") + + +def select_flash_attn_impl( + impl_type: AttnType, + stage: str = "fwd-only", + attn_processor: torch.nn.Module | None = None, +) -> Callable[..., tuple[torch.Tensor, torch.Tensor | None]]: + """Select attention implementation for forward pass (inference only). + + Args: + impl_type: The attention implementation type. + stage: Must be "fwd-only" (backward not supported for inference). + attn_processor: Optional custom attention processor. + + Returns: + Callable[..., tuple[torch.Tensor, torch.Tensor | None]]: The attention + forward function for the specified implementation. + """ + if stage != "fwd-only": + raise ValueError(f"Only 'fwd-only' stage is supported for inference. Got: {stage}") + + if impl_type == AttnType.AITER: + return flash_attn_forward_aiter + + elif impl_type == AttnType.FA: + return flash_attn_forward + + elif impl_type == AttnType.FA3: + return flash_attn3_func_forward + + elif impl_type == AttnType.FLASHINFER: + return flashinfer_attn_forward + + elif impl_type == AttnType.TORCH: + return pytorch_attn_forward + + elif impl_type == AttnType.SAGE_AUTO: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn, + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP16: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp16_cuda, + pv_accum_dtype="fp32", + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP16_TRITON: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp16_triton, + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP8: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp8_cuda, + pv_accum_dtype="fp32+fp32", + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP8_SM90: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp8_cuda_sm90, + pv_accum_dtype="fp32+fp32", + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SPARSE_SAGE: + if not HAS_SPARSE_SAGE_ATTENTION: + raise ImportError("SparseSageAttention is not available!") + if not isinstance(attn_processor, SparseAttentionMeansim): + raise ImportError("SparseSageAttention is only available with a SparseAttentionProcessor class passed in") + + def fn(q, k, v, causal=False, softmax_scale=None, *args, **kwargs): + return ( + attn_processor( + q, + k, + v, + is_causal=causal, + scale=softmax_scale, + tensor_layout="NHD", + ), + None, + ) + + return fn + + elif impl_type == AttnType.NPU: + if not HAS_NPU: + raise ImportError("torch_npu is not available!") + return npu_fused_infer_attention_score + + elif attn_processor is not None: + return attn_processor + + else: + raise ValueError(f"Unknown flash attention implementation: {impl_type}") diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_utils.py b/vllm_omni/diffusion/attention/backends/ring/ring_utils.py new file mode 100644 index 00000000000..c256f62cbd9 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_utils.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + + +import torch +import torch.nn.functional as F + +__all__ = ["update_out_and_lse", "flatten_varlen_lse", "unflatten_varlen_lse"] + + +# Remove torch.jit.script for debugging and flexible shape handling +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.to(torch.float32) + + B, S, H, D = out.shape + + # --- Shape Correction Logic for block_lse --- + # Goal: block_lse should be (B, S, H, 1) to match out (B, S, H, D) + + # Debug info + # print(f"DEBUG _update: out={out.shape}, block_lse={block_lse.shape}") + + # Case 0: If block_lse is already 4D, check if it matches + if block_lse.dim() == 4: + if block_lse.shape[1] == S and block_lse.shape[2] == H: + pass # Good + elif block_lse.shape[1] == H and block_lse.shape[2] == S: + block_lse = block_lse.transpose(1, 2) + elif block_lse.shape[1] == H and block_lse.shape[2] >= S: # Padding case + block_lse = block_lse[:, :, :S, :].transpose(1, 2) + # If shape is (B, H, S, 1) but expected (B, S, H, 1) because out is (B, S, H, D) + elif block_lse.shape[1] == H and block_lse.shape[2] == S and block_lse.shape[3] == 1: + block_lse = block_lse.transpose(1, 2) + + # Case 1: block_lse is 3D (B, H, S) or (B, S, H) or (B, ?, ?) + elif block_lse.dim() == 3: + # Check for (B, H, S) - Standard SDPA/FA output + if block_lse.shape[1] == H and block_lse.shape[2] == S: + block_lse = block_lse.transpose(1, 2).unsqueeze(-1) + + # Check for (B, S, H) + elif block_lse.shape[1] == S and block_lse.shape[2] == H: + block_lse = block_lse.unsqueeze(-1) + + # Check for Padding: (B, H, S_pad) where S_pad >= S + elif block_lse.shape[1] == H and block_lse.shape[2] >= S: + # print(f"DEBUG: Trimming padding from lse. {block_lse.shape} -> S={S}") + block_lse = block_lse[:, :, :S].transpose(1, 2).unsqueeze(-1) + + # Check for weird case: (B, S, H_pad) ? Unlikely for LSE but possible + elif block_lse.shape[1] == S and block_lse.shape[2] >= H: + block_lse = block_lse[:, :, :H].unsqueeze(-1) + + # Check for flipped weird case: (B, S_pad, H) + elif block_lse.shape[1] >= S and block_lse.shape[2] == H: + block_lse = block_lse[:, :S, :].unsqueeze(-1) + + # --- Shape Correction for lse (internal state) --- + # Ensure lse matches block_lse's corrected shape (B, S, H, 1) + if lse.shape != block_lse.shape: + # If lse was initialized with wrong shape, try to fix it + if lse.dim() == 4 and lse.shape[1] == block_lse.shape[2] and lse.shape[2] == block_lse.shape[1]: + lse = lse.transpose(1, 2) + elif lse.shape[1] >= S: # slice if lse was initialized with padding + lse = lse[:, :S, :, :] + + # Final check + if lse.shape != block_lse.shape: + # Force broadcast if possible? + pass + + try: + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + except RuntimeError as e: + print(f"ERROR in _update_out_and_lse: {e}") + print(f"out: {out.shape}, lse: {lse.shape}") + print(f"block_out: {block_out.shape}, block_lse: {block_lse.shape}") + # raise e + raise e + + return out, lse + + +def update_out_and_lse( + out: torch.Tensor | None, + lse: torch.Tensor | None, + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + + out = block_out.to(torch.float32) + + # Initialize LSE with robust logic (same as _update) + B, D1, D2, D3 = out.shape + + S_guess = D1 + H_guess = D2 + + if block_lse.dim() == 3: + if block_lse.shape[1] == H_guess and block_lse.shape[2] == S_guess: + lse = block_lse.transpose(1, 2).unsqueeze(-1) + elif block_lse.shape[1] == S_guess and block_lse.shape[2] == H_guess: + lse = block_lse.unsqueeze(-1) + elif block_lse.shape[1] == H_guess and block_lse.shape[2] >= S_guess: # Padding + lse = block_lse[:, :, :S_guess].transpose(1, 2).unsqueeze(-1) + elif block_lse.shape[1] == S_guess and block_lse.shape[2] >= H_guess: # Padding/Weird + lse = block_lse[:, :, :H_guess].unsqueeze(-1) + elif block_lse.shape[1] >= S_guess and block_lse.shape[2] == H_guess: + lse = block_lse[:, :S_guess, :].unsqueeze(-1) + + # Reverse case: What if out is (B, H, S, D) so S=D2, H=D1? + elif block_lse.shape[1] == D1 and block_lse.shape[2] >= D2: # Matches (H, S) + # Then out is (B, H, S, D). We should transpose out! + out = out.transpose(1, 2) + lse = block_lse[:, :, :D2].transpose(1, 2).unsqueeze(-1) # (B, S, H, 1) + + else: + # Fallback + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + else: + # Case 0: If block_lse is already 4D, check if it matches + if block_lse.dim() == 4: + if block_lse.shape[1] == S_guess and block_lse.shape[2] == H_guess: + lse = block_lse + elif block_lse.shape[1] == H_guess and block_lse.shape[2] == S_guess: + lse = block_lse.transpose(1, 2) + elif block_lse.shape[1] == H_guess and block_lse.shape[2] >= S_guess: # Padding case + lse = block_lse[:, :, :S_guess, :].transpose(1, 2) + elif block_lse.shape[1] == D1 and block_lse.shape[2] >= D2: # Matches (H, S) + # Then out is (B, H, S, D). We should transpose out! + out = out.transpose(1, 2) + lse = block_lse[:, :, :D2].transpose(1, 2) # (B, S, H, 1) + else: + lse = block_lse + else: + lse = block_lse + + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() diff --git a/vllm_omni/diffusion/attention/backends/ring_flash_attn.py b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py new file mode 100644 index 00000000000..e163ef1ace1 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + + +import torch + +from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType, select_flash_attn_impl +from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse +from vllm_omni.diffusion.distributed.comm import RingComm + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + attn_type: AttnType = AttnType.FA, + attn_processor=None, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + # Check and adjust q, k, v to be contiguous + if not q.is_contiguous(): + q = q.contiguous() + if not k.is_contiguous(): + k = k.contiguous() + if not v.is_contiguous(): + v = v.contiguous() + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor + next_v: torch.Tensor + next_k = comm.send_recv(k) + next_v = comm.send_recv(v) + comm.commit() + + if not causal or step <= comm.rank: + step_k = k + step_v = v + if step == 0 and joint_tensor_key is not None: + if joint_strategy == "front": + step_k = torch.cat([joint_tensor_key, step_k], dim=1) + step_v = torch.cat([joint_tensor_value, step_v], dim=1) + else: + step_k = torch.cat([step_k, joint_tensor_key], dim=1) + step_v = torch.cat([step_v, joint_tensor_value], dim=1) + + fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor) + block_out, block_lse = fn( + q, + step_k, + step_v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal and step == 0, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + # Ensure block_out is contiguous if needed, though usually it is from FA + + if attn_type == AttnType.SPARSE_SAGE: + out, lse = block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + if attn_type != AttnType.SPARSE_SAGE: + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +class RingFlashAttnFunc(torch.autograd.Function): + """Ring Flash Attention autograd function (inference only, no backward).""" + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + group, + attn_type, + attn_processor, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=False, + attn_type=attn_type, + attn_processor=attn_processor, + joint_tensor_key=joint_tensor_key, + joint_tensor_value=joint_tensor_value, + joint_strategy=joint_strategy, + ) + return out if not return_softmax else (out, softmax_lse, None) + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + None, # attn_processor + None, # joint_tensor_key + None, # joint_tensor_value + "front", # joint_strategy + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + None, # attn_processor + None, # joint_tensor_key + None, # joint_tensor_value + "front", # joint_strategy + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, + attn_processor=None, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, None]: + """Ring Attention forward pass using Flash Attention backend. + + Implements Ring Attention with sequence parallelism using a ring-based P2P + communication pattern. The sequence dimension is sharded across devices, and + Key/Value blocks are circulated through the ring to accumulate attention results. + + Args: + q (torch.Tensor): Query tensor of shape (batch, seq_len, num_heads, head_dim). + Sequence dimension is sharded across the ring group. + k (torch.Tensor): Key tensor of shape (batch, seq_len, num_heads, head_dim). + Sequence dimension is sharded across the ring group. + v (torch.Tensor): Value tensor of shape (batch, seq_len, num_heads, head_dim). + Sequence dimension is sharded across the ring group. + dropout_p (float): Dropout probability. Defaults to 0.0. + softmax_scale (float | None): Scaling factor for softmax. + If None, computed as head_dim^(-0.5). + causal (bool): Whether to apply causal masking. Defaults to False. + window_size (tuple[int, int]): Sliding window size for attention. + (-1, -1) means no windowing. + softcap (float): Soft capping value for attention logits. Defaults to 0.0. + alibi_slopes (torch.Tensor | None): ALiBi slopes for positional bias. + Not supported. + deterministic (bool): Whether to use deterministic algorithms. + Defaults to False. + return_attn_probs (bool): If True, returns (out, softmax_lse, None). + Defaults to False. + group (ProcessGroup | None): Process group for ring communication. + Defaults to None. + attn_type (AttnType): Flash Attention implementation type + (AttnType.FA, AttnType.FA3, etc.). + attn_processor (Callable | None): Custom attention processor for sparse + attention. Defaults to None. + joint_tensor_key (torch.Tensor | None): Additional key tensor for joint + attention (e.g., text + image). Concatenated only at step=0. + Defaults to None. + joint_tensor_value (torch.Tensor | None): Additional value tensor for + joint attention (e.g., text + image). Concatenated only at step=0. + Defaults to None. + joint_strategy (str): Concatenation strategy ("front" or "back"). + Defaults to "front". + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]: + - If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim). + - If return_attn_probs is True: A tuple (out, softmax_lse, None). + """ + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + attn_processor, + joint_tensor_key, + joint_tensor_value, + joint_strategy, + ) diff --git a/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py new file mode 100644 index 00000000000..43ee35f7098 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +# adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py +# Copyright 2024 The HuggingFace Inc. team and Jiarui Fang. + + +import torch + +from vllm_omni.diffusion.attention.backends.ring.ring_kernels import pytorch_attn_forward +from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse +from vllm_omni.diffusion.distributed.comm import RingComm + + +def ring_pytorch_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + op_type="efficient", + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", +): + return RingAttentionFunc.apply( + group, + q, + k, + v, + softmax_scale, + causal, + op_type, + joint_tensor_key, + joint_tensor_value, + joint_strategy, + ) + + +class RingAttentionFunc(torch.autograd.Function): + """Ring Attention autograd function using PyTorch SDPA (inference only, no backward).""" + + @staticmethod + def forward( + ctx, + group, + q, + k, + v, + sm_scale, + is_causal, + op_type, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", + ): + comm = RingComm(group) + # Ensure tensors are contiguous for P2P communication + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + out, lse = None, None + next_k, next_v = None, None + + if sm_scale is None: + sm_scale = q.shape[-1] ** -0.5 + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k = comm.send_recv(k) + next_v = comm.send_recv(v) + comm.commit() + + if not is_causal or step <= comm.rank: + step_k = k + step_v = v + if step == 0 and joint_tensor_key is not None: + if joint_strategy == "front": + step_k = torch.cat([joint_tensor_key, step_k], dim=1) + step_v = torch.cat([joint_tensor_value, step_v], dim=1) + else: + step_k = torch.cat([step_k, joint_tensor_key], dim=1) + step_v = torch.cat([step_v, joint_tensor_value], dim=1) + + block_out, block_lse = pytorch_attn_forward( + q, + step_k, + step_v, + softmax_scale=sm_scale, + causal=is_causal and step == 0, + op_type=op_type, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + + return out diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index a565e6aefde..25a12bbd476 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -6,12 +6,20 @@ # Adapted from # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py + import torch import torch.nn as nn +from vllm.logger import init_logger from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy +from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention from vllm_omni.diffusion.attention.selector import get_attn_backend +from vllm_omni.diffusion.data import get_current_omni_diffusion_config +from vllm_omni.diffusion.distributed.parallel_state import get_sp_group +from vllm_omni.utils.platform_utils import is_npu + +logger = init_logger(__name__) class Attention(nn.Module): @@ -43,9 +51,28 @@ def __init__( self.scatter_idx = scatter_idx self.gather_idx = gather_idx self.use_sync = use_sync - # Parallel attention (communication / resharding) is a pluggable strategy. - # This keeps the attention kernel backend selection orthogonal. - self.parallel = build_parallel_attention_strategy( + self.causal = causal + + self.use_ring = False + self.ring_pg = None + self.ring_runner = None + + try: + config = get_current_omni_diffusion_config() + if config.parallel_config.ring_degree > 1: + self.use_ring = True + try: + sp_group = get_sp_group() + self.ring_pg = sp_group.ring_group + self.ring_runner = RingParallelAttention(sp_group) + except Exception: + self.use_ring = False + self.ring_runner = None + except Exception: + self.use_ring = False + self.ring_runner = None + + self.parallel_strategy = build_parallel_attention_strategy( scatter_idx=scatter_idx, gather_idx=gather_idx, use_sync=use_sync, @@ -58,13 +85,59 @@ def forward( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - # Parallel strategy may reshard/communicate QKV before the attention kernel. - query, key, value, attn_metadata, ctx = self.parallel.pre_attention(query, key, value, attn_metadata) + # 1. Prepare inputs (Communication / Resharding) + # For Ulysses: AllToAll Q/K/V; Slicing joint_q/k/v + # For Ring: Concat joint_q + query, key, value, attn_metadata, ctx = self.parallel_strategy.pre_attention(query, key, value, attn_metadata) + + # 2. Kernel Execution (Computation) + if self.use_ring: + out = self._run_ring_attention(query, key, value, attn_metadata) + else: + out = self._run_local_attention(query, key, value, attn_metadata) + + # 3. Post-processing (Reverse Communication) + # For Ulysses: AllToAll Output, and AllGather Joint Output + out = self.parallel_strategy.post_attention(out, ctx) + + return out + + def _run_local_attention(self, query, key, value, attn_metadata): + # Check backend preference from config + try: + config = get_current_omni_diffusion_config() + backend_pref = config.attention_backend + except Exception: + backend_pref = None + + if backend_pref == "flash_attn" and query.dtype == torch.float32: + logger.warning( + "Flash Attention does not support float32. Overriding user config " + f"attention_backend='{backend_pref}' to 'sdpa' for dtype={query.dtype}." + ) + backend_pref = "sdpa" + + if is_npu(): + return self.attention( + query, + key, + value, + num_heads=query.shape[-2], + input_layout="BSND", + scale=self.softmax_scale, + softmax_lse_flag=True, + pre_tokens=65535, + next_tokens=65535, + )[0] + + # Fallback to standard attention + return self.attention.forward(query, key, value, attn_metadata) - # shape: (batch_size, seq_len, num_heads, head_size) - attn_output = self.attention.forward(query, key, value, attn_metadata) - if isinstance(attn_output, tuple): - attn_output = attn_output[0] + def _run_ring_attention(self, query, key, value, attn_metadata): + # Delegate to RingParallelAttention strategy if available + if self.ring_runner is not None: + return self.ring_runner.run_attention( + query, key, value, attn_metadata, softmax_scale=self.softmax_scale, causal=self.causal + ) - # Parallel strategy may need to reverse resharding after the kernel. - return self.parallel.post_attention(attn_output, ctx) + raise RuntimeError("Ring attention is enabled but strategy is not RingParallelAttention") diff --git a/vllm_omni/diffusion/attention/parallel/factory.py b/vllm_omni/diffusion/attention/parallel/factory.py index ee0113fc9f0..19e2c9ae158 100644 --- a/vllm_omni/diffusion/attention/parallel/factory.py +++ b/vllm_omni/diffusion/attention/parallel/factory.py @@ -4,9 +4,10 @@ from __future__ import annotations from vllm_omni.diffusion.attention.parallel.base import NoParallelAttention, ParallelAttentionStrategy +from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention from vllm_omni.diffusion.attention.parallel.ulysses import UlyssesParallelAttention from vllm_omni.diffusion.data import get_current_omni_diffusion_config -from vllm_omni.diffusion.distributed.parallel_state import get_sp_group +from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size, get_sp_group def build_parallel_attention_strategy( @@ -28,13 +29,19 @@ def build_parallel_attention_strategy( except Exception: return NoParallelAttention() - # Current implementation supports Ulysses sequence-parallel attention. - # We intentionally do NOT infer ring attention from ring_degree here yet. - if getattr(p, "ulysses_degree", 1) > 1: - try: - sp_group = get_sp_group() - except Exception: + ulysses_degree = getattr(p, "ulysses_degree", 1) + ring_degree = getattr(p, "ring_degree", 1) + + try: + sp_group = get_sp_group() + # Ensure SP group is initialized and world size > 1 + if get_sequence_parallel_world_size() <= 1: return NoParallelAttention() + except Exception: + return NoParallelAttention() + + # Ulysses (or Hybrid Ulysses+Ring) + if ulysses_degree > 1: return UlyssesParallelAttention( sp_group=sp_group, scatter_idx=scatter_idx, @@ -42,4 +49,10 @@ def build_parallel_attention_strategy( use_sync=use_sync, ) + # Pure Ring Attention + if ring_degree > 1: + return RingParallelAttention( + sp_group=sp_group, + ) + return NoParallelAttention() diff --git a/vllm_omni/diffusion/attention/parallel/ring.py b/vllm_omni/diffusion/attention/parallel/ring.py new file mode 100644 index 00000000000..784e28bd519 --- /dev/null +++ b/vllm_omni/diffusion/attention/parallel/ring.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from vllm.logger import init_logger + +# import torch.distributed as dist # Not used directly here, but good practice if needed +from vllm_omni.diffusion.attention.backends.ring.ring_globals import HAS_FLASH_ATTN +from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType +from vllm_omni.diffusion.attention.parallel.base import ( + ParallelAttentionContext, + # ParallelAttentionStrategy, # Not used in type hint below currently +) + +# from vllm_omni.diffusion.attention.backends.ring_selector import AttnType # Already imported above +from vllm_omni.diffusion.data import get_current_omni_diffusion_config +from vllm_omni.diffusion.distributed.group_coordinator import SequenceParallelGroupCoordinator + +if TYPE_CHECKING: + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + + +@dataclass(frozen=True, slots=True) +class _RingCtx(ParallelAttentionContext): + """Per-forward context for Ring sequence-parallel attention.""" + + # Ring attention typically doesn't need complex context for post-processing + # as the output is already correctly sharded along sequence dimension. + pass + + +class RingParallelAttention: + """Ring sequence-parallel strategy. + + This strategy prepares inputs for Ring Attention. + Key responsibilities: + - Concatenate joint_query (Text) to query (Image) if present. + - Keep joint_key/value separate in metadata for the Ring kernel to handle as static prefix. + """ + + def __init__( + self, + sp_group: SequenceParallelGroupCoordinator, + attn_backend_pref: str | None = None, + ) -> None: + self._sp_group = sp_group + self.attn_backend_pref = attn_backend_pref + + @property + def enabled(self) -> bool: + return True + + @property + def name(self) -> str: + return "ring" + + def pre_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + ): + joint_tensor_query = None + joint_strategy = "front" + + if attn_metadata is not None: + joint_tensor_query = attn_metadata.joint_query + joint_strategy = attn_metadata.joint_strategy + + if joint_tensor_query is not None: + supported_joint_strategy = ["front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError(f"joint_strategy: {joint_strategy} not supported.") + + if joint_strategy == "front": + query = torch.cat([joint_tensor_query, query], dim=1) + else: + query = torch.cat([query, joint_tensor_query], dim=1) + + # Note: We do NOT concatenate joint_key/value here. + # They are preserved in attn_metadata and will be passed + # explicitly to ring_flash_attn_func. + + ctx = _RingCtx(name=self.name) + return query, key, value, attn_metadata, ctx + + def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: + # Ring attention output is already sharded correctly along sequence dimension. + return attn_output + + def run_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + softmax_scale: float | None = None, + causal: bool = False, + ) -> torch.Tensor: + """Run the actual Ring Attention kernel.""" + if softmax_scale is None: + softmax_scale = query.shape[-1] ** -0.5 + + backend_pref = self.attn_backend_pref + if backend_pref is None: + try: + config = get_current_omni_diffusion_config() + # config might not have attention_backend attribute if not updated + backend_pref = getattr(config, "attention_backend", None) + except Exception: + backend_pref = None + + # Fallback for FP32 or if Flash Attention is not available + if query.dtype == torch.float32 or not HAS_FLASH_ATTN: + if not HAS_FLASH_ATTN and backend_pref != "sdpa": + logger = init_logger(__name__) + logger.warning("Flash Attention is not available! Force enabling SDPA.") + backend_pref = "sdpa" + + # Extract joint tensors + joint_key, joint_value = None, None + joint_strategy = "front" + if attn_metadata is not None: + joint_key = attn_metadata.joint_key + joint_value = attn_metadata.joint_value + if attn_metadata.joint_strategy is not None: + joint_strategy = attn_metadata.joint_strategy + + if backend_pref == "sdpa" or backend_pref == "torch": + from vllm_omni.diffusion.attention.backends.ring_pytorch_attn import ring_pytorch_attn_func + + return ring_pytorch_attn_func( + query, + key, + value, + softmax_scale=softmax_scale, + causal=causal, + group=self._sp_group.ring_group, + op_type="efficient", + joint_tensor_key=joint_key, + joint_tensor_value=joint_value, + joint_strategy=joint_strategy, + ) + + from vllm_omni.diffusion.attention.backends.ring_flash_attn import ring_flash_attn_func + + return ring_flash_attn_func( + query, + key, + value, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + # return_attn_probs=False, # Removed as it might not be supported in signature + group=self._sp_group.ring_group, + attn_type=AttnType.FA, + joint_tensor_key=joint_key, + joint_tensor_value=joint_value, + joint_strategy=joint_strategy, + ) diff --git a/vllm_omni/diffusion/attention/parallel/ulysses.py b/vllm_omni/diffusion/attention/parallel/ulysses.py index acf4684867b..1a368d5d788 100644 --- a/vllm_omni/diffusion/attention/parallel/ulysses.py +++ b/vllm_omni/diffusion/attention/parallel/ulysses.py @@ -22,6 +22,8 @@ class _UlyssesCtx(ParallelAttentionContext): scatter_idx: int gather_idx: int use_sync: bool + joint_len: int = 0 + joint_strategy: str = "front" class UlyssesParallelAttention: @@ -29,8 +31,8 @@ class UlyssesParallelAttention: This preserves the semantics previously implemented in `Attention._forward_ulysses`: - - If `AttentionMetadata.joint_*` is provided, joint_query is concatenated to - query *before* all-to-all; joint_key/value are concatenated *after* all-to-all. + - If `AttentionMetadata.joint_*` is provided, joint_query/key/value are + concatenated *after* all-to-all. - joint_key/value are assumed to be replicated across SP ranks and are sliced by ulysses head rank before concatenation. """ @@ -64,7 +66,8 @@ def pre_attention( attn_metadata: AttentionMetadata | None, ): joint_tensor_query = joint_tensor_key = joint_tensor_value = None - joint_strategy = None + joint_strategy = "front" + joint_len = 0 if attn_metadata is not None: joint_tensor_query = attn_metadata.joint_query @@ -80,10 +83,22 @@ def pre_attention( f"joint_strategy: {joint_strategy} not supported." f" supported joint strategy: {supported_joint_strategy}" ) - if joint_strategy == "rear": - query = torch.cat([query, joint_tensor_query], dim=1) - else: - query = torch.cat([joint_tensor_query, query], dim=1) + + # Slice joint_query for this Ulysses rank + # joint_query is (B, S, H, D). We split H (dim 2). + ulysses_world_size = self._sp_group.ulysses_world_size + ulysses_rank = self._sp_group.ulysses_rank + attn_heads_per_ulysses_rank = joint_tensor_query.shape[-2] // ulysses_world_size + + # Note: We use the same heads for Q/K/V + joint_tensor_query = joint_tensor_query[ + ..., + attn_heads_per_ulysses_rank * ulysses_rank : attn_heads_per_ulysses_rank * (ulysses_rank + 1), + :, + ] + + joint_len = joint_tensor_query.shape[1] + is_joint = True elif joint_tensor_query is None and joint_tensor_key is None and joint_tensor_value is None: pass @@ -92,27 +107,46 @@ def pre_attention( if is_joint: # Slice joint key/value heads for this ulysses rank. - ulysses_world_size = self._sp_group.ulysses_world_size - ulysses_rank = self._sp_group.ulysses_rank - attn_heads_per_ulysses_rank = joint_tensor_key.shape[-2] // ulysses_world_size + # Using same slicing logic as query + attn_heads_per_ulysses_rank_kv = joint_tensor_key.shape[-2] // ulysses_world_size + joint_tensor_key = joint_tensor_key[ ..., - attn_heads_per_ulysses_rank * ulysses_rank : attn_heads_per_ulysses_rank * (ulysses_rank + 1), + attn_heads_per_ulysses_rank_kv * ulysses_rank : attn_heads_per_ulysses_rank_kv * (ulysses_rank + 1), :, ] joint_tensor_value = joint_tensor_value[ ..., - attn_heads_per_ulysses_rank * ulysses_rank : attn_heads_per_ulysses_rank * (ulysses_rank + 1), + attn_heads_per_ulysses_rank_kv * ulysses_rank : attn_heads_per_ulysses_rank_kv * (ulysses_rank + 1), :, ] + # Update metadata with sliced tensors so Ring attention can use them if needed + if attn_metadata is not None: + attn_metadata.joint_key = joint_tensor_key + attn_metadata.joint_value = joint_tensor_value + # (bs, seq_len/P, head_cnt, head_size) -> (bs, seq_len, head_cnt/P, head_size) query = SeqAllToAll4D.apply(self._ulysses_pg, query, self._scatter_idx, self._gather_idx, self._use_sync) key = SeqAllToAll4D.apply(self._ulysses_pg, key, self._scatter_idx, self._gather_idx, self._use_sync) value = SeqAllToAll4D.apply(self._ulysses_pg, value, self._scatter_idx, self._gather_idx, self._use_sync) if is_joint: - # Concatenate joint key/value after all-to-all (matches previous implementation). + # Concatenate joint query AFTER AllToAll + # Image query is now (B, S, H/P, D). Joint query is (B, S_txt, H/P, D). + # This is dimensionally consistent. + if joint_strategy == "rear": + query = torch.cat([query, joint_tensor_query], dim=1) + else: + query = torch.cat([joint_tensor_query, query], dim=1) + + # Check if Ring Attention is also active (Hybrid mode) + # If Ring is active, we should NOT concatenate joint_key/value to k/v here. + # Instead, they should remain in attn_metadata and be passed to the Ring kernel. + use_ring = self._sp_group.ring_world_size > 1 + + if is_joint and not use_ring: + # Concatenate joint key/value after all-to-all ONLY for pure Ulysses (Local Attention). if joint_strategy == "front": key = torch.cat([joint_tensor_key, key], dim=1) value = torch.cat([joint_tensor_value, value], dim=1) @@ -126,10 +160,47 @@ def pre_attention( scatter_idx=self._scatter_idx, gather_idx=self._gather_idx, use_sync=self._use_sync, + joint_len=joint_len, + joint_strategy=joint_strategy, ) return query, key, value, attn_metadata, ctx def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}" - # Reverse: (bs, seq_len, head_cnt/P, head_size) -> (bs, seq_len/P, head_cnt, head_size) + + # If we have joint tensors (Text), they were Head-Sliced. + # The main sequence (Image) was Sequence-Sliced. + # attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front'). + + if ctx.joint_len > 0: + joint_len = ctx.joint_len + + if ctx.joint_strategy == "front": + output_joint = attn_output[:, :joint_len] + output_img = attn_output[:, joint_len:] + else: + output_img = attn_output[:, :-joint_len] + output_joint = attn_output[:, -joint_len:] + + # 1. Process Image part: Standard Ulysses Reverse (AllToAll) + # (bs, seq_len, head_cnt/P, head_size) -> (bs, seq_len/P, head_cnt, head_size) + # SeqAllToAll4D handles: Scatter gather_idx, Gather scatter_idx. + # Forward: Scatter 2 (H), Gather 1 (S). + # Reverse: Scatter 1 (S), Gather 2 (H). + output_img = SeqAllToAll4D.apply(ctx.ulysses_pg, output_img, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync) + + # 2. Process Joint part: AllGather on Heads + # Input: (B, JointLen, H/P, D). Output: (B, JointLen, H, D). + # AllGather along dim 2. + gathered_joint = [torch.zeros_like(output_joint) for _ in range(dist.get_world_size(ctx.ulysses_pg))] + dist.all_gather(gathered_joint, output_joint, group=ctx.ulysses_pg) + output_joint = torch.cat(gathered_joint, dim=2) + + # 3. Recombine + if ctx.joint_strategy == "front": + return torch.cat([output_joint, output_img], dim=1) + else: + return torch.cat([output_img, output_joint], dim=1) + + # Standard Ulysses Reverse return SeqAllToAll4D.apply(ctx.ulysses_pg, attn_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index bef96a46416..6f54e9bd19b 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -245,7 +245,7 @@ class OmniDiffusionConfig: tf_model_config: TransformerConfig = field(default_factory=TransformerConfig) # Attention - # attention_backend: str = None + attention_backend: str | None = None # Running mode # mode: ExecutionMode = ExecutionMode.INFERENCE diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py index b5f7aa32a4f..ada03f4a1ee 100644 --- a/vllm_omni/diffusion/distributed/comm.py +++ b/vllm_omni/diffusion/distributed/comm.py @@ -8,6 +8,8 @@ import torch.distributed as dist from torch import Tensor +__all__ = ["all_to_all_4D", "all_to_all_5D", "SeqAllToAll4D", "SeqAllToAll5D", "RingComm"] + def all_to_all_4D( input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False @@ -219,3 +221,54 @@ def forward( ctx.use_sync = use_sync return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + +class RingComm: + """Ring communication utility for Ring Attention P2P communication.""" + + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: torch.Tensor | None = None) -> torch.Tensor: + # Ensure to_send is contiguous for P2P + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + if recv_tensor is None: + # Create a contiguous buffer for receiving + res = torch.empty_like(to_send, memory_format=torch.contiguous_format) + # print(f"send_recv: empty_like {to_send.shape}") + else: + res = recv_tensor + if not res.is_contiguous(): + res = res.contiguous() + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 9235b1a76e0..4c585953b1d 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -146,6 +146,14 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Ulysses Sequence Parallelism degree for diffusion models. " "Equivalent to setting DiffusionParallelConfig.ulysses_degree.", ) + serve_parser.add_argument( + "--ring", + dest="ring_degree", + type=int, + default=None, + help="Ring Sequence Parallelism degree for diffusion models. " + "Equivalent to setting DiffusionParallelConfig.ring_degree.", + ) # Cache optimization parameters serve_parser.add_argument(