From 8ce425ffb11064f9cf67b21aff0679324e04f8d2 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 11 Mar 2026 21:24:02 +0000 Subject: [PATCH 1/3] fix AR path for xpu Signed-off-by: Chendi Xue --- .../hunyuan_image3/autoencoder_kl_3d.py | 4 +- .../stage_configs/hunyuan_image_3_moe.yaml | 43 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml diff --git a/vllm_omni/model_executor/models/hunyuan_image3/autoencoder_kl_3d.py b/vllm_omni/model_executor/models/hunyuan_image3/autoencoder_kl_3d.py index e87d9b8b199..f8541a4afb0 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/autoencoder_kl_3d.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/autoencoder_kl_3d.py @@ -21,6 +21,8 @@ from einops import rearrange from torch import Tensor, nn +from vllm_omni.diffusion.distributed.utils import get_local_device + class DiagonalGaussianDistribution: def __init__(self, parameters: torch.Tensor, deterministic: bool = False): @@ -506,7 +508,7 @@ def __init__( self.use_compile = False - self.empty_cache = torch.empty(0, device="cuda") + self.empty_cache = torch.empty(0, device=get_local_device()) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml new file mode 100644 index 00000000000..48735e02fe6 --- /dev/null +++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml @@ -0,0 +1,43 @@ +# Stage config for running Hunyuan-Image3.0 with architecture of OmniLLM. +# Stage 0: AR Model (vLLM implementation) + +# The following config has been verified on 8x Max 1550 GPU. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true # Run this stage in a separate process + devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage + max_batch_size: 1 + engine_args: + model_stage: AR + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.platforms.xpu.worker.xpu_ar_worker.XPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32784 + tensor_parallel_size: 8 + pipeline_parallel_size: 1 + enable_expert_parallel: True + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage From 0b72f2652b522a4f9b6c56e1677ab41367875404 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 16 Mar 2026 22:28:05 +0000 Subject: [PATCH 2/3] Enable a new config - mode - to decide stage selection Signed-off-by: Chendi Xue --- .../hunyuan_image3/image_to_text.py | 6 +- .../text_to_image/text_to_image.py | 1 + vllm_omni/entrypoints/utils.py | 96 +++++++++++++++++++ .../stage_configs/hunyuan_image_3_moe.yaml | 39 ++++++++ .../stage_configs/hunyuan_image_3_moe.yaml | 41 +++++++- 5 files changed, 180 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py index dbae0431555..402f4756060 100644 --- a/examples/offline_inference/hunyuan_image3/image_to_text.py +++ b/examples/offline_inference/hunyuan_image3/image_to_text.py @@ -52,10 +52,12 @@ def load_image(image_path: str) -> Image.Image: def main(args: argparse.Namespace) -> None: - omni = Omni(model=args.model) + omni = Omni(model=args.model, mode="image-to-text") + + prompt = "<|startoftext|>You are an assistant that understands images and outputs text." + args.prompt prompt_dict = { - "prompt": args.prompt, + "prompt": prompt, "modalities": ["text"], } diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 6f98403b93b..0e62e0f2bb6 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -305,6 +305,7 @@ def main(): "parallel_config": parallel_config, "enforce_eager": args.enforce_eager, "enable_cpu_offload": args.enable_cpu_offload, + "mode": "text-to-image", **lora_args, **quant_kwargs, } diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 0e31bfa7c2f..0c21085f089 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -296,6 +296,100 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None return stage_args +def filter_stages( + config_path: str | None, + stage_configs: list, + kwargs: dict | None, +) -> list: + """Filter stage configs by mode when YAML defines a `modes` section. + + The YAML can define, e.g.: + + modes: + - mode: text-to-image + stages: [1] + - mode: image-to-text + stages: [0] + + When users pass `mode="image-to-text"` into Omni(**kwargs), only the stages + listed for that mode are returned. If no mode is provided, defaults to + "text-to-image". If no modes are defined or filtering fails, returns the + original stage_configs unchanged. + + Args: + config_path: Path to the YAML config (used to read `modes`). + stage_configs: Loaded list of stage configs. + kwargs: Engine/caller kwargs; may contain "mode". + + Returns: + Filtered list of stage configs (or original list if filtering not applied). + """ + if not stage_configs or config_path is None: + return stage_configs + + try: + cfg = load_yaml_config(config_path) + yaml_modes = getattr(cfg, "modes", None) + if yaml_modes is None: + return stage_configs + + mode_to_stage_ids: dict[str, list[int]] = {} + if yaml_modes is not None: + for entry in yaml_modes: + mode_name = None + stages = None + if hasattr(entry, "mode") or hasattr(entry, "stages"): + mode_name = getattr(entry, "mode", None) + stages = getattr(entry, "stages", None) + elif isinstance(entry, dict): + mode_name = entry.get("mode") + stages = entry.get("stages") + + if mode_name is None or stages is None: + continue + + if isinstance(stages, int): + stage_list = [stages] + else: + stage_list = list(stages) + + mode_to_stage_ids[str(mode_name)] = [int(sid) for sid in stage_list] + + # No modes section or empty mapping: use all stages and return early. + active_mode: str | None = None + if isinstance(kwargs, dict): + active_mode = kwargs.get("mode") + + if active_mode is None: + active_mode = "text-to-image" + + if active_mode not in mode_to_stage_ids: + logger.warning( + "Requested mode '%s' not found in config '%s'; available modes: %s. Using all stages.", + active_mode, + config_path, + sorted(mode_to_stage_ids.keys()), + ) + return stage_configs + + allowed_ids = set(mode_to_stage_ids[active_mode]) + filtered_stage_configs = [sc for sc in stage_configs if getattr(sc, "stage_id", None) in allowed_ids] + if not filtered_stage_configs: + logger.warning( + "Mode '%s' in config '%s' resolved to stage ids %s, but none matched loaded stage_args. " + "Falling back to all stages.", + active_mode, + config_path, + sorted(allowed_ids), + ) + return stage_configs + + return filtered_stage_configs + except Exception as e: + logger.warning("Failed to apply mode-based stage filtering: %s", e) + return stage_configs + + def load_and_resolve_stage_configs( model: str, stage_configs_path: str | None, @@ -327,6 +421,8 @@ def load_and_resolve_stage_configs( config_path = stage_configs_path stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=kwargs) + stage_configs = filter_stages(config_path, stage_configs, kwargs) + return config_path, stage_configs diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml index e8a603af447..edce4c5856b 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml @@ -2,6 +2,11 @@ # Stage 0: AR Model (vLLM implementation) # The following config has been verified on 8x L40S-48G GPU. +modes: + - mode: text-to-image + stages: [1] + - mode: image-to-text + stages: [0] stage_args: - stage_id: 0 stage_type: llm # Use llm stage type to launch OmniLLM @@ -37,6 +42,40 @@ stage_args: seed: 42 detokenize: True repetition_penalty: 1.1 + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "0,1,2,3,4,5,6,7" + max_batch_size: 1 + engine_args: + model_stage: diffusion + gpu_memory_utilization: 0.9 + enforce_eager: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + vae_use_slicing: false + vae_use_tiling: false + cache_backend: null + cache_config: null + enable_cache_dit_summary: false + parallel_config: + pipeline_parallel_size: 1 + data_parallel_size: 1 + tensor_parallel_size: 8 + enable_expert_parallel: false + sequence_parallel_size: 1 + ulysses_degree: 1 + ring_degree: 1 + cfg_parallel_size: 1 + vae_patch_parallel_size: 1 + use_hsdp: false + hsdp_shard_size: -1 + hsdp_replicate_size: 1 + final_output: true + final_output_type: image # Top-level runtime config (concise): default windows and stage edges runtime: diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml index 48735e02fe6..80f94a9f3f5 100644 --- a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml @@ -2,6 +2,11 @@ # Stage 0: AR Model (vLLM implementation) # The following config has been verified on 8x Max 1550 GPU. +modes: + - mode: text-to-image + stages: [1] + - mode: image-to-text + stages: [0] stage_args: - stage_id: 0 stage_type: llm # Use llm stage type to launch OmniLLM @@ -22,7 +27,7 @@ stage_args: max_num_batched_tokens: 32784 tensor_parallel_size: 8 pipeline_parallel_size: 1 - enable_expert_parallel: True + enable_expert_parallel: true is_comprehension: true final_output: true final_output_type: text @@ -34,6 +39,40 @@ stage_args: seed: 42 detokenize: True repetition_penalty: 1.1 + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "0,1,2,3,4,5,6,7" + max_batch_size: 1 + engine_args: + model_stage: diffusion + gpu_memory_utilization: 0.9 + enforce_eager: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + vae_use_slicing: false + vae_use_tiling: false + cache_backend: null + cache_config: null + enable_cache_dit_summary: false + parallel_config: + pipeline_parallel_size: 1 + data_parallel_size: 1 + tensor_parallel_size: 8 + enable_expert_parallel: false + sequence_parallel_size: 1 + ulysses_degree: 1 + ring_degree: 1 + cfg_parallel_size: 1 + vae_patch_parallel_size: 1 + use_hsdp: false + hsdp_shard_size: -1 + hsdp_replicate_size: 1 + final_output: true + final_output_type: image # Top-level runtime config (concise): default windows and stage edges runtime: From c300c6dab06f7328665a5141410bea8afaaaa6ae Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Mar 2026 18:55:00 +0000 Subject: [PATCH 3/3] update config to work with #1935 Signed-off-by: Chendi Xue --- vllm_omni/entrypoints/utils.py | 1 + .../platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 97dbe993955..40d5ae6e687 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -419,6 +419,7 @@ def load_and_resolve_stage_configs( stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=kwargs) stage_configs = filter_stages(config_path, stage_configs, kwargs) + logger.debug(f"stage_configs: {stage_configs}") return config_path, stage_configs diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml index 80f94a9f3f5..8f969ced5f4 100644 --- a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml @@ -28,6 +28,7 @@ stage_args: tensor_parallel_size: 8 pipeline_parallel_size: 1 enable_expert_parallel: true + quantization: "fp8" is_comprehension: true final_output: true final_output_type: text @@ -52,17 +53,17 @@ stage_args: engine_output_type: image distributed_executor_backend: "mp" enable_prefix_caching: false - max_num_batched_tokens: 32768 vae_use_slicing: false vae_use_tiling: false cache_backend: null cache_config: null enable_cache_dit_summary: false + quantization: "fp8" parallel_config: pipeline_parallel_size: 1 data_parallel_size: 1 tensor_parallel_size: 8 - enable_expert_parallel: false + enable_expert_parallel: true sequence_parallel_size: 1 ulysses_degree: 1 ring_degree: 1