diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 15167105135..2813cac47e0 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1598,6 +1598,7 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st stage_overrides_json = kwargs.pop("stage_overrides", None) kwargs.pop("_cli_explicit_keys", None) explicit_stage_configs = kwargs.pop("stage_configs", None) + if explicit_stage_configs is not None: logger.warning( "`stage_configs` is not part of the public API. " @@ -1631,7 +1632,8 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st stage_overrides=stage_overrides, ) - # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. + # Inject diffusion knobs (parallel_config, LoRA, quantization) from kwargs + # into resolved diffusion stages when not already set by YAML/model config. for cfg in stage_configs: try: if not hasattr(cfg, "engine_args") or cfg.engine_args is None: @@ -1644,6 +1646,42 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st continue if not hasattr(cfg, "engine_args") or cfg.engine_args is None: cfg.engine_args = OmegaConf.create({}) + + if kwargs.get("parallel_config") is None: + parallel_cli_fields = { + "ulysses_degree": 1, + "ring_degree": 1, + "ulysses_mode": "strict", + "sequence_parallel_size": None, + "tensor_parallel_size": 1, + "enable_expert_parallel": False, + "cfg_parallel_size": 1, + "vae_patch_parallel_size": 1, + "use_hsdp": False, + "hsdp_shard_size": -1, + "hsdp_replicate_size": 1, + } + if not hasattr(cfg.engine_args, "parallel_config") or cfg.engine_args.parallel_config is None: + values = {k: kwargs.get(k, d) for k, d in parallel_cli_fields.items()} + if values["sequence_parallel_size"] is None: + values["sequence_parallel_size"] = values["ulysses_degree"] * values["ring_degree"] + cfg.engine_args.parallel_config = DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + **values, + ) + else: + # YAML/model config already set parallel_config; only override + # fields that the user explicitly passed via kwargs. + pc = cfg.engine_args.parallel_config + for key in parallel_cli_fields: + if key in kwargs: + setattr(pc, key, kwargs[key]) + if "sequence_parallel_size" not in kwargs and ( + "ulysses_degree" in kwargs or "ring_degree" in kwargs + ): + pc.sequence_parallel_size = pc.ulysses_degree * pc.ring_degree + if kwargs.get("lora_path") is not None: if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: cfg.engine_args.lora_path = kwargs["lora_path"] @@ -1708,7 +1746,7 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st ): cfg.engine_args.kv_cache_skip_layers = kv_cache_skip_layers except Exception as e: - logger.warning("Failed to inject LoRA config for stage: %s", e) + logger.warning("Failed to inject diffusion engine_args for stage: %s", e) return config_path, stage_configs