From 11eaecfbdd3d73d6f0ae1d24d6c7083aab2577dc Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 21:19:20 +0800 Subject: [PATCH 1/7] Fix diffusion parallel override normalization in stage config Co-authored-by: zzhuoxin1508 Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_config_factory.py | 141 +++++++++++++++++++++++++++++++ vllm_omni/config/stage_config.py | 39 +++++++-- 2 files changed, 174 insertions(+), 6 deletions(-) diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 9ac3d859c1e..6120515075d 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -173,6 +173,147 @@ def test_to_omegaconf_omits_none_deploy_overrides_for_engine_args(self): for name in deploy_override_field_names() - {"devices"}: assert name not in engine_args + def test_to_omegaconf_diffusion_parallel_overrides_replace_nested_values(self): + config = StageConfig( + stage_id=1, + model_stage="diffusion", + stage_type=StageType.DIFFUSION, + yaml_engine_args={ + "parallel_config": { + "pipeline_parallel_size": 1, + "data_parallel_size": 1, + "tensor_parallel_size": 4, + "enable_expert_parallel": False, + "ulysses_degree": 1, + "ring_degree": 1, + "ulysses_mode": "strict", + "sequence_parallel_size": 1, + "cfg_parallel_size": 1, + "vae_patch_parallel_size": 1, + "use_hsdp": False, + "hsdp_shard_size": -1, + "hsdp_replicate_size": 1, + } + }, + runtime_overrides={ + "pipeline_parallel_size": 2, + "data_parallel_size": 3, + "tensor_parallel_size": 8, + "enable_expert_parallel": True, + "ulysses_degree": 2, + "ring_degree": 4, + "ulysses_mode": "advanced_uaa", + "sequence_parallel_size": 8, + "cfg_parallel_size": 2, + "vae_patch_parallel_size": 2, + "use_hsdp": True, + "hsdp_shard_size": 8, + "hsdp_replicate_size": 2, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.parallel_config.pipeline_parallel_size == 2 + assert omega_config.engine_args.parallel_config.data_parallel_size == 3 + assert omega_config.engine_args.parallel_config.tensor_parallel_size == 8 + assert omega_config.engine_args.parallel_config.enable_expert_parallel is True + assert omega_config.engine_args.parallel_config.ulysses_degree == 2 + assert omega_config.engine_args.parallel_config.ring_degree == 4 + assert omega_config.engine_args.parallel_config.ulysses_mode == "advanced_uaa" + assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8 + assert omega_config.engine_args.parallel_config.cfg_parallel_size == 2 + assert omega_config.engine_args.parallel_config.vae_patch_parallel_size == 2 + assert omega_config.engine_args.parallel_config.use_hsdp is True + assert omega_config.engine_args.parallel_config.hsdp_shard_size == 8 + assert omega_config.engine_args.parallel_config.hsdp_replicate_size == 2 + assert "pipeline_parallel_size" not in omega_config.engine_args + assert "data_parallel_size" not in omega_config.engine_args + assert "tensor_parallel_size" not in omega_config.engine_args + assert "enable_expert_parallel" not in omega_config.engine_args + assert "ulysses_degree" not in omega_config.engine_args + assert "ring_degree" not in omega_config.engine_args + assert "ulysses_mode" not in omega_config.engine_args + assert "sequence_parallel_size" not in omega_config.engine_args + assert "cfg_parallel_size" not in omega_config.engine_args + assert "vae_patch_parallel_size" not in omega_config.engine_args + assert "use_hsdp" not in omega_config.engine_args + assert "hsdp_shard_size" not in omega_config.engine_args + assert "hsdp_replicate_size" not in omega_config.engine_args + + def test_to_omegaconf_diffusion_parallel_overrides_create_parallel_config(self): + config = StageConfig( + stage_id=1, + model_stage="diffusion", + stage_type=StageType.DIFFUSION, + runtime_overrides={ + "pipeline_parallel_size": 2, + "data_parallel_size": 3, + "tensor_parallel_size": 8, + "enable_expert_parallel": True, + "ulysses_degree": 2, + "ring_degree": 4, + "ulysses_mode": "advanced_uaa", + "sequence_parallel_size": 8, + "cfg_parallel_size": 2, + "vae_patch_parallel_size": 2, + "use_hsdp": True, + "hsdp_shard_size": 8, + "hsdp_replicate_size": 2, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.parallel_config.pipeline_parallel_size == 2 + assert omega_config.engine_args.parallel_config.data_parallel_size == 3 + assert omega_config.engine_args.parallel_config.tensor_parallel_size == 8 + assert omega_config.engine_args.parallel_config.enable_expert_parallel is True + assert omega_config.engine_args.parallel_config.ulysses_degree == 2 + assert omega_config.engine_args.parallel_config.ring_degree == 4 + assert omega_config.engine_args.parallel_config.ulysses_mode == "advanced_uaa" + assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8 + assert omega_config.engine_args.parallel_config.cfg_parallel_size == 2 + assert omega_config.engine_args.parallel_config.vae_patch_parallel_size == 2 + assert omega_config.engine_args.parallel_config.use_hsdp is True + assert omega_config.engine_args.parallel_config.hsdp_shard_size == 8 + assert omega_config.engine_args.parallel_config.hsdp_replicate_size == 2 + assert "pipeline_parallel_size" not in omega_config.engine_args + assert "data_parallel_size" not in omega_config.engine_args + assert "tensor_parallel_size" not in omega_config.engine_args + assert "enable_expert_parallel" not in omega_config.engine_args + assert "ulysses_degree" not in omega_config.engine_args + assert "ring_degree" not in omega_config.engine_args + assert "ulysses_mode" not in omega_config.engine_args + assert "sequence_parallel_size" not in omega_config.engine_args + assert "cfg_parallel_size" not in omega_config.engine_args + assert "vae_patch_parallel_size" not in omega_config.engine_args + assert "use_hsdp" not in omega_config.engine_args + assert "hsdp_shard_size" not in omega_config.engine_args + assert "hsdp_replicate_size" not in omega_config.engine_args + + def test_to_omegaconf_llm_parallel_overrides_remain_top_level(self): + config = StageConfig( + stage_id=0, + model_stage="thinker", + stage_type=StageType.LLM, + runtime_overrides={ + "pipeline_parallel_size": 2, + "data_parallel_size": 3, + "tensor_parallel_size": 8, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.pipeline_parallel_size == 2 + assert omega_config.engine_args.data_parallel_size == 3 + assert omega_config.engine_args.tensor_parallel_size == 8 + assert "pipeline_parallel_size" in omega_config.engine_args + assert "data_parallel_size" in omega_config.engine_args + assert "tensor_parallel_size" in omega_config.engine_args + assert "parallel_config" not in omega_config.engine_args + class TestModelPipeline: """Tests for ModelPipeline class.""" diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index 7c17cf7ceb3..9cc1f325840 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -121,6 +121,29 @@ def strip_parent_engine_args( return result, sorted(overridden) +def _apply_diffusion_parallel_runtime_overrides( + engine_args: dict[str, Any], + runtime_overrides: dict[str, Any], +) -> None: + """Move diffusion parallel overrides into nested ``parallel_config``.""" + from vllm_omni.diffusion.data import DiffusionParallelConfig + + parallel_fields = frozenset(f.name for f in fields(DiffusionParallelConfig)) + parallel_config = engine_args.get("parallel_config") + parallel_config_dict = to_dict(parallel_config) if parallel_config is not None else None + + for key in list(runtime_overrides.keys()): + value = runtime_overrides.get(key) + if value is None or key not in parallel_fields: + continue + if parallel_config_dict is None: + parallel_config_dict = {} + parallel_config_dict[key] = runtime_overrides.pop(key) + + if parallel_config_dict is not None: + engine_args["parallel_config"] = parallel_config_dict + + class StageType(str, Enum): """Type of processing stage in the Omni pipeline.""" @@ -930,6 +953,7 @@ def to_omegaconf(self) -> Any: """TODO(@lishunyang12): remove once engine consumes ResolvedStageConfig directly.""" # Start with YAML engine_args defaults engine_args: dict[str, Any] = dict(self.yaml_engine_args) + runtime_overrides = dict(self.runtime_overrides) # Overlay topology-level fields engine_args["model_stage"] = self.model_stage @@ -940,22 +964,25 @@ def to_omegaconf(self) -> Any: if self.hf_config_name: engine_args["hf_config_name"] = self.hf_config_name + if StageType(self.stage_type) == StageType.DIFFUSION: + _apply_diffusion_parallel_runtime_overrides(engine_args, runtime_overrides) + # CLI overrides take precedence over YAML defaults - for key, value in self.runtime_overrides.items(): + for key, value in runtime_overrides.items(): if value is not None and key not in ("devices", "max_batch_size", "num_replicas"): engine_args[key] = value # Build runtime config from YAML defaults + CLI overrides runtime: dict[str, Any] = dict(self.yaml_runtime) runtime.setdefault("process", True) - if self.runtime_overrides.get("devices") is not None: - runtime["devices"] = self.runtime_overrides["devices"] - if self.runtime_overrides.get("num_replicas") is not None: - runtime["num_replicas"] = self.runtime_overrides["num_replicas"] + if runtime_overrides.get("devices") is not None: + runtime["devices"] = runtime_overrides["devices"] + if runtime_overrides.get("num_replicas") is not None: + runtime["num_replicas"] = runtime_overrides["num_replicas"] # Legacy compat: migrate runtime.max_batch_size → engine_args.max_num_seqs legacy_mbs = runtime.pop("max_batch_size", None) - cli_mbs = self.runtime_overrides.get("max_batch_size") + cli_mbs = runtime_overrides.get("max_batch_size") if legacy_mbs is not None or cli_mbs is not None: warnings.warn( "runtime.max_batch_size is deprecated and will be removed in a " From 24edd6875cbc59e66fe9461526593926b33677e0 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 21:49:33 +0800 Subject: [PATCH 2/7] Fix diffusion deploy override nullification Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_arg_utils.py | 9 ++++++--- tests/test_config_factory.py | 34 ++++++++++++++++++++++++-------- vllm_omni/config/stage_config.py | 12 +++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py index 2fd5cf302e0..8fc445a604d 100644 --- a/tests/test_arg_utils.py +++ b/tests/test_arg_utils.py @@ -393,16 +393,19 @@ def test_nullify_stage_engine_defaults_resets_inherited_defaults(): def test_non_override_flags_keep_real_defaults_after_nullify(): import argparse + from vllm_omni.config.stage_config import deploy_override_field_names from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults parser = argparse.ArgumentParser() - parser.add_argument("--hsdp-shard-size", type=int, default=-1, help="HSDP shard size.") + parser.add_argument("--batch-timeout", type=int, default=10, help="Batch timeout.") parser.add_argument("--max-num-seqs", type=int, default=64, help="Max num seqs.") nullify_stage_engine_defaults(parser) - hsdp = next(a for a in parser._actions if a.dest == "hsdp_shard_size") + assert "batch_timeout" not in deploy_override_field_names() + + batch_timeout = next(a for a in parser._actions if a.dest == "batch_timeout") max_num_seqs = next(a for a in parser._actions if a.dest == "max_num_seqs") - assert hsdp.default == -1 + assert batch_timeout.default == 10 assert max_num_seqs.default is None diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 6120515075d..9fa3b3d4acf 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -969,17 +969,16 @@ class TestDeployConfigLoading: def test_deploy_override_fields_include_deploy_schema_fields(self): expected_fields = { "async_chunk", + + # StageDeployConfig: stage placement and runtime fields. + "devices", + + # StageDeployConfig: vLLM EngineArgs fields. "async_scheduling", "compilation_config", "config_format", - "data_parallel_size", - "devices", "disable_hybrid_kv_cache_manager", - "distributed_executor_backend", - "dtype", - "enable_chunked_prefill", "enable_flashinfer_autotune", - "enable_prefix_caching", "enforce_eager", "gpu_memory_utilization", "load_format", @@ -987,13 +986,32 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): "max_num_batched_tokens", "max_num_seqs", "mm_processor_cache_gb", - "pipeline_parallel_size", "profiler_config", - "quantization", "skip_mm_profiling", "subtalker_sampling_params", "tensor_parallel_size", "tokenizer_mode", + + # StageDeployConfig: diffusion parallel_config deploy override fields. + "cfg_parallel_size", + "enable_expert_parallel", + "hsdp_replicate_size", + "hsdp_shard_size", + "ring_degree", + "sequence_parallel_size", + "ulysses_degree", + "ulysses_mode", + "use_hsdp", + "vae_patch_parallel_size", + + # DeployConfig: pipeline-wide engine settings. + "data_parallel_size", + "distributed_executor_backend", + "dtype", + "enable_chunked_prefill", + "enable_prefix_caching", + "pipeline_parallel_size", + "quantization", "trust_remote_code", } diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index 9cc1f325840..e42216ec0f8 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -452,6 +452,18 @@ class StageDeployConfig: disable_hybrid_kv_cache_manager: bool | None = None mm_processor_cache_gb: float | None = None + # Diffusion parallel_config deploy override fields. + enable_expert_parallel: bool | None = None + ulysses_degree: int | None = None + ulysses_mode: str | None = None + ring_degree: int | None = None + sequence_parallel_size: int | None = None + cfg_parallel_size: int | None = None + vae_patch_parallel_size: int | None = None + use_hsdp: bool | None = None + hsdp_shard_size: int | None = None + hsdp_replicate_size: int | None = None + # Compilation, profiling, tokenizer/config parsing, and model loading. compilation_config: dict[str, Any] | None = None profiler_config: dict[str, Any] | None = None From 2f7c486fd41e05688303c8c67beed8a99cbe6021 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Sun, 10 May 2026 23:50:36 +0800 Subject: [PATCH 3/7] Handle dict diffusion parallel_config overrides Signed-off-by: xiaohajiayou <923390377@qq.com> --- vllm_omni/config/stage_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index e42216ec0f8..afcb8fc53f5 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -130,7 +130,7 @@ def _apply_diffusion_parallel_runtime_overrides( parallel_fields = frozenset(f.name for f in fields(DiffusionParallelConfig)) parallel_config = engine_args.get("parallel_config") - parallel_config_dict = to_dict(parallel_config) if parallel_config is not None else None + parallel_config_dict = dict(parallel_config) if parallel_config is not None else None for key in list(runtime_overrides.keys()): value = runtime_overrides.get(key) From edcd0959c6b1c760511b0ad9cad4771e95d00679 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Mon, 11 May 2026 00:44:24 +0800 Subject: [PATCH 4/7] Fix deploy override test formatting Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_config_factory.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 9fa3b3d4acf..0e9b718e452 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -969,10 +969,8 @@ class TestDeployConfigLoading: def test_deploy_override_fields_include_deploy_schema_fields(self): expected_fields = { "async_chunk", - # StageDeployConfig: stage placement and runtime fields. "devices", - # StageDeployConfig: vLLM EngineArgs fields. "async_scheduling", "compilation_config", @@ -991,7 +989,6 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): "subtalker_sampling_params", "tensor_parallel_size", "tokenizer_mode", - # StageDeployConfig: diffusion parallel_config deploy override fields. "cfg_parallel_size", "enable_expert_parallel", @@ -1003,7 +1000,6 @@ def test_deploy_override_fields_include_deploy_schema_fields(self): "ulysses_mode", "use_hsdp", "vae_patch_parallel_size", - # DeployConfig: pipeline-wide engine settings. "data_parallel_size", "distributed_executor_backend", From 505eb6252d641b21f35647a346a1fecc38388f63 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Mon, 11 May 2026 01:20:31 +0800 Subject: [PATCH 5/7] Fix nullified HSDP defaults in diffusion stage builder Signed-off-by: xiaohajiayou <923390377@qq.com> --- vllm_omni/engine/async_omni_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 23ef9b85567..d60a526ea34 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1879,9 +1879,9 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: pipeline_parallel_size = normalized_kwargs.get("pipeline_parallel_size") or 1 vae_patch_parallel_size = normalized_kwargs.get("vae_patch_parallel_size") or 1 enable_expert_parallel = normalized_kwargs.get("enable_expert_parallel") or False - use_hsdp = normalized_kwargs.get("use_hsdp", False) - hsdp_shard_size = normalized_kwargs.get("hsdp_shard_size", -1) - hsdp_replicate_size = normalized_kwargs.get("hsdp_replicate_size", 1) + use_hsdp = normalized_kwargs.get("use_hsdp") or False + hsdp_shard_size = normalized_kwargs.get("hsdp_shard_size") or -1 + hsdp_replicate_size = normalized_kwargs.get("hsdp_replicate_size") or 1 if sequence_parallel_size is None: sequence_parallel_size = ulysses_degree * ring_degree From 6cac15c91240ef7d8f202b98dc87ae5604eefdc3 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Mon, 11 May 2026 01:40:42 +0800 Subject: [PATCH 6/7] Fix legacy nullify entrypoint expectations Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/entrypoints/test_omni_entrypoints.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py index adcdc3e9780..3612020d4fd 100644 --- a/tests/entrypoints/test_omni_entrypoints.py +++ b/tests/entrypoints/test_omni_entrypoints.py @@ -184,7 +184,7 @@ def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine: parser = argparse.ArgumentParser() parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) - parser.add_argument("--hsdp-shard-size", type=int, default=-1) + parser.add_argument("--batch-timeout", type=int, default=10) nullify_stage_engine_defaults(parser) args = parser.parse_args([]) args.model = "fake-model" @@ -192,7 +192,7 @@ def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine: Omni(**vars(args)) assert captured["gpu_memory_utilization"] is None - assert captured["hsdp_shard_size"] == -1 + assert captured["batch_timeout"] == 10 assert "_cli_explicit_keys" not in captured @@ -233,7 +233,7 @@ def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine: parser = argparse.ArgumentParser() parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) - parser.add_argument("--hsdp-shard-size", type=int, default=-1) + parser.add_argument("--batch-timeout", type=int, default=10) args = parser.parse_args([]) args.model = "fake-model" @@ -241,7 +241,7 @@ def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine: Omni.from_cli_args(args, parser=parser) assert captured["gpu_memory_utilization"] is None - assert captured["hsdp_shard_size"] == -1 + assert captured["batch_timeout"] == 10 def _make_base(): From d06c7f0dd83be47ad308d03c161000988a17fb38 Mon Sep 17 00:00:00 2001 From: xiaohajiayou <923390377@qq.com> Date: Tue, 26 May 2026 19:37:20 +0800 Subject: [PATCH 7/7] Recompute diffusion SP size after degree overrides Signed-off-by: xiaohajiayou <923390377@qq.com> --- tests/test_config_factory.py | 52 ++++++++++++++++++++++++++++++++ vllm_omni/config/stage_config.py | 9 ++++++ 2 files changed, 61 insertions(+) diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 0e9b718e452..c54a39a1e38 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -292,6 +292,58 @@ def test_to_omegaconf_diffusion_parallel_overrides_create_parallel_config(self): assert "hsdp_shard_size" not in omega_config.engine_args assert "hsdp_replicate_size" not in omega_config.engine_args + def test_to_omegaconf_diffusion_parallel_degree_overrides_recompute_sequence_parallel_size(self): + config = StageConfig( + stage_id=1, + model_stage="diffusion", + stage_type=StageType.DIFFUSION, + yaml_engine_args={ + "parallel_config": { + "sequence_parallel_size": 1, + "ulysses_degree": 1, + "ring_degree": 1, + } + }, + runtime_overrides={ + "ulysses_degree": 2, + "ring_degree": 4, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.parallel_config.ulysses_degree == 2 + assert omega_config.engine_args.parallel_config.ring_degree == 4 + assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8 + assert "ulysses_degree" not in omega_config.engine_args + assert "ring_degree" not in omega_config.engine_args + assert "sequence_parallel_size" not in omega_config.engine_args + + def test_to_omegaconf_diffusion_parallel_explicit_sequence_parallel_size_is_preserved(self): + config = StageConfig( + stage_id=1, + model_stage="diffusion", + stage_type=StageType.DIFFUSION, + yaml_engine_args={ + "parallel_config": { + "sequence_parallel_size": 1, + "ulysses_degree": 1, + "ring_degree": 1, + } + }, + runtime_overrides={ + "ulysses_degree": 2, + "ring_degree": 4, + "sequence_parallel_size": 16, + }, + ) + + omega_config = config.to_omegaconf() + + assert omega_config.engine_args.parallel_config.ulysses_degree == 2 + assert omega_config.engine_args.parallel_config.ring_degree == 4 + assert omega_config.engine_args.parallel_config.sequence_parallel_size == 16 + def test_to_omegaconf_llm_parallel_overrides_remain_top_level(self): config = StageConfig( stage_id=0, diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index afcb8fc53f5..4188606d72f 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -131,6 +131,8 @@ def _apply_diffusion_parallel_runtime_overrides( parallel_fields = frozenset(f.name for f in fields(DiffusionParallelConfig)) parallel_config = engine_args.get("parallel_config") parallel_config_dict = dict(parallel_config) if parallel_config is not None else None + degree_overridden = False + sequence_parallel_explicit = runtime_overrides.get("sequence_parallel_size") is not None for key in list(runtime_overrides.keys()): value = runtime_overrides.get(key) @@ -138,8 +140,15 @@ def _apply_diffusion_parallel_runtime_overrides( continue if parallel_config_dict is None: parallel_config_dict = {} + if key in ("ulysses_degree", "ring_degree"): + degree_overridden = True parallel_config_dict[key] = runtime_overrides.pop(key) + if parallel_config_dict is not None and degree_overridden and not sequence_parallel_explicit: + ulysses_degree = parallel_config_dict.get("ulysses_degree") or 1 + ring_degree = parallel_config_dict.get("ring_degree") or 1 + parallel_config_dict["sequence_parallel_size"] = ulysses_degree * ring_degree + if parallel_config_dict is not None: engine_args["parallel_config"] = parallel_config_dict