Skip to content
Closed
42 changes: 40 additions & 2 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,7 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
stage_overrides_json = kwargs.pop("stage_overrides", None)
kwargs.pop("_cli_explicit_keys", None)
explicit_stage_configs = kwargs.pop("stage_configs", None)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does a network round-trip (get_hf_file_to_dict) on every call just to guard the parallel-config bridging. _resolve_stage_configs already resolves the model type downstream — can you reuse that or at least cache the result?

if explicit_stage_configs is not None:
logger.warning(
"`stage_configs` is not part of the public API. "
Expand Down Expand Up @@ -1631,7 +1632,8 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
stage_overrides=stage_overrides,
)

# Inject diffusion LoRA-related knobs from kwargs if not present in the stage config.
# Inject diffusion knobs (parallel_config, LoRA, quantization) from kwargs
# into resolved diffusion stages when not already set by YAML/model config.
for cfg in stage_configs:
try:
if not hasattr(cfg, "engine_args") or cfg.engine_args is None:
Expand All @@ -1644,6 +1646,42 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
continue
if not hasattr(cfg, "engine_args") or cfg.engine_args is None:
cfg.engine_args = OmegaConf.create({})

if kwargs.get("parallel_config") is None:
parallel_cli_fields = {
"ulysses_degree": 1,
"ring_degree": 1,
"ulysses_mode": "strict",
"sequence_parallel_size": None,
"tensor_parallel_size": 1,
"enable_expert_parallel": False,
"cfg_parallel_size": 1,
"vae_patch_parallel_size": 1,
"use_hsdp": False,
"hsdp_shard_size": -1,
"hsdp_replicate_size": 1,
}
if not hasattr(cfg.engine_args, "parallel_config") or cfg.engine_args.parallel_config is None:
values = {k: kwargs.get(k, d) for k, d in parallel_cli_fields.items()}
if values["sequence_parallel_size"] is None:
values["sequence_parallel_size"] = values["ulysses_degree"] * values["ring_degree"]
cfg.engine_args.parallel_config = DiffusionParallelConfig(
pipeline_parallel_size=1,
data_parallel_size=1,
**values,
)
else:
# YAML/model config already set parallel_config; only override
# fields that the user explicitly passed via kwargs.
pc = cfg.engine_args.parallel_config
for key in parallel_cli_fields:
if key in kwargs:
setattr(pc, key, kwargs[key])
if "sequence_parallel_size" not in kwargs and (
"ulysses_degree" in kwargs or "ring_degree" in kwargs
):
pc.sequence_parallel_size = pc.ulysses_degree * pc.ring_degree

if kwargs.get("lora_path") is not None:
if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None:
cfg.engine_args.lora_path = kwargs["lora_path"]
Expand Down Expand Up @@ -1708,7 +1746,7 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
):
cfg.engine_args.kv_cache_skip_layers = kv_cache_skip_layers
except Exception as e:
logger.warning("Failed to inject LoRA config for stage: %s", e)
logger.warning("Failed to inject diffusion engine_args for stage: %s", e)

return config_path, stage_configs

Expand Down
Loading