Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ def main(args: argparse.Namespace) -> None:
omni = Omni(
model=args.model,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
mode="image-to-text",
)

prompt = "<|startoftext|>You are an assistant that understands images and outputs text.<img>" + args.prompt

prompt_dict = {
"prompt": args.prompt,
"prompt": prompt,
"modalities": ["text"],
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def main():
"parallel_config": parallel_config,
"enforce_eager": args.enforce_eager,
"enable_cpu_offload": args.enable_cpu_offload,
"mode": "text-to-image",
"enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler,
**lora_args,
**quant_kwargs,
Expand Down
97 changes: 97 additions & 0 deletions vllm_omni/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,100 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None
return stage_args


def filter_stages(
config_path: str | None,
stage_configs: list,
kwargs: dict | None,
) -> list:
"""Filter stage configs by mode when YAML defines a `modes` section.

The YAML can define, e.g.:

modes:
- mode: text-to-image
stages: [1]
- mode: image-to-text
stages: [0]

When users pass `mode="image-to-text"` into Omni(**kwargs), only the stages
listed for that mode are returned. If no mode is provided, defaults to
"text-to-image". If no modes are defined or filtering fails, returns the
original stage_configs unchanged.

Args:
config_path: Path to the YAML config (used to read `modes`).
stage_configs: Loaded list of stage configs.
kwargs: Engine/caller kwargs; may contain "mode".

Returns:
Filtered list of stage configs (or original list if filtering not applied).
"""
if not stage_configs or config_path is None:
return stage_configs

try:
cfg = load_yaml_config(config_path)
yaml_modes = getattr(cfg, "modes", None)
if yaml_modes is None:
return stage_configs

mode_to_stage_ids: dict[str, list[int]] = {}
if yaml_modes is not None:
for entry in yaml_modes:
mode_name = None
stages = None
if hasattr(entry, "mode") or hasattr(entry, "stages"):
mode_name = getattr(entry, "mode", None)
stages = getattr(entry, "stages", None)
elif isinstance(entry, dict):
mode_name = entry.get("mode")
stages = entry.get("stages")

if mode_name is None or stages is None:
continue

if isinstance(stages, int):
stage_list = [stages]
else:
stage_list = list(stages)

mode_to_stage_ids[str(mode_name)] = [int(sid) for sid in stage_list]

# No modes section or empty mapping: use all stages and return early.
active_mode: str | None = None
if isinstance(kwargs, dict):
active_mode = kwargs.get("mode")

if active_mode is None:
active_mode = "text-to-image"

if active_mode not in mode_to_stage_ids:
logger.warning(
"Requested mode '%s' not found in config '%s'; available modes: %s. Using all stages.",
active_mode,
config_path,
sorted(mode_to_stage_ids.keys()),
)
return stage_configs

allowed_ids = set(mode_to_stage_ids[active_mode])
filtered_stage_configs = [sc for sc in stage_configs if getattr(sc, "stage_id", None) in allowed_ids]
if not filtered_stage_configs:
logger.warning(
"Mode '%s' in config '%s' resolved to stage ids %s, but none matched loaded stage_args. "
"Falling back to all stages.",
active_mode,
config_path,
sorted(allowed_ids),
)
return stage_configs

return filtered_stage_configs
except Exception as e:
logger.warning("Failed to apply mode-based stage filtering: %s", e)
return stage_configs


def load_and_resolve_stage_configs(
model: str,
stage_configs_path: str | None,
Expand Down Expand Up @@ -324,6 +418,9 @@ def load_and_resolve_stage_configs(
config_path = stage_configs_path
stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=kwargs)

stage_configs = filter_stages(config_path, stage_configs, kwargs)
logger.debug(f"stage_configs: {stage_configs}")

return config_path, stage_configs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from einops import rearrange
from torch import Tensor, nn

from vllm_omni.diffusion.distributed.utils import get_local_device


class DiagonalGaussianDistribution:
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
Expand Down Expand Up @@ -506,7 +508,7 @@ def __init__(

self.use_compile = False

self.empty_cache = torch.empty(0, device="cuda")
self.empty_cache = torch.empty(0, device=get_local_device())

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
Expand Down
39 changes: 39 additions & 0 deletions vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# Stage 0: AR Model (vLLM implementation)

# The following config has been verified on 8x L40S-48G GPU.
modes:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this mapping here? @Semmer2 @lishunyang12 @nussejzz PTAL

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we load both stages and let workload to decide which stage to go.
Device memory utilization gets doubled.
This PR suggested a simple fix by using modes to decide if uses want to go text-to-image / image-to-text.

I am thinking a more aggressive fix by sharing same weight for different stages, if that makes sense, I can init a RFC and have some discussion on that?

- mode: text-to-image
stages: [1]
- mode: image-to-text
stages: [0]
stage_args:
- stage_id: 0
stage_type: llm # Use llm stage type for AR stages
Expand Down Expand Up @@ -37,6 +42,40 @@ stage_args:
seed: 42
detokenize: True
repetition_penalty: 1.1
- stage_id: 1
stage_type: diffusion
runtime:
process: true
devices: "0,1,2,3,4,5,6,7"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Bounty-hunter , that is because of the initial config is using all 8 cards.
If you want to use 4 cards, need to manual update here.

max_batch_size: 1
engine_args:
model_stage: diffusion
gpu_memory_utilization: 0.9
enforce_eager: true
engine_output_type: image
distributed_executor_backend: "mp"
enable_prefix_caching: false
max_num_batched_tokens: 32768
vae_use_slicing: false
vae_use_tiling: false
cache_backend: null
cache_config: null
enable_cache_dit_summary: false
parallel_config:
pipeline_parallel_size: 1
data_parallel_size: 1
tensor_parallel_size: 8
enable_expert_parallel: false
sequence_parallel_size: 1
ulysses_degree: 1
ring_degree: 1
cfg_parallel_size: 1
vae_patch_parallel_size: 1
use_hsdp: false
hsdp_shard_size: -1
hsdp_replicate_size: 1
final_output: true
final_output_type: image

# Top-level runtime config (concise): default windows and stage edges
runtime:
Expand Down
83 changes: 83 additions & 0 deletions vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Stage config for running Hunyuan-Image3.0 with architecture of OmniLLM.
# Stage 0: AR Model (vLLM implementation)

# The following config has been verified on 8x Max 1550 GPU.
modes:
- mode: text-to-image
stages: [1]
- mode: image-to-text
stages: [0]
stage_args:
- stage_id: 0
stage_type: llm # Use llm stage type to launch OmniLLM
runtime:
process: true # Run this stage in a separate process
devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage
max_batch_size: 1
engine_args:
model_stage: AR
model_arch: HunyuanImage3ForCausalMM
worker_cls: vllm_omni.platforms.xpu.worker.xpu_ar_worker.XPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.95
enforce_eager: true # Now we only support eager mode
trust_remote_code: true
engine_output_type: latent
enable_prefix_caching: false
max_num_batched_tokens: 32784
tensor_parallel_size: 8
pipeline_parallel_size: 1
enable_expert_parallel: true
quantization: "fp8"
is_comprehension: true
final_output: true
final_output_type: text
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 2048
seed: 42
detokenize: True
repetition_penalty: 1.1
- stage_id: 1
stage_type: diffusion
runtime:
process: true
devices: "0,1,2,3,4,5,6,7"
max_batch_size: 1
engine_args:
model_stage: diffusion
gpu_memory_utilization: 0.9
enforce_eager: true
engine_output_type: image
distributed_executor_backend: "mp"
enable_prefix_caching: false
vae_use_slicing: false
vae_use_tiling: false
cache_backend: null
cache_config: null
enable_cache_dit_summary: false
quantization: "fp8"
parallel_config:
pipeline_parallel_size: 1
data_parallel_size: 1
tensor_parallel_size: 8
enable_expert_parallel: true
sequence_parallel_size: 1
ulysses_degree: 1
ring_degree: 1
cfg_parallel_size: 1
vae_patch_parallel_size: 1
use_hsdp: false
hsdp_shard_size: -1
hsdp_replicate_size: 1
final_output: true
final_output_type: image

# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
defaults:
window_size: -1 # Simplified: trigger downstream only after full upstream completion
max_inflight: 1 # Simplified: process serially within each stage
Loading