From 3f08925bb7c54af8ff57bec67b47900ac8df8ea5 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 21:19:20 +0800 Subject: [PATCH 1/5] Fix diffusion parallel override normalization in stage config Co-authored-by: zzhuoxin1508 <234137171+zzhuoxin1508@users.noreply.github.com> Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_config_factory.py | 141 +++++++++++++++++++++++++++++++ vllm_omni/config/stage_config.py | 35 +++++++- 2 files changed, 172 insertions(+), 4 deletions(-) diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 7b620ea6e80..df39d0bcc98 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -165,6 +165,147 @@ def test_to_omegaconf_omits_none_deploy_overrides_for_engine_args(self): for name in deploy_override_field_names() - {"devices"}: assert name not in engine_args + def test_to_omegaconf_diffusion_parallel_overrides_replace_nested_values(self): + config = StageConfig( + stage_id=1, + model_stage="diffusion", + stage_type=StageType.DIFFUSION, + yaml_engine_args={ + "parallel_config": { + "pipeline_parallel_size": 1, + "data_parallel_size": 1, + "tensor_parallel_size": 4, + "enable_expert_parallel": False, + "ulysses_degree": 1, + "ring_degree": 1, + "ulysses_mode": "strict", + "sequence_parallel_size": 1, + "cfg_parallel_size": 1, + "vae_patch_parallel_size": 1, + "use_hsdp": False, + "hsdp_shard_size": -1, + "hsdp_replicate_size": 1, + } + }, + runtime_overrides={ + "pipeline_parallel_size": 2, + "data_parallel_size": 3, + "tensor_parallel_size": 8, + "enable_expert_parallel": True, + "ulysses_degree": 2, + "ring_degree": 4, + "ulysses_mode": "advanced_uaa", + "sequence_parallel_size": 8, + "cfg_parallel_size": 2, + "vae_patch_parallel_size": 2, + "use_hsdp": True, + "hsdp_shard_size": 8, + "hsdp_replicate_size": 2, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.parallel_config.pipeline_parallel_size == 2 + assert omega_config.engine_args.parallel_config.data_parallel_size == 3 + assert omega_config.engine_args.parallel_config.tensor_parallel_size == 8 + assert omega_config.engine_args.parallel_config.enable_expert_parallel is True + assert omega_config.engine_args.parallel_config.ulysses_degree == 2 + assert omega_config.engine_args.parallel_config.ring_degree == 4 + assert omega_config.engine_args.parallel_config.ulysses_mode == "advanced_uaa" + assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8 + assert omega_config.engine_args.parallel_config.cfg_parallel_size == 2 + assert omega_config.engine_args.parallel_config.vae_patch_parallel_size == 2 + assert omega_config.engine_args.parallel_config.use_hsdp is True + assert omega_config.engine_args.parallel_config.hsdp_shard_size == 8 + assert omega_config.engine_args.parallel_config.hsdp_replicate_size == 2 + assert "pipeline_parallel_size" not in omega_config.engine_args + assert "data_parallel_size" not in omega_config.engine_args + assert "tensor_parallel_size" not in omega_config.engine_args + assert "enable_expert_parallel" not in omega_config.engine_args + assert "ulysses_degree" not in omega_config.engine_args + assert "ring_degree" not in omega_config.engine_args + assert "ulysses_mode" not in omega_config.engine_args + assert "sequence_parallel_size" not in omega_config.engine_args + assert "cfg_parallel_size" not in omega_config.engine_args + assert "vae_patch_parallel_size" not in omega_config.engine_args + assert "use_hsdp" not in omega_config.engine_args + assert "hsdp_shard_size" not in omega_config.engine_args + assert "hsdp_replicate_size" not in omega_config.engine_args + + def test_to_omegaconf_diffusion_parallel_overrides_create_parallel_config(self): + config = StageConfig( + stage_id=1, + model_stage="diffusion", + stage_type=StageType.DIFFUSION, + runtime_overrides={ + "pipeline_parallel_size": 2, + "data_parallel_size": 3, + "tensor_parallel_size": 8, + "enable_expert_parallel": True, + "ulysses_degree": 2, + "ring_degree": 4, + "ulysses_mode": "advanced_uaa", + "sequence_parallel_size": 8, + "cfg_parallel_size": 2, + "vae_patch_parallel_size": 2, + "use_hsdp": True, + "hsdp_shard_size": 8, + "hsdp_replicate_size": 2, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.parallel_config.pipeline_parallel_size == 2 + assert omega_config.engine_args.parallel_config.data_parallel_size == 3 + assert omega_config.engine_args.parallel_config.tensor_parallel_size == 8 + assert omega_config.engine_args.parallel_config.enable_expert_parallel is True + assert omega_config.engine_args.parallel_config.ulysses_degree == 2 + assert omega_config.engine_args.parallel_config.ring_degree == 4 + assert omega_config.engine_args.parallel_config.ulysses_mode == "advanced_uaa" + assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8 + assert omega_config.engine_args.parallel_config.cfg_parallel_size == 2 + assert omega_config.engine_args.parallel_config.vae_patch_parallel_size == 2 + assert omega_config.engine_args.parallel_config.use_hsdp is True + assert omega_config.engine_args.parallel_config.hsdp_shard_size == 8 + assert omega_config.engine_args.parallel_config.hsdp_replicate_size == 2 + assert "pipeline_parallel_size" not in omega_config.engine_args + assert "data_parallel_size" not in omega_config.engine_args + assert "tensor_parallel_size" not in omega_config.engine_args + assert "enable_expert_parallel" not in omega_config.engine_args + assert "ulysses_degree" not in omega_config.engine_args + assert "ring_degree" not in omega_config.engine_args + assert "ulysses_mode" not in omega_config.engine_args + assert "sequence_parallel_size" not in omega_config.engine_args + assert "cfg_parallel_size" not in omega_config.engine_args + assert "vae_patch_parallel_size" not in omega_config.engine_args + assert "use_hsdp" not in omega_config.engine_args + assert "hsdp_shard_size" not in omega_config.engine_args + assert "hsdp_replicate_size" not in omega_config.engine_args + + def test_to_omegaconf_llm_parallel_overrides_remain_top_level(self): + config = StageConfig( + stage_id=0, + model_stage="thinker", + stage_type=StageType.LLM, + runtime_overrides={ + "pipeline_parallel_size": 2, + "data_parallel_size": 3, + "tensor_parallel_size": 8, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.pipeline_parallel_size == 2 + assert omega_config.engine_args.data_parallel_size == 3 + assert omega_config.engine_args.tensor_parallel_size == 8 + assert "pipeline_parallel_size" in omega_config.engine_args + assert "data_parallel_size" in omega_config.engine_args + assert "tensor_parallel_size" in omega_config.engine_args + assert "parallel_config" not in omega_config.engine_args + class TestModelPipeline: """Tests for ModelPipeline class.""" diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index 17c70302312..960b3af03dd 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -121,6 +121,29 @@ def strip_parent_engine_args( return result, sorted(overridden) +def _apply_diffusion_parallel_runtime_overrides( + engine_args: dict[str, Any], + runtime_overrides: dict[str, Any], +) -> None: + """Move diffusion parallel overrides into nested ``parallel_config``.""" + from vllm_omni.diffusion.data import DiffusionParallelConfig + + parallel_fields = frozenset(f.name for f in fields(DiffusionParallelConfig)) + parallel_config = engine_args.get("parallel_config") + parallel_config_dict = to_dict(parallel_config) if parallel_config is not None else None + + for key in list(runtime_overrides.keys()): + value = runtime_overrides.get(key) + if value is None or key not in parallel_fields: + continue + if parallel_config_dict is None: + parallel_config_dict = {} + parallel_config_dict[key] = runtime_overrides.pop(key) + + if parallel_config_dict is not None: + engine_args["parallel_config"] = parallel_config_dict + + class StageType(str, Enum): """Type of processing stage in the Omni pipeline.""" @@ -912,6 +935,7 @@ def to_omegaconf(self) -> Any: """TODO(@lishunyang12): remove once engine consumes ResolvedStageConfig directly.""" # Start with YAML engine_args defaults engine_args: dict[str, Any] = dict(self.yaml_engine_args) + runtime_overrides = dict(self.runtime_overrides) # Overlay topology-level fields engine_args["model_stage"] = self.model_stage @@ -922,20 +946,23 @@ def to_omegaconf(self) -> Any: if self.hf_config_name: engine_args["hf_config_name"] = self.hf_config_name + if StageType(self.stage_type) == StageType.DIFFUSION: + _apply_diffusion_parallel_runtime_overrides(engine_args, runtime_overrides) + # CLI overrides take precedence over YAML defaults - for key, value in self.runtime_overrides.items(): + for key, value in runtime_overrides.items(): if value is not None and key not in ("devices", "max_batch_size"): engine_args[key] = value # Build runtime config from YAML defaults + CLI overrides runtime: dict[str, Any] = dict(self.yaml_runtime) runtime.setdefault("process", True) - if self.runtime_overrides.get("devices") is not None: - runtime["devices"] = self.runtime_overrides["devices"] + if runtime_overrides.get("devices") is not None: + runtime["devices"] = runtime_overrides["devices"] # Legacy compat: migrate runtime.max_batch_size → engine_args.max_num_seqs legacy_mbs = runtime.pop("max_batch_size", None) - cli_mbs = self.runtime_overrides.get("max_batch_size") + cli_mbs = runtime_overrides.get("max_batch_size") if legacy_mbs is not None or cli_mbs is not None: warnings.warn( "runtime.max_batch_size is deprecated and will be removed in a " From 767ab8bcbeb3f731956c457ada4c749c1b42a169 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 21:49:33 +0800 Subject: [PATCH 2/5] Fix diffusion deploy override nullification Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_arg_utils.py | 9 ++++++--- tests/test_config_factory.py | 34 ++++++++++++++++++++++++-------- vllm_omni/config/stage_config.py | 12 +++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py index 2fd5cf302e0..8fc445a604d 100644 --- a/tests/test_arg_utils.py +++ b/tests/test_arg_utils.py @@ -393,16 +393,19 @@ def test_nullify_stage_engine_defaults_resets_inherited_defaults(): def test_non_override_flags_keep_real_defaults_after_nullify(): import argparse + from vllm_omni.config.stage_config import deploy_override_field_names from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults parser = argparse.ArgumentParser() - parser.add_argument("--hsdp-shard-size", type=int, default=-1, help="HSDP shard size.") + parser.add_argument("--batch-timeout", type=int, default=10, help="Batch timeout.") parser.add_argument("--max-num-seqs", type=int, default=64, help="Max num seqs.") nullify_stage_engine_defaults(parser) - hsdp = next(a for a in parser._actions if a.dest == "hsdp_shard_size") + assert "batch_timeout" not in deploy_override_field_names() + + batch_timeout = next(a for a in parser._actions if a.dest == "batch_timeout") max_num_seqs = next(a for a in parser._actions if a.dest == "max_num_seqs") - assert hsdp.default == -1 + assert batch_timeout.default == 10 assert max_num_seqs.default is None diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index df39d0bcc98..84da981852c 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -971,17 +971,16 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): expected_fields = { "async_chunk", + + # StageDeployConfig: stage placement and runtime fields. + "devices", + + # StageDeployConfig: vLLM EngineArgs fields. "async_scheduling", "compilation_config", "config_format", - "data_parallel_size", - "devices", "disable_hybrid_kv_cache_manager", - "distributed_executor_backend", - "dtype", - "enable_chunked_prefill", "enable_flashinfer_autotune", - "enable_prefix_caching", "enforce_eager", "gpu_memory_utilization", "load_format", @@ -989,13 +988,32 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): "max_num_batched_tokens", "max_num_seqs", "mm_processor_cache_gb", - "pipeline_parallel_size", "profiler_config", - "quantization", "skip_mm_profiling", "subtalker_sampling_params", "tensor_parallel_size", "tokenizer_mode", + + # StageDeployConfig: diffusion parallel_config deploy override fields. + "cfg_parallel_size", + "enable_expert_parallel", + "hsdp_replicate_size", + "hsdp_shard_size", + "ring_degree", + "sequence_parallel_size", + "ulysses_degree", + "ulysses_mode", + "use_hsdp", + "vae_patch_parallel_size", + + # DeployConfig: pipeline-wide engine settings. + "data_parallel_size", + "distributed_executor_backend", + "dtype", + "enable_chunked_prefill", + "enable_prefix_caching", + "pipeline_parallel_size", + "quantization", "trust_remote_code", } diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index 960b3af03dd..9882846f307 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -452,6 +452,18 @@ class StageDeployConfig: disable_hybrid_kv_cache_manager: bool | None = None mm_processor_cache_gb: float | None = None + # Diffusion parallel_config deploy override fields. + enable_expert_parallel: bool | None = None + ulysses_degree: int | None = None + ulysses_mode: str | None = None + ring_degree: int | None = None + sequence_parallel_size: int | None = None + cfg_parallel_size: int | None = None + vae_patch_parallel_size: int | None = None + use_hsdp: bool | None = None + hsdp_shard_size: int | None = None + hsdp_replicate_size: int | None = None + # Compilation, profiling, tokenizer/config parsing, and model loading. compilation_config: dict[str, Any] | None = None profiler_config: dict[str, Any] | None = None From f43e85f22efbd943ec87efe3601adb03d067f5ff Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 21:53:04 +0800 Subject: [PATCH 3/5] Fix deploy override field test formatting Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_config_factory.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 84da981852c..3ea37dc9ab4 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -971,10 +971,8 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): expected_fields = { "async_chunk", - # StageDeployConfig: stage placement and runtime fields. "devices", - # StageDeployConfig: vLLM EngineArgs fields. "async_scheduling", "compilation_config", @@ -993,7 +991,6 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): "subtalker_sampling_params", "tensor_parallel_size", "tokenizer_mode", - # StageDeployConfig: diffusion parallel_config deploy override fields. "cfg_parallel_size", "enable_expert_parallel", @@ -1005,7 +1002,6 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): "ulysses_mode", "use_hsdp", "vae_patch_parallel_size", - # DeployConfig: pipeline-wide engine settings. "data_parallel_size", "distributed_executor_backend", From 34ff3a6337b5717697b49a0a85837b41ef9237b4 Mon Sep 17 00:00:00 2001 From: dengyunyang <584797741@qq.com> Date: Sun, 10 May 2026 22:08:24 +0800 Subject: [PATCH 4/5] [Feature] hunyuanimage support flash attn (#2981) Signed-off-by: dengyunyang <584797741@qq.com> --- .../attention/test_piecewise_attn.py | 126 ++++++++++++++++++ .../diffusion/attention/backends/abstract.py | 4 + .../attention/backends/flash_attn.py | 26 ++++ .../backends/utils/piecewise_attn.py | 98 ++++++++++++++ .../hunyuan_image3_transformer.py | 13 ++ .../hunyuan_image3/pipeline_hunyuan_image3.py | 20 ++- .../omni_connectors/kv_transfer_manager.py | 5 + 7 files changed, 286 insertions(+), 6 deletions(-) create mode 100644 tests/diffusion/attention/test_piecewise_attn.py create mode 100644 vllm_omni/diffusion/attention/backends/utils/piecewise_attn.py diff --git a/tests/diffusion/attention/test_piecewise_attn.py b/tests/diffusion/attention/test_piecewise_attn.py new file mode 100644 index 00000000000..6560876234d --- /dev/null +++ b/tests/diffusion/attention/test_piecewise_attn.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""End-to-end test for ``piecewise_attn`` (CPU). + +Verify that running attention in segments (causal outside full-attn spans, +bidirectional inside full-attn spans) matches running a single full SDPA call +with the equivalent 2D attention mask. + +Covers: + * batch size = 1 and batch size > 1 (homogeneous CFG-like batch) + * query length == key length (full prefill) + * query length < key length (decode-like tail slice) + * various full-attn-span layouts (none / start / middle / end / multi) +""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn.functional as F + +from vllm_omni.diffusion.attention.backends.utils.piecewise_attn import ( + piecewise_attn, +) + +DEVICE = torch.device("cpu") + + +def _sdpa_attn_func(q, k, v, causal, softmax_scale): + q_ = q.transpose(1, 2) + k_ = k.transpose(1, 2) + v_ = v.transpose(1, 2) + attn_mask = None + if causal: + Sq, Sk = q_.shape[-2], k_.shape[-2] + i = torch.arange(Sq, device=q.device).unsqueeze(1) + j = torch.arange(Sk, device=q.device).unsqueeze(0) + attn_mask = j <= (i + (Sk - Sq)) + out = F.scaled_dot_product_attention(q_, k_, v_, attn_mask=attn_mask, scale=softmax_scale) + return out.transpose(1, 2).contiguous() + + +def _full_reference(query, key, value, global_spans, q_start, q_end, softmax_scale): + """Build a full 2D mask with global spans and compute reference output.""" + Sk = key.shape[1] + mask = torch.tril(torch.ones(Sk, Sk, dtype=torch.bool, device=key.device)) + for a, e in global_spans: + mask[a:e, :e] = True + mask_q = mask[q_start:q_end, :] + q_ = query.transpose(1, 2) + k_ = key.transpose(1, 2) + v_ = value.transpose(1, 2) + out = F.scaled_dot_product_attention(q_, k_, v_, attn_mask=mask_q, scale=softmax_scale) + return out.transpose(1, 2).contiguous() + + +SPAN_CASES = [ + pytest.param([], id="no-spans"), + pytest.param([(0, 10)], id="span-at-start"), + pytest.param([(10, 30), (54, 64)], id="multi-spans"), +] + +Q_RANGE_CASES = [ + pytest.param((0, 64), id="q_eq_k"), # Sq == Sk (prefill) + pytest.param((53, 64), id="q_lt_k"), # Sq < Sk (decode-like) +] + +BATCH_CASES = [ + pytest.param(1, id="B1"), + pytest.param(2, id="B2"), +] + + +@pytest.mark.parametrize("global_spans", SPAN_CASES) +@pytest.mark.parametrize("q_range", Q_RANGE_CASES) +@pytest.mark.parametrize("batch_size", BATCH_CASES) +def test_piecewise_matches_full(global_spans, q_range, batch_size): + torch.manual_seed(0) + H, D, Sk = 2, 16, 64 + q_start, q_end = q_range + Sq = q_end - q_start + + key = torch.randn(batch_size, Sk, H, D, device=DEVICE) + value = torch.randn(batch_size, Sk, H, D, device=DEVICE) + query = torch.randn(batch_size, Sq, H, D, device=DEVICE) + + full_attn_spans = [list(global_spans) for _ in range(batch_size)] + softmax_scale = 1.0 / (D**0.5) + + got = piecewise_attn( + query, + key, + value, + full_attn_spans=full_attn_spans, + softmax_scale=softmax_scale, + attn_func=_sdpa_attn_func, + ) + expected = _full_reference(query, key, value, global_spans, q_start, q_end, softmax_scale) + torch.testing.assert_close(got, expected, atol=1e-5, rtol=1e-5) + + +def test_piecewise_span_fully_before_qstart(): + """Spans entirely before query region produce pure causal attention.""" + torch.manual_seed(0) + B, H, D, Sk = 1, 2, 16, 30 + q_start, q_end = 15, 30 + Sq = q_end - q_start + + key = torch.randn(B, Sk, H, D, device=DEVICE) + value = torch.randn(B, Sk, H, D, device=DEVICE) + query = torch.randn(B, Sq, H, D, device=DEVICE) + + global_spans = [(5, 10)] + full_attn_spans = [list(global_spans) for _ in range(B)] + softmax_scale = 1.0 / (D**0.5) + + got = piecewise_attn( + query, + key, + value, + full_attn_spans=full_attn_spans, + softmax_scale=softmax_scale, + attn_func=_sdpa_attn_func, + ) + expected = _full_reference(query, key, value, global_spans, q_start, q_end, softmax_scale) + torch.testing.assert_close(got, expected, atol=1e-5, rtol=1e-5) diff --git a/vllm_omni/diffusion/attention/backends/abstract.py b/vllm_omni/diffusion/attention/backends/abstract.py index 6dd785e6678..f702dc65028 100644 --- a/vllm_omni/diffusion/attention/backends/abstract.py +++ b/vllm_omni/diffusion/attention/backends/abstract.py @@ -68,6 +68,10 @@ class AttentionMetadata: # Opaque backend-specific per-forward parameters (e.g. block masks, KV indices). # Backends MUST silently ignore unknown keys. + # Piecewise attention metadata (mixed causal/full masks). + # full_attn_spans: per-sample [start, end) spans in global coordinates using full attention. + full_attn_spans: list[list[tuple[int, int]]] | None = None + T = TypeVar("T", bound=AttentionMetadata) diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py index 3413984460f..a612546942a 100644 --- a/vllm_omni/diffusion/attention/backends/flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import partial import torch from vllm.logger import init_logger @@ -10,6 +11,9 @@ AttentionImpl, AttentionMetadata, ) +from vllm_omni.diffusion.attention.backends.utils.piecewise_attn import ( + piecewise_attn, +) logger = init_logger(__name__) @@ -59,6 +63,10 @@ def _unwrap_flash_output(out: torch.Tensor | tuple[torch.Tensor, ...]) -> torch. # FA3 may return (out, lse), FA2 returns out return out[0] if isinstance(out, tuple) else out + @staticmethod + def _flash_wrapper(q, k, v, *, attn_func, **kwargs): + return FlashAttentionImpl._unwrap_flash_output(attn_func(q, k, v, **kwargs)) + def _forward_varlen_masked( self, query: torch.Tensor, @@ -156,6 +164,24 @@ def forward_cuda( ) attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None + full_attn_spans = attn_metadata.full_attn_spans if attn_metadata is not None else None + + # Try piecewise attention + if full_attn_spans is not None: + logger.debug("Using piecewise Flash Attention for mixed causal/full mask") + attn_func = partial( + FlashAttentionImpl._flash_wrapper, + attn_func=flash_attn_func, + ) + + return piecewise_attn( + query, + key, + value, + full_attn_spans, + self.softmax_scale, + attn_func, + ) if attention_mask is not None and torch.any(~attention_mask): return self._forward_varlen_masked( diff --git a/vllm_omni/diffusion/attention/backends/utils/piecewise_attn.py b/vllm_omni/diffusion/attention/backends/utils/piecewise_attn.py new file mode 100644 index 00000000000..fbbd3c8e005 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/utils/piecewise_attn.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Piecewise attention for mixed causal / full (bidirectional) masks. + +Dispatches each segment as a separate attention call whose causal flag +follows FlashAttention's bottom-right convention (``K[:e]`` is attended by +``Q[s:e]``, with causal alignment anchored at the bottom-right corner). + +Per segment: + - causal segment ``[s, e)``: ``attn(Q[:, s:e], K[:, :e], V[:, :e], causal=True)`` + - full-attn span ``[a, e)``: ``attn(Q[:, a:e], K[:, :e], V[:, :e], causal=False)`` +""" + +from __future__ import annotations + +from typing import Literal, NamedTuple + + +class Segment(NamedTuple): + start: int + end: int + mode: Literal["causal", "full"] + + +def build_segments(full_attn_spans, query_offset, query_len): + """ + full_attn_spans: list of (start, end) half-open spans in global coordinates + query_offset: starting position of query in the global sequence + query_len: length of the query + + return: + List[Segment] in global coordinates, clipped to [query_offset, query_offset + query_len) + """ + q_start = query_offset + q_end = query_offset + query_len + + segs: list[Segment] = [] + cur = q_start + + for a, e in full_attn_spans: + # clip span to query range + a_clipped = max(a, q_start) + e_clipped = min(e, q_end) + if a_clipped >= e_clipped: + continue + + if cur < a_clipped: + segs.append(Segment(cur, a_clipped, "causal")) + segs.append(Segment(a_clipped, e_clipped, "full")) + cur = e_clipped + + if cur < q_end: + segs.append(Segment(cur, q_end, "causal")) + + return segs + + +def _check_homogeneous( + full_attn_spans: list[list[tuple[int, int]]], +) -> None: + """Assert all samples share identical spans.""" + if len(full_attn_spans) > 1: + ref = full_attn_spans[0] + for i, s in enumerate(full_attn_spans[1:], 1): + if s != ref: + raise ValueError( + f"piecewise_attn requires homogeneous batch: sample 0 spans {ref} != sample {i} spans {s}" + ) + + +def piecewise_attn( + query, # (B, Sq, H, D) + key, + value, + full_attn_spans: list[list[tuple[int, int]]], + softmax_scale: float, + attn_func, +): + B, Sq, H, D = query.shape + _check_homogeneous(full_attn_spans) + + query_offset = key.shape[1] - Sq + spans = full_attn_spans[0] + out = query.new_zeros(B, Sq, H, D) + + for s, e, mode in build_segments(spans, query_offset, Sq): + q_s = s - query_offset + q_e = e - query_offset + out_seg = attn_func( + query[:, q_s:q_e], + key[:, :e], + value[:, :e], + causal=(mode == "causal"), + softmax_scale=softmax_scale, + ) + out[:, q_s:q_e] = out_seg + return out diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py index 06cc6cfdeb3..0125b319a5d 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py @@ -1054,9 +1054,12 @@ def __call__( attention_mask = attention_mask.contiguous() + full_attn_spans = kwargs.get("full_attn_spans", None) + if self.sp_size <= 1: attn_metadata = AttentionMetadata( attn_mask=attention_mask, + full_attn_spans=full_attn_spans, ) else: attn_metadata = AttentionMetadata( @@ -1065,6 +1068,7 @@ def __call__( joint_value=joint_text_value, joint_strategy="front", attn_mask=attention_mask, + full_attn_spans=full_attn_spans, ) attn_output = self.attn(query, key, value, attn_metadata) attn_output = attn_output.reshape(bs * q_len, head_num_per_rank, head_dim) @@ -2246,6 +2250,7 @@ def forward( num_image_tokens: int | None = None, gen_timestep_scatter_index: torch.Tensor | None = None, uncond_cfg_prefill: bool = False, + full_attn_spans: list[list[tuple[int, int]]] | None = None, ) -> tuple | BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2354,6 +2359,7 @@ def forward( shard_image_size=shard_image_size, shard_padding_size=shard_padding_size, uncond_cfg_prefill=uncond_cfg_prefill, + full_attn_spans=full_attn_spans, ) hidden_states = layer_outputs[0] @@ -2594,6 +2600,10 @@ def _split_model_kwargs_for_cfg_parallel(model_kwargs: dict[str, Any], batch_siz if key in model_kwargs and model_kwargs[key] is not None: model_kwargs[key] = model_kwargs[key][s] + # List[List[...]] per-sample metadata indexed along the CFG batch dim + if isinstance(model_kwargs.get("full_attn_spans"), list): + model_kwargs["full_attn_spans"] = model_kwargs["full_attn_spans"][s.start : s.stop] + # custom_pos_emb: tuple of (cos, sin) if "custom_pos_emb" in model_kwargs and model_kwargs["custom_pos_emb"] is not None: cos, sin = model_kwargs["custom_pos_emb"] @@ -2724,6 +2734,9 @@ def _build_negative_cfg_prefill_inputs( query_lens=[prefill_query_len], seq_lens=[prefill_seq_len], num_image_tokens=0, + full_attn_spans=model_kwargs["full_attn_spans"][batch_slice] + if model_kwargs.get("full_attn_spans") + else None, ) # ========================================================== diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 3a4d0d2c64d..7a8be07456d 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -import os from collections.abc import Iterable from typing import Any @@ -340,11 +339,6 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: ) ] quant_config = od_config.quantization_config - os.environ["DIFFUSION_ATTENTION_BACKEND"] = "TORCH_SDPA" - logger.info( - "Setting attention backend to TORCH_SDPA. " - "HunyuanImage3Pipeline only supports TORCH_SDPA to handle mixed causal and full attention." - ) self.model = HunyuanImage3Model(self.hf_config, quant_config=quant_config) self.transformer = self.model self.vae = AutoencoderKLConv3D.from_config(self.hf_config.vae) @@ -962,10 +956,18 @@ def _prepare_attention_mask_for_generation( tokenizer_output.joint_image_slices[i] + tokenizer_output.gen_image_slices[i] for i in range(bsz) ] attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) + full_attn_spans: list[list[tuple[int, int]]] = [[] for _ in range(bsz)] for i in range(bsz): for j, image_slice in enumerate(batch_image_slices[i]): attention_mask[i, image_slice, image_slice] = True + start = image_slice.start if image_slice.start is not None else 0 + stop = image_slice.stop if image_slice.stop is not None else seq_len + assert start < stop, f"Invalid image slice: {image_slice}" + full_attn_spans[i].append((int(start), int(stop))) + if full_attn_spans[i]: + full_attn_spans[i].sort(key=lambda x: x[0]) attention_mask = attention_mask.unsqueeze(1) + model_kwargs["full_attn_spans"] = full_attn_spans return attention_mask def prepare_inputs_for_generation( @@ -1009,6 +1011,7 @@ def prepare_inputs_for_generation( "query_lens": kwargs.get("query_lens"), "seq_lens": kwargs.get("seq_lens"), "num_image_tokens": kwargs.get("num_image_tokens"), + "full_attn_spans": kwargs.get("full_attn_spans"), } ) return model_inputs @@ -1027,6 +1030,8 @@ def _update_model_kwargs_for_generation( "custom_pos_emb": model_kwargs["custom_pos_emb"], "num_image_tokens": model_kwargs["num_image_tokens"], } + if "full_attn_spans" in model_kwargs: + updated_model_kwargs["full_attn_spans"] = model_kwargs["full_attn_spans"] # update past_key_values keeping its naming used in model code for possible_cache_name in ALL_CACHE_NAMES: @@ -1057,6 +1062,7 @@ def _update_model_kwargs_for_generation( torch.arange(bsz), model_kwargs["gen_timestep_scatter_index"][:, -1] ].unsqueeze(-1) updated_model_kwargs["position_ids"] = torch.cat([timestep_position_ids, position_ids], dim=1) + # attention mask mask_list = [] for attention_mask_i, position_ids_i in zip( @@ -1156,6 +1162,7 @@ def forward_call( seq_lens: list[int] | None = None, num_image_tokens: int | None = None, uncond_cfg_prefill: bool = False, + full_attn_spans: list[list[tuple[int, int]]] | None = None, ) -> tuple | CausalMMOutputWithPast: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Sanity Check of Inputs @@ -1251,6 +1258,7 @@ def forward_call( num_image_tokens=num_image_tokens, gen_timestep_scatter_index=gen_timestep_scatter_index, uncond_cfg_prefill=uncond_cfg_prefill, + full_attn_spans=full_attn_spans, ) hidden_states = outputs[0] diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index ad008c3971f..d673ad1667c 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -995,6 +995,11 @@ def receive_kv_cache_for_request( logger.info(f"Skip receiving KV cache for {request_id} (need_recv_cache=False)") return None, 0 + # Skip during warmup dummy run — no sender is available. + if request_id == "dummy_req_id": + logger.info("Skip receiving KV cache for dummy warmup request") + return None, 0 + timeout = self.config.recv_timeout start_time = time.time() poll_interval = 0.01 From 5685577e0825b79aafd8830a089bdb9be4f88ac4 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 23:50:36 +0800 Subject: [PATCH 5/5] Handle dict diffusion parallel_config overrides Signed-off-by: xiaohajiayou <923390377@qq.com> --- vllm_omni/config/stage_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index 9882846f307..6f3ac97df61 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -130,7 +130,7 @@ def _apply_diffusion_parallel_runtime_overrides( parallel_fields = frozenset(f.name for f in fields(DiffusionParallelConfig)) parallel_config = engine_args.get("parallel_config") - parallel_config_dict = to_dict(parallel_config) if parallel_config is not None else None + parallel_config_dict = dict(parallel_config) if parallel_config is not None else None for key in list(runtime_overrides.keys()): value = runtime_overrides.get(key)