diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 7280d7d2e23..9a1324305cf 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -13,6 +13,7 @@ import numpy as np import soundfile as sf from PIL import Image +import vllm from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -237,6 +238,7 @@ def get_multi_audios_query() -> QueryResult: def main(args): model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + print(f"="*20,"\n",f"vllm version: {vllm.__version__}","\n","="*20) # Get paths from args video_path = getattr(args, "video_path", None) @@ -302,8 +304,8 @@ def main(args): sampling_params_list = [ thinker_sampling_params, - talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni - code2wav_sampling_params, + # talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni + # code2wav_sampling_params, ] if args.txt_prompts is None: diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index e074689c9e2..2e53a7af2e1 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -6,15 +6,10 @@ import vllm.envs as envs from pydantic import ConfigDict from pydantic.dataclasses import dataclass -from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig, config from vllm.config.model import ( _RUNNER_CONVERTS, - _RUNNER_TASKS, - ConvertOption, - ConvertType, - RunnerOption, - TaskOption, _get_and_verify_dtype, get_served_model_name, ) @@ -31,11 +26,8 @@ from vllm.transformers_utils.gguf_utils import ( maybe_patch_hf_config_from_gguf, ) -from vllm.transformers_utils.utils import ( - is_gguf, - maybe_model_redirect, -) - +from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.transformers_utils.gguf_utils import is_gguf import vllm_omni.model_executor.models as me_models logger = init_logger(__name__) @@ -116,7 +108,9 @@ def __post_init__( video_pruning_rate: float | None, ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) - self.served_model_name = get_served_model_name(self.model, self.served_model_name) + self.served_model_name = get_served_model_name( + self.model, self.served_model_name + ) self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: @@ -146,14 +140,6 @@ def __post_init__( self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if (backend := envs.VLLM_ATTENTION_BACKEND) and backend == "FLASHINFER" and find_spec("flashinfer") is None: - raise ValueError( - "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " - "module was not found. See " - "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 - "for instructions on how to install it." - ) - if self.override_attention_dtype is not None and not current_platform.is_rocm(): warnings.warn( "override-attention-dtype is set but not using ROCm platform", @@ -181,115 +167,24 @@ def __post_init__( if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = self.draw_hf_text_config() - self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) + self.attention_chunk_size = getattr( + self.hf_text_config, "attention_chunk_size", None + ) self.encoder_config = self._get_encoder_config() - # Try to load image processor config, but allow it to fail for stages that don't need it - try: - self.hf_image_processor_config = get_hf_image_processor_config( - self.model, hf_token=self.hf_token, revision=self.revision - ) - except (OSError, ValueError, IndexError) as e: - # Some stages (e.g., code2wav, talker) don't need image processor - # Log warning but allow initialization to continue - logger.warning( - f"Failed to load image processor config for model '{self.model}': {e}. " - "This is expected for stages that don't require image processing." - ) - self.hf_image_processor_config = None + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=self.hf_token, revision=self.revision + ) + self.model_arch_config = self.get_model_arch_config() architectures = self.architectures registry = self.registry is_generative_model = registry.is_text_generation_model(architectures, self) is_pooling_model = registry.is_pooling_model(architectures, self) - def _task_to_convert(task: TaskOption) -> ConvertType: - if task == "embedding" or task == "embed": - return "embed" - if task == "classify": - return "classify" - if task == "reward": - return "reward" - if task == "score": - new_task = self._get_default_pooling_task(architectures) - return "classify" if new_task == "classify" else "embed" - - return "none" - - if self.task is not None: - runner: RunnerOption = "auto" - convert: ConvertOption = "auto" - msg_prefix = ( - "The 'task' option has been deprecated and will be removed in v0.13.0 or v1.0, whichever comes first." - ) - msg_hint = "Please remove this option." - - is_generative_task = self.task in _RUNNER_TASKS["generate"] - is_pooling_task = self.task in _RUNNER_TASKS["pooling"] - - if is_generative_model and is_pooling_model: - if is_generative_task: - runner = "generate" - convert = "auto" - msg_hint = ( - "Please replace this option with `--runner " - "generate` to continue using this model " - "as a generative model." - ) - elif is_pooling_task: - runner = "pooling" - convert = "auto" - msg_hint = ( - "Please replace this option with `--runner " - "pooling` to continue using this model " - "as a pooling model." - ) - else: # task == "auto" - pass - elif is_generative_model or is_pooling_model: - if is_generative_task: - runner = "generate" - convert = "auto" - msg_hint = "Please remove this option" - elif is_pooling_task: - runner = "pooling" - convert = _task_to_convert(self.task) - msg_hint = ( - "Please replace this option with `--convert " - f"{convert}` to continue using this model " - "as a pooling model." - ) - else: # task == "auto" - pass - else: - # Neither generative nor pooling model - try to convert if possible - if is_pooling_task: - runner = "pooling" - convert = _task_to_convert(self.task) - msg_hint = ( - "Please replace this option with `--runner pooling " - f"--convert {convert}` to continue using this model " - "as a pooling model." - ) - else: - debug_info = { - "architectures": architectures, - "is_generative_model": is_generative_model, - "is_pooling_model": is_pooling_model, - } - raise AssertionError( - "The model should be a generative or " - "pooling model when task is set to " - f"{self.task!r}. Found: {debug_info}" - ) - - self.runner = runner - self.convert = convert - - msg = f"{msg_prefix} {msg_hint}" - warnings.warn(msg, DeprecationWarning, stacklevel=2) - self.runner_type = self._get_runner_type(architectures, self.runner) - self.convert_type = self._get_convert_type(architectures, self.runner_type, self.convert) + self.convert_type = self._get_convert_type( + architectures, self.runner_type, self.convert + ) if self.runner_type == "generate" and not is_generative_model: generate_converts = _RUNNER_CONVERTS["generate"] @@ -325,9 +220,12 @@ def _task_to_convert(task: TaskOption) -> ConvertType: if getattr(self.pooler_config, k) is None: setattr(self.pooler_config, k, v) - default_pooling_type = self._model_info.default_pooling_type - if self.pooler_config.pooling_type is None: - self.pooler_config.pooling_type = default_pooling_type + default_seq_pooling_type = self._model_info.default_seq_pooling_type + if self.pooler_config.seq_pooling_type is None: + self.pooler_config.seq_pooling_type = default_seq_pooling_type + default_tok_pooling_type = self._model_info.default_tok_pooling_type + if self.pooler_config.tok_pooling_type is None: + self.pooler_config.tok_pooling_type = default_tok_pooling_type self.dtype: torch.dtype = _get_and_verify_dtype( self.model, @@ -339,9 +237,17 @@ def _task_to_convert(task: TaskOption) -> ConvertType: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) + + if self.is_encoder_decoder: + self.mm_processor_cache_gb = 0 + logger.info("Encoder-decoder model detected, disabling mm processor cache.") + # Init multimodal config if needed if self._model_info.supports_multimodal: - if mm_encoder_tp_mode == "data" and not self._model_info.supports_multimodal_encoder_tp_data: + if ( + mm_encoder_tp_mode == "data" + and not self._model_info.supports_multimodal_encoder_tp_data + ): logger.warning_once( "This model does not support `--mm-encoder-tp-mode data`. " "Falling back to `--mm-encoder-tp-mode weights`." @@ -363,7 +269,9 @@ def _task_to_convert(task: TaskOption) -> ConvertType: video_pruning_rate=video_pruning_rate, ) - mm_config_kwargs = {k: v for k, v in mm_config_kwargs.items() if v is not None} + mm_config_kwargs = { + k: v for k, v in mm_config_kwargs.items() if v is not None + } self.multimodal_config = MultiModalConfig(**mm_config_kwargs) @@ -382,7 +290,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType: # Avoid running try_verify_and_update_config multiple times self.config_updated = False - + self._try_verify_and_update_model_config() self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 7918e16878e..dc3f56ac2db 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -73,6 +73,11 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + cudagraph_stats = model_runner_output.cudagraph_stats + + perf_stats: PerfStats | None = None + if self.perf_metrics and self.perf_metrics.is_enabled(): + perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None @@ -131,11 +136,14 @@ def update_from_output( spec_decoding_stats, num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted, + num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens, + request_id=req_id, ) stopped = False new_logprobs = None new_token_ids = generated_token_ids + pooler_output = pooler_outputs[req_index] if pooler_outputs else None kv_transfer_params = None status_before_stop = request.status @@ -143,14 +151,34 @@ def update_from_output( if new_token_ids: new_token_ids, stopped = self._update_request_with_output(request, new_token_ids) - # Stop checking for pooler models. - pooler_output = None - if pooler_outputs: - pooler_output = pooler_outputs[req_index] + if pooler_output: + # Note: As we occupied the pooler output, for multimodal outputs, we do not intermediate stop checking for pooler output if request.output_token_ids: - stopped = check_stop(request, self.max_model_len, pooler_output) - + stopped = check_stop(request, self.max_model_len) + routed_experts = None if stopped: + if self.vllm_config.model_config.enable_return_routed_experts: + kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) + block_ids = kv_blocks.get_block_ids()[0] + num_tokens = request.num_tokens - 1 + + # compute slot mapping + block_ids_array = np.array(block_ids, dtype=np.int32) + num_blocks = len(block_ids) + block_size = self.block_size + + # generate block offsets + block_offsets = np.arange(0, block_size) + + # compute slot mapping: slot = block_id * block_size + offset + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_array.reshape((num_blocks, 1)) * block_size + ).flatten()[:num_tokens] + + routed_experts = self.routed_experts_reader.get_routed_experts( + indices=slot_mapping + ) kv_transfer_params = self._free_request(request) if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) @@ -165,7 +193,13 @@ def update_from_output( struct_output_request = request.structured_output_request assert struct_output_request is not None assert struct_output_request.grammar is not None - struct_output_request.grammar.accept_tokens(req_id, new_token_ids) + ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids) + if not ok: + logger.warning( + "Unexpected: grammar rejected tokens %s for request %s.", + new_token_ids, + req_id, + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] @@ -200,7 +234,21 @@ def update_from_output( if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) - + + if failed_kv_load_req_ids and not self.recompute_kv_load_failures: + requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids] + self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR) + for request in requests: + outputs[request.client_index].append( + EngineCoreOutput( + request_id=request.request_id, + new_token_ids=[], + finish_reason=request.get_finished_reason(), + events=request.take_events(), + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + ) + ) # KV Connector: update state for finished KV Transfers. if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) diff --git a/vllm_omni/diffusion/forward_context.py b/vllm_omni/diffusion/forward_context.py index 3c33f8105eb..d898bceaf62 100644 --- a/vllm_omni/diffusion/forward_context.py +++ b/vllm_omni/diffusion/forward_context.py @@ -86,5 +86,14 @@ def set_forward_context( attn_metadata=attn_metadata, split_text_embed_in_sp=split_text_embed_in_sp, ) + # vLLM CustomOp dispatch (e.g. QKVParallelLinear) requires a global + # vLLM config set via set_current_vllm_config(). with override_forward_context(forward_context): - yield + if vllm_config is None: + yield + else: + # Local import to avoid importing vllm.config.vllm at module import time. + from vllm.config.vllm import set_current_vllm_config + + with set_current_vllm_config(vllm_config): + yield diff --git a/vllm_omni/diffusion/models/glm_image/__init__.py b/vllm_omni/diffusion/models/glm_image/__init__.py new file mode 100644 index 00000000000..ac7a98fa743 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GLM Image diffusion model components.""" + +from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageKVCache, + GlmImageTransformer2DModel, +) +from vllm_omni.diffusion.models.glm_image.pipeline_glm_image import ( + GlmImagePipeline, + get_glm_image_post_process_func, + # get_glm_image_pre_process_func, +) + +__all__ = [ + "GlmImageKVCache", + "GlmImagePipeline", + "GlmImageTransformer2DModel", + "get_glm_image_post_process_func", + # "get_glm_image_pre_process_func", +] diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py new file mode 100644 index 00000000000..09f7b17e133 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -0,0 +1,793 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from enum import Enum +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.cache.base import CachedTransformer +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +class GlmImageImageProjector(nn.Module): + """Projects latent image patches to transformer hidden dimension.""" + + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + # Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p] + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class GlmImageRotaryPosEmbed(nn.Module): + """Rotary positional embedding for 2D image patches.""" + + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(height, device=hidden_states.device) + w_seq = torch.arange(width, device=hidden_states.device) + h_inv_freq = h_inv_freq.to(hidden_states.device) + w_inv_freq = w_inv_freq.to(hidden_states.device) + + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices: [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1).expand(height, width, -1) + freqs_w = freqs_w.unsqueeze(0).expand(height, width, -1) + + # Concatenate: [height, width, dim//2] -> [height, width, dim] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class GlmImageAdaLayerNormZero(nn.Module): + """Adaptive LayerNorm with zero initialization for both image and text streams.""" + + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageAdaLayerNormContinuous(nn.Module): + """Final AdaLN for output projection (no activation before Linear).""" + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # NO SiLU here + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class KVCacheMode(Enum): + """Mode for KV cache operations. + + - WRITE: Store the K/V tensors from condition images + - READ: Concatenate cached K/V with current K/V + - SKIP: Do not use cache (pass-through) + """ + + WRITE = "write" + READ = "read" + SKIP = "skip" + + +class GlmImageLayerKVCache: + """KV cache for a single attention layer. + + Stores key and value tensors for image editing. The cache accumulates + KV pairs during write mode and provides them during read mode. + + Shape convention (vllm-omni): + key/value: [batch_size, seq_length, num_heads, head_dim] + """ + + def __init__(self): + self.k_cache: torch.Tensor | None = None + self.v_cache: torch.Tensor | None = None + + def store(self, key: torch.Tensor, value: torch.Tensor) -> None: + """Store or accumulate KV tensors. + + If cache is empty, stores the tensors directly. + If cache is not empty, concatenates new tensors along seq_length dim. + + Args: + key: Key tensor of shape [B, S, H, D] + value: Value tensor of shape [B, S, H, D] + """ + if self.k_cache is None: + self.k_cache = key + self.v_cache = value + else: + # Concatenate along sequence dimension (dim=1 for [B, S, H, D]) + self.k_cache = torch.cat([self.k_cache, key], dim=1) + self.v_cache = torch.cat([self.v_cache, value], dim=1) + + def get(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Get cached KV tensors. + + Returns: + Tuple of (k_cache, v_cache), both may be None if cache is empty. + """ + return self.k_cache, self.v_cache + + def clear(self) -> None: + """Clear the cache.""" + self.k_cache = None + self.v_cache = None + + @property + def is_empty(self) -> bool: + """Check if cache is empty.""" + return self.k_cache is None + + def __repr__(self) -> str: + if self.is_empty: + return "GlmImageLayerKVCache(empty)" + return f"GlmImageLayerKVCache(k_shape={self.k_cache.shape}, v_shape={self.v_cache.shape})" + + +class GlmImageKVCache: + """Container for all layers' KV caches. + + Manages KV cache for all transformer layers in GLM-Image model. + Provides a unified interface for setting mode and clearing cache. + + Args: + num_layers: Number of transformer layers in the model. + + Example: + kv_cache = GlmImageKVCache(num_layers=28) + kv_cache.set_mode(KVCacheMode.WRITE) + # ... process condition image ... + kv_cache.set_mode(KVCacheMode.READ) + # ... process target image ... + kv_cache.clear() + """ + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + self._mode: KVCacheMode | None = None + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + """Get cache for a specific layer. + + Args: + layer_idx: Index of the layer (0-indexed). + + Returns: + GlmImageLayerKVCache for the specified layer. + + Raises: + IndexError: If layer_idx is out of range. + """ + if layer_idx < 0 or layer_idx >= self.num_layers: + raise IndexError(f"Layer index {layer_idx} out of range [0, {self.num_layers})") + return self.caches[layer_idx] + + def __len__(self) -> int: + """Return number of layers.""" + return self.num_layers + + @property + def mode(self) -> KVCacheMode | None: + """Get current cache mode.""" + return self._mode + + def set_mode(self, mode: KVCacheMode | str | None) -> None: + """Set cache mode for all layers. + + Args: + mode: Cache mode (WRITE, READ, SKIP) or string ("write", "read", "skip"). + Use None to disable cache operations. + + Raises: + ValueError: If mode is an invalid string. + """ + if mode is None: + self._mode = None + elif isinstance(mode, str): + try: + self._mode = KVCacheMode(mode.lower()) + except ValueError: + raise ValueError(f"Invalid mode: '{mode}', must be one of 'write', 'read', 'skip'") + else: + self._mode = mode + + def clear(self) -> None: + """Clear cache for all layers and reset mode.""" + for cache in self.caches: + cache.clear() + self._mode = None + + @property + def is_empty(self) -> bool: + """Check if all layer caches are empty.""" + return all(cache.is_empty for cache in self.caches) + + def __repr__(self) -> str: + mode_str = self._mode.value if self._mode else "None" + return f"GlmImageKVCache(num_layers={self.num_layers}, mode={mode_str}, is_empty={self.is_empty})" + + +class GlmImageAttention(nn.Module): + """ + Joint attention for GLM-Image model using vllm-omni's optimized attention. + + This combines text and image streams for joint attention computation. + Supports KV caching for image editing workflows via external cache. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + out_bias: bool = True, + eps: float = 1e-5, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + + # QKV projection (fused for efficiency) + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + disable_tp=True, + bias=True, + ) + + # QK normalization (LayerNorm, not RMSNorm for GLM-Image) + self.norm_q = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) + self.norm_k = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) + + # Output projection + self.to_out = nn.Sequential( + nn.Linear(self.inner_dim, dim, bias=out_bias), + nn.Dropout(0.0), + ) + + # RoPE and attention + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + kv_cache: GlmImageLayerKVCache | None = None, + kv_cache_mode: KVCacheMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for joint attention. + + Args: + hidden_states: Image hidden states [B, img_seq_len, D] + encoder_hidden_states: Text hidden states [B, text_seq_len, D] + image_rotary_emb: Tuple of (cos, sin) for RoPE + attention_mask: Optional attention mask for text tokens + kv_cache: Optional layer KV cache for image editing + kv_cache_mode: Cache mode (WRITE, READ, SKIP) + + Returns: + Tuple of (image_hidden_states, text_hidden_states) + """ + dtype = encoder_hidden_states.dtype + batch_size, text_seq_length, _ = encoder_hidden_states.shape + + # Concatenate text and image: [text, image] + hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # QKV projection + qkv, _ = self.to_qkv(hidden_states_combined) + query, key, value = qkv.chunk(3, dim=-1) + + # Reshape: [B, S, H*D] -> [B, S, H, D] + query = query.unflatten(-1, (self.num_heads, -1)) + key = key.unflatten(-1, (self.num_heads, -1)) + value = value.unflatten(-1, (self.num_heads, -1)) + + # QK normalization + query = self.norm_q(query).to(dtype=dtype) + key = self.norm_k(key).to(dtype=dtype) + + # Apply RoPE only to image tokens (not text tokens) + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + # Only apply RoPE to image part (after text_seq_length) + query_img = query[:, text_seq_length:, :, :] + key_img = key[:, text_seq_length:, :, :] + from diffusers.models.embeddings import apply_rotary_emb + query_img = apply_rotary_emb(query_img,image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + # key_img = self.rope(key_img, cos, sin) + key_img = apply_rotary_emb(key_img,image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) + key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) + + # Handle KV cache for image editing + if kv_cache is not None and kv_cache_mode is not None: + if kv_cache_mode == KVCacheMode.WRITE: + kv_cache.store(key, value) + elif kv_cache_mode == KVCacheMode.READ: + k_cached, v_cached = kv_cache.get() + if k_cached is not None: + key = torch.cat([k_cached, key], dim=1) + value = torch.cat([v_cached, value], dim=1) + # KVCacheMode.SKIP: do nothing + + # Attention computation + hidden_states_out = self.attn(query, key, value) + hidden_states_out = hidden_states_out.flatten(2, 3) + hidden_states_out = hidden_states_out.to(dtype) + + # Output projection + hidden_states_out = self.to_out(hidden_states_out) + + # Split back to text and image + encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] + hidden_states_out = hidden_states_out[:, text_seq_length:, :] + + return hidden_states_out, encoder_hidden_states_out + + +class GlmImageTransformerBlock(nn.Module): + """Single transformer block for GLM-Image.""" + + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ) -> None: + super().__init__() + + # 1. Attention with AdaLN + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) + self.attn1 = GlmImageAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + kv_cache: GlmImageLayerKVCache | None = None, + kv_cache_mode: KVCacheMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for transformer block. + + Args: + hidden_states: Image hidden states + encoder_hidden_states: Text hidden states + temb: Timestep embedding + image_rotary_emb: RoPE embeddings + attention_mask: Text attention mask + attention_kwargs: Additional attention arguments + kv_cache: Layer-specific KV cache for image editing + kv_cache_mode: Cache mode (WRITE, READ, SKIP) + + Returns: + Tuple of (image_hidden_states, text_hidden_states) + """ + # 1. Timestep conditioning via AdaLN + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + kv_cache_mode=kv_cache_mode, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class GlmImageTransformer2DModel(CachedTransformer): + """ + GLM-Image Transformer model for 2D image generation. + + This is the vllm-omni optimized version of the GLM-Image DiT model. + + Args: + od_config: OmniDiffusionConfig containing model configuration. The + transformer hyper-parameters (e.g. patch size / channels / heads) + are read from `od_config.tf_model_config`. + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + ): + super().__init__() + + patch_size = od_config.tf_model_config.patch_size + in_channels = od_config.tf_model_config.in_channels + out_channels = od_config.tf_model_config.out_channels + num_attention_heads = od_config.tf_model_config.num_attention_heads + attention_head_dim = od_config.tf_model_config.attention_head_dim + time_embed_dim = od_config.tf_model_config.time_embed_dim + condition_dim = od_config.tf_model_config.condition_dim + prior_vq_quantizer_codebook_size = od_config.tf_model_config.prior_vq_quantizer_codebook_size + text_embed_dim = od_config.tf_model_config.text_embed_dim + + + + # Get num_layers from config if available + model_config = od_config.tf_model_config + if model_config is not None and hasattr(model_config, "num_layers"): + num_layers = model_config.num_layers + + self.od_config = od_config + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords + pooled_projection_dim = 2 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. RoPE + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) + self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=time_embed_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + kv_cache: GlmImageKVCache | None = None, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + Forward pass of the GLM-Image Transformer. + + Args: + hidden_states: Input latent tensor of shape [B, C, H, W]. + encoder_hidden_states: Text embeddings of shape [B, S, D]. + prior_token_id: Prior VQ token IDs. + prior_token_drop: Mask for dropping prior tokens (CFG). + timestep: Diffusion timestep. + target_size: Target image size for conditioning. + crop_coords: Crop coordinates for conditioning. + attention_kwargs: Additional attention arguments. + return_dict: Whether to return a dataclass. + attention_mask: Optional attention mask for text tokens. + image_rotary_emb: Pre-computed rotary embeddings. + kv_cache: Optional KV cache for image editing. When provided, + the cache's mode determines behavior: + - WRITE: Store KV from condition images + - READ: Use cached KV during generation + - SKIP: No caching (same as None) + + Returns: + Output tensor or Transformer2DModelOutput. + """ + batch_size, num_channels, height, width = hidden_states.shape + + # Get KV cache mode + kv_cache_mode = kv_cache.mode if kv_cache is not None else None + + # 1. RoPE + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + # Move to correct device + image_rotary_emb = ( + image_rotary_emb[0].to(hidden_states.device), + image_rotary_emb[1].to(hidden_states.device), + ) + + # 2. Patch & Timestep embeddings + p = self.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) + + # Prior embedding with dropout + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + hidden_states = hidden_states + prior_hidden_states + + # Timestep conditioning + temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) + + # 3. Transformer blocks + for layer_idx, block in enumerate(self.transformer_blocks): + # Get layer-specific KV cache if available + layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None + + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_cache=layer_kv_cache, + kv_cache_mode=kv_cache_mode, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from pretrained checkpoint. + + This method handles the mapping from diffusers weight names to vllm-omni weight names, + especially for fused QKV projections. + """ + stacked_params_mapping = [ + # Fused QKV projection: to_q, to_k, to_v -> to_qkv + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + + # Also include buffers (for any beta/eps parameters) + for name, buffer in self.named_buffers(): + params_dict[name] = buffer + + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Handle fused QKV projections + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + # Map diffusers name to vllm-omni name + name = name.replace(weight_name, param_name) + + if name not in params_dict: + logger.warning(f"Skipping weight {name} - not found in model") + break + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Standard weight loading (not fused) + if name not in params_dict: + logger.warning(f"Skipping weight {name} - not found in model") + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + return loaded_params + + def create_kv_cache(self) -> GlmImageKVCache: + """ + Create a KV cache for image editing. + + Returns a new GlmImageKVCache instance sized for this model's + number of transformer layers. Use this for image editing workflows. + + Example: + kv_cache = transformer.create_kv_cache() + kv_cache.set_mode("write") + transformer(condition_image, kv_cache=kv_cache) + kv_cache.set_mode("read") + for t in timesteps: + transformer(noisy_target, kv_cache=kv_cache) + kv_cache.clear() + + Returns: + GlmImageKVCache instance with correct number of layers. + """ + return GlmImageKVCache(num_layers=len(self.transformer_blocks)) + + @property + def num_layers(self) -> int: + """Return number of transformer layers.""" + return len(self.transformer_blocks) + + @property + def dtype(self) -> torch.dtype: + """Return dtype of model parameters.""" + return next(self.parameters()).dtype diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py new file mode 100644 index 00000000000..f582c3b9b69 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -0,0 +1,965 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GlmImagePipeline implementation for vLLM-Omni. + +This pipeline implements GLM-Image text-to-image generation with: +- AR stage: GlmImageForConditionalGeneration generates prior tokens +- DiT stage: GlmImageTransformer2DModel performs diffusion denoising +- VAE: AutoencoderKL decodes latents to images +""" + +from __future__ import annotations + +import inspect +import json +import logging +import os +import re +from collections.abc import Iterable +from math import sqrt + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import ( + ByT5Tokenizer, + GlmImageForConditionalGeneration, + GlmImageProcessor, + T5EncoderModel, +) + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageKVCache, + GlmImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_glm_image_post_process_func(od_config: OmniDiffusionConfig): + """Get post-processing function for GLM-Image pipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + block_out_channels = vae_config.get("block_out_channels", [128, 256, 512, 512]) + vae_scale_factor = 2 ** (len(block_out_channels) - 1) + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func(images: PIL.Image.Image): + return images + + return post_process_func + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + """Calculate timestep shift based on image sequence length.""" + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps. + Handles custom timesteps and sigmas schedules. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + # Both provided - check if scheduler supports both + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +) -> torch.Tensor: + """Extract latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImagePipeline(nn.Module): + """ + GLM-Image Pipeline for text-to-image and image-to-image generation. + + This pipeline integrates: + - AR model (GlmImageForConditionalGeneration): Generates prior image tokens + - Text encoder (T5EncoderModel): Encodes glyph/text embeddings + - DiT model (GlmImageTransformer2DModel): Diffusion transformer + - VAE (AutoencoderKL): Encodes/decodes images to/from latent space + + The pipeline flow: + 1. AR generates prior_token_ids from text prompt + 2. T5 encodes glyph text for text rendering + 3. DiT performs iterative denoising conditioned on prior tokens + 4. VAE decodes final latents to image + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + self.device = get_local_device() + + model = od_config.model + local_files_only = os.path.exists(model) + + if local_files_only: + model_path = model + else: + model_path = download_weights_from_hf_specific(model, od_config.revision, ["*"]) + + # Load scheduler + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_path, subfolder="scheduler", local_files_only=True + ) + + # Load AR model (vision_language_encoder) + logger.info("Loading GlmImageForConditionalGeneration (AR model)...") + self.vision_language_encoder = GlmImageForConditionalGeneration.from_pretrained( + model_path, + subfolder="vision_language_encoder", + local_files_only=True, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.vision_language_encoder.eval() + + # Load processor for AR model + self.processor = GlmImageProcessor.from_pretrained(model_path, subfolder="processor", local_files_only=True) + + # Load text encoder (T5 for glyph embeddings) + logger.info("Loading T5EncoderModel (glyph encoder)...") + self.text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + local_files_only=True, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.text_encoder.eval() + + # Load tokenizer for glyph encoding + self.tokenizer = ByT5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", local_files_only=True) + + # Load VAE + logger.info("Loading AutoencoderKL (VAE)...") + self.vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16 + ).to(self.device) + self.vae.eval() + + # Load transformer (DiT) + logger.info("Loading GlmImageTransformer2DModel (DiT)...") + self.transformer = GlmImageTransformer2DModel(od_config=od_config) + + # Weight sources for DiT loading + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=od_config.revision, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + # Configure scale factors + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 128 + + # Get transformer config for patch size + self._patch_size = getattr(self.transformer, "patch_size", 2) + + # ==================== Input Validation ==================== + + def check_inputs( + self, + prompt: str | list[str] | None, + height: int | None, + width: int | None, + prompt_embeds: torch.Tensor | None = None, + ) -> None: + """Validate input arguments before generation.""" + # Check dimension alignment + multiple_of = self.vae_scale_factor * self._patch_size + if height is not None and height % multiple_of != 0: + logger.warning( + f"`height` should be divisible by {multiple_of} but is {height}. " + "Dimensions will be adjusted accordingly." + ) + if width is not None and width % multiple_of != 0: + logger.warning( + f"`width` should be divisible by {multiple_of} but is {width}. Dimensions will be adjusted accordingly." + ) + + # Check prompt/prompt_embeds mutual exclusivity + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please provide only one of the two." + ) + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.") + + # Check prompt type + if prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` must be of type `str` or `list` but is {type(prompt)}") + + # ==================== AR Stage Methods ==================== + + @staticmethod + def _build_image_grid_thw( + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + existing_grid: torch.Tensor | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + """Build image grid tensor for AR model.""" + if existing_grid is None or existing_grid.numel() == 0: + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + device=device, + ) + else: + return torch.cat( + [existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], + dim=0, + ) + + @staticmethod + def _calculate_ar_generation_params( + token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool + ) -> tuple[int, int]: + """Calculate AR generation parameters.""" + large_image_tokens = token_h * token_w + small_image_tokens = prev_token_h * prev_token_w + + if is_text_to_image: + max_new_tokens = small_image_tokens + large_image_tokens + 1 + large_image_start_offset = small_image_tokens + else: + max_new_tokens = large_image_tokens + 1 + large_image_start_offset = 0 + + return max_new_tokens, large_image_start_offset + + @staticmethod + def _extract_large_image_tokens( + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + """Extract large image tokens from AR output.""" + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + """Upsample token IDs by 2x using nearest neighbor interpolation.""" + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + token_ids = token_ids.view(1, -1) + return token_ids + + @staticmethod + def _build_prompt_with_shape( + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + factor: int = 32, + ) -> tuple[str, int, int, int, int]: + """Build prompt with shape information for AR model.""" + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * (factor // 2)) + prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" + else: + expanded_prompt = f"{prompt}{token_h} {token_w}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + @torch.inference_mode() + def generate_prior_tokens( + self, + prompt: str, + height: int, + width: int, + image: list[PIL.Image.Image] | None = None, + factor: int = 32, + ) -> tuple[torch.Tensor, torch.Tensor | None, int, int]: + """ + Generate prior tokens using the AR model. + + Args: + prompt: Text prompt for generation + height: Target image height + width: Target image width + image: Optional condition images for image-to-image + factor: Token factor (default 32) + + Returns: + Tuple of (prior_token_ids, prior_token_image_ids, pixel_height, pixel_width) + """ + device = self.vision_language_encoder.device + height = (height // factor) * factor + width = (width // factor) * factor + is_text_to_image = image is None or len(image) == 0 + + expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( + prompt, height, width, is_text_to_image + ) + + # Build message content + content = [] + if image is not None: + for img in image: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": expanded_prompt}) + messages = [{"role": "user", "content": content}] + + # Apply chat template + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + # Build image grid + existing_grid = inputs.get("image_grid_thw") + inputs["image_grid_thw"] = self._build_image_grid_thw( + token_h, + token_w, + prev_h, + prev_w, + existing_grid=existing_grid if not is_text_to_image else None, + device=device, + ) + + max_new_tokens, large_image_offset = self._calculate_ar_generation_params( + token_h, token_w, prev_h, prev_w, is_text_to_image + ) + large_image_tokens = token_h * token_w + + inputs = inputs.to(device) + input_length = inputs["input_ids"].shape[-1] + + # Process condition images if provided + prior_token_image_ids = None + if image is not None and existing_grid is not None: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], existing_grid + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, existing_grid + ) + + # Generate with AR model + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + # Extract and upsample tokens + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs, input_length, large_image_offset, large_image_tokens + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + + return prior_token_ids, prior_token_image_ids + + # ==================== Text Encoding Methods ==================== + + def get_glyph_texts(self, prompt: str | list[str]) -> list[str]: + """Extract text within quotes for glyph rendering.""" + prompt = prompt[0] if isinstance(prompt, list) else prompt + ocr_texts = ( + re.findall(r"'([^']*)'", prompt) + + re.findall(r"“([^“”]*)”", prompt) + + re.findall(r'"([^"]*)"', prompt) + + re.findall(r"「([^「」]*)」", prompt) + ) + return ocr_texts + + def _get_glyph_embeds( + self, + prompt: str | list[str], + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + """Get glyph embeddings from T5 encoder for text rendering.""" + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + glyph_texts = self.get_glyph_texts(prompt) + input_ids = self.tokenizer( + glyph_texts if len(glyph_texts) > 0 else [""], + max_length=max_sequence_length, + truncation=True, + ).input_ids + + # Pad to even length + input_ids = [[self.tokenizer.pad_token_id] * ((len(ids) + 1) % 2) + ids for ids in input_ids] + max_length = max(len(ids) for ids in input_ids) + + attention_mask = torch.tensor( + [[1] * len(ids) + [0] * (max_length - len(ids)) for ids in input_ids], + device=device, + ) + input_ids = torch.tensor( + [ids + [self.tokenizer.pad_token_id] * (max_length - len(ids)) for ids in input_ids], + device=device, + ) + + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 2048, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Encode prompt into glyph embeddings for text rendering.""" + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = [""] * batch_size + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # ==================== Latent Preparation ==================== + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare random noise latents.""" + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"Passed {len(generator)} generators but batch size is {batch_size}.") + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def diffuse( + self, + latents: torch.Tensor, + prior_token_id: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + timesteps: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + guidance_scale: float, + do_classifier_free_guidance: bool, + kv_caches: GlmImageKVCache | None = None, + ) -> torch.Tensor: + """ + Denoising loop for diffusion process with CFG-Parallel support. + + Args: + latents: Initial noise latents + prior_token_id: Prior tokens generated by AR model + prompt_embeds: Encoded positive prompt embeddings (glyph embeddings) + negative_prompt_embeds: Encoded negative prompt embeddings + timesteps: Denoising timesteps + target_size: Target image size tensor [[height, width]] + crop_coords: Crop coordinates tensor + guidance_scale: CFG scale + do_classifier_free_guidance: Whether to apply CFG + kv_caches: Optional KV cache for Image Edit mode + + Returns: + Denoised latents ready for VAE decode + """ + # Prepare conditional/unconditional drop flags + prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool) + + transformer_dtype = self.transformer.dtype + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_classifier_free_guidance and get_classifier_free_guidance_world_size() > 1 + + for i, t in enumerate(timesteps): + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) - 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + # Rank 0: Compute positive (conditional) prediction + local_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + else: + # Rank 1: Compute negative (unconditional) prediction + local_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + + # All-gather predictions from all ranks + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + + if cfg_rank == 0: + # Rank 0: Combine predictions and apply CFG + noise_pred_cond = gathered[0] + noise_pred_uncond = gathered[1] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Broadcast updated latents to all ranks + cfg_group.broadcast(latents, src=0) + + else: + # Sequential CFG (single GPU or no CFG) + # Conditional forward pass + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_cache=kv_caches, + return_dict=False, + )[0].float() + + if do_classifier_free_guidance: + # Unconditional forward pass + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_caches=kv_caches, + return_dict=False, + )[0].float() + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + return latents + + # ==================== Main Forward Pass ==================== + + def _prepare_condition_image_kv_cache( + self, + condition_images: list[torch.Tensor], + prior_token_image_ids: list[torch.Tensor], + prompt_embeds: torch.Tensor, + generator: torch.Generator | None = None, + ) -> GlmImageKVCache: + """ + Prepare KV cache by running condition images through transformer at timestep 0. + + This is used for Image Edit mode where we need to cache the condition image's + KV states for cross-attention during denoising. + + Args: + condition_images: List of preprocessed condition images + prior_token_image_ids: Prior token IDs for each condition image from AR model + prompt_embeds: Prompt embeddings (used to get dtype) + generator: Optional random generator + + Returns: + GlmImageKVCache with cached KV states from condition images + """ + kv_caches = self.transformer.create_kv_cache() + kv_caches.set_mode("write") + + # Prepare VAE normalization parameters + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(device=self.device, dtype=prompt_embeds.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(device=self.device, dtype=prompt_embeds.dtype) + ) + + # Process each condition image through transformer to populate KV cache + for condition_image, condition_prior_token_id in zip(condition_images, prior_token_image_ids): + condition_image = condition_image.to(device=self.device, dtype=prompt_embeds.dtype) + + # Encode condition image to latent space + # Use argmax (mode) for deterministic encoding of condition images + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + # Run forward pass at timestep 0 to cache KV states + # Empty encoder_hidden_states since we only want to cache image features + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_prior_token_id, + prior_token_drop=torch.full_like(condition_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=self.device), + target_size=torch.tensor([condition_image.shape[-2:]], device=self.device, dtype=prompt_embeds.dtype), + crop_coords=torch.zeros((1, 2), device=self.device, dtype=prompt_embeds.dtype), + kv_caches=kv_caches, + return_dict=False, + ) + + return kv_caches + + def _preprocess_condition_images( + self, + images: list[PIL.Image.Image] | PIL.Image.Image | None, + ) -> tuple[list[torch.Tensor] | None, int | None, int | None]: + """ + Preprocess condition images for Image Edit mode. + + Args: + images: Input images (PIL or list of PIL) + + Returns: + Tuple of (preprocessed_images, height, width) + """ + if images is None: + return None, None, None + + if not isinstance(images, list): + images = [images] + + preprocessed = [] + height, width = None, None + + for img in images: + if isinstance(img, PIL.Image.Image): + img_h, img_w = img.size[::-1] + else: + img_h, img_w = img.shape[:2] + + # Align to multiple of vae_scale_factor * patch_size + multiple_of = self.vae_scale_factor * self._patch_size + img_h = (img_h // multiple_of) * multiple_of + img_w = (img_w // multiple_of) * multiple_of + + processed = self.image_processor.preprocess(img, height=img_h, width=img_w) + preprocessed.append(processed) + + # Use first image dimensions as default + if height is None: + height, width = img_h, img_w + + return preprocessed, height, width + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + """ + Main generation forward pass. + + Args: + req: OmniDiffusionRequest with generation parameters + + Returns: + DiffusionOutput containing generated image + """ + prompt = req.prompt or "" + if isinstance(prompt, list): + prompt = prompt[0] if prompt else "" + + # Get pre-computed prompt embeddings if provided + prompt_embeds = req.prompt_embeds if isinstance(req.prompt_embeds, torch.Tensor) else None + + # Get condition images for Image Edit mode + condition_images = req.pil_image + if condition_images is not None and not isinstance(condition_images, list): + condition_images = [condition_images] + + # Preprocess condition images and get dimensions + preprocessed_images, img_height, img_width = self._preprocess_condition_images(condition_images) + is_image_edit = preprocessed_images is not None + + # Use image dimensions as default if available + height = req.height or img_height or self.default_sample_size * self.vae_scale_factor + width = req.width or img_width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.num_inference_steps or 50 + guidance_scale = req.guidance_scale or 1.5 + + # 0. Validate inputs + self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) + + batch_size = 1 + do_classifier_free_guidance = guidance_scale > 1.0 + + # Set seed if provided + generator = None + if req.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.seed) + + # 1. Generate prior tokens with AR model + logger.info("Generating prior tokens with AR model...") + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt, + image=condition_images, + height=height, + width=width, + ) + + # 2. Encode prompt for glyph embeddings + logger.info("Encoding prompt...") + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=1, + prompt_embeds=prompt_embeds, + device=self.device, + dtype=self.transformer.dtype, + ) + + # 3. Prepare KV cache for Image Edit mode + kv_caches = None + if is_image_edit and prior_token_image_ids is not None: + logger.info("Preparing KV cache for Image Edit mode...") + kv_caches = self._prepare_condition_image_kv_cache( + condition_images=preprocessed_images, + prior_token_image_ids=prior_token_image_ids, + prompt_embeds=prompt_embeds, + generator=generator, + ) + # Switch to read mode for denoising + kv_caches.set_mode("read") + + # 4. Prepare latents + latent_channels = self.transformer.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=self.device, + generator=generator, + ) + + # 5. Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (self._patch_size**2) + timesteps_array = np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + timesteps_array = timesteps_array.astype(np.int64).astype(np.float32) + sigmas = timesteps_array / self.scheduler.config.num_train_timesteps + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, self.device, timesteps_array.tolist(), sigmas.tolist(), mu=mu + ) + + # 6. Prepare conditioning tensors + target_size = torch.tensor([[height, width]], dtype=prompt_embeds.dtype, device=self.device) + crop_coords = torch.zeros((1, 2), dtype=prompt_embeds.dtype, device=self.device) + + # 7. Denoising loop with CFG-parallel support + logger.info(f"Starting denoising loop with {num_inference_steps} steps...") + latents = self.diffuse( + latents=latents, + prior_token_id=prior_token_id, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + timesteps=timesteps, + target_size=target_size, + crop_coords=crop_coords, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + kv_caches=kv_caches, + ) + + # 8. VAE decode + logger.info("Decoding latents with VAE...") + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + + # 9. Post-process + image = self.image_processor.postprocess(image, output_type="pil")[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load transformer weights.""" + # Filter weights for transformer only + transformer_weights = ( + (name.replace("transformer.", ""), weight) for name, weight in weights if name.startswith("transformer.") + ) + return self.transformer.load_weights(transformer_weights) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index c49ba0a3cd9..e566ca66cfa 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -29,6 +29,11 @@ "pipeline_qwen_image_layered", "QwenImageLayeredPipeline", ), + "GlmImagePipeline": ( + "glm_image", + "pipeline_glm_image", + "GlmImagePipeline", + ), "ZImagePipeline": ( "z_image", "pipeline_z_image", @@ -112,6 +117,7 @@ def initialize_model( "QwenImagePipeline": "get_qwen_image_post_process_func", "QwenImageEditPipeline": "get_qwen_image_edit_post_process_func", "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_post_process_func", + "GlmImagePipeline": "get_glm_image_post_process_func", "ZImagePipeline": "get_post_process_func", "OvisImagePipeline": "get_ovis_image_post_process_func", "WanPipeline": "get_wan22_post_process_func", diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 6f22c6165b2..714b8dfcc53 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -37,65 +37,6 @@ def __init__( self.mm_type: str | None = None self.mm_accumulated: Dict[str, Any] | None = None - @classmethod - def from_new_request( - cls, - tokenizer: TokenizerLike, - request: EngineCoreRequest, - prompt: str | None, - parent_req: ParentRequest | None, - request_index: int, - queue: Any | None, - log_stats: bool, - stream_interval: int, - ) -> "OmniRequestState": - if sampling_params := request.sampling_params: - if not sampling_params.detokenize: - tokenizer = None - output_kind = sampling_params.output_kind - logprobs_processor = LogprobsProcessor.from_new_request( - tokenizer=tokenizer, - request=request, - ) - detokenizer = IncrementalDetokenizer.from_new_request( - tokenizer=tokenizer, - request=request, - ) - max_tokens_param = sampling_params.max_tokens - top_p = sampling_params.top_p - n = sampling_params.n - temperature = sampling_params.temperature - else: - logprobs_processor = None - detokenizer = None - max_tokens_param = None - top_p = None - n = None - temperature = None - assert request.pooling_params is not None - output_kind = request.pooling_params.output_kind - - return cls( - request_id=request.request_id, - parent_req=parent_req, - request_index=request_index, - lora_name=(request.lora_request.name if request.lora_request is not None else None), - output_kind=output_kind, - prompt=prompt, - prompt_token_ids=request.prompt_token_ids, - prompt_embeds=request.prompt_embeds, - logprobs_processor=logprobs_processor, - detokenizer=detokenizer, - max_tokens_param=max_tokens_param, - top_p=top_p, - n=n, - temperature=temperature, - arrival_time=request.arrival_time, - queue=queue, - log_stats=log_stats, - stream_interval=stream_interval, - ) - def add_multimodal_tensor(self, payload: Any | None, mm_type: str | None) -> None: if payload is None: return diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py index 567af03770f..287f12b9ed7 100644 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -10,7 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.tokenizers import init_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -111,7 +111,7 @@ def __init__( tokenizer = None else: # Tokenizer (+ ensure liveness if running in another process). - tokenizer = init_tokenizer_from_config(model_config=vllm_config.model_config) + tokenizer = cached_tokenizer_from_config(model_config=vllm_config.model_config) # InputProcessor (converts Inputs --> EngineCoreRequests). self.input_processor = OmniInputProcessor( diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 05a48feee0e..74fe6a80376 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -2,7 +2,9 @@ import cloudpickle from pydantic import ValidationError - +from tqdm import tqdm +from vllm.outputs import RequestOutput, PoolingRequestOutput +from typing import Callable # External library imports (vLLM) from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field from vllm.entrypoints.llm import LLM @@ -190,3 +192,54 @@ def __del__(self) -> None: # best-effort self.close() except Exception as e: logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) + + def _run_engine( + self, *, use_tqdm: bool | Callable[..., tqdm] = True + ) -> list[RequestOutput | PoolingRequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: list[RequestOutput | PoolingRequestOutput] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + n = len(output.outputs) + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) * n + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum( + len(stp.token_ids) for stp in output.outputs + ) + out_spd = total_out_toks / pbar.format_dict["elapsed"] + pbar.postfix = ( + f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s" + ) + pbar.update(n) + else: + pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() + + if use_tqdm: + pbar.close() + # Sort the outputs by the int part of request ID which is in format of 'int-uuid'. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 7675caa638e..c459289af34 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -154,6 +154,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # for CI: Initialize special tokens embeddings early to avoid AttributeError when loading dummy weights self._init_special_tokens_embeddings() + self.requires_raw_input_tokens = True elif self.model_stage == "code2wav": self.thinker = None @@ -168,6 +169,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): architectures=["Qwen3OmniMoeCode2Wav"], ) self.model = self.code2wav + self.requires_raw_input_tokens = True else: raise ValueError( f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'talker', 'code2wav'" diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py index 361a9349b25..7f3320a82eb 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py @@ -684,13 +684,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) - attn_backend_override = multimodal_config.mm_encoder_attn_backend if multimodal_config is not None else None self.visual = Qwen3Omni_VisionTransformer( vision_config=thinker_config.vision_config, norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) self.quant_config = quant_config diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 2d2b7ef8e2d..d4e7e195fe8 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -26,11 +26,18 @@ get_pp_group, get_tp_group, has_kv_transfer_group, + ) from vllm.v1.worker.utils import is_residual_scattered_for_sp - +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner +from vllm.v1.outputs import make_empty_encoder_model_runner_output +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices logger = init_logger(__name__) @@ -44,6 +51,7 @@ class ExecuteModelState(NamedTuple): sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None ec_connector_output: Any + cudagraph_stats: Any multimodal_outputs: Any @@ -82,67 +90,148 @@ def execute_model( scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: - with record_function_or_nullcontext("Preprocess"): - with self.synchronize_input_prep(): - self._update_states(scheduler_output) - self._decode_and_store_request_payloads(scheduler_output) - - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect " - "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs" - ) + if self.execute_model_state is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) - num_reqs = self.input_batch.num_reqs - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions( + scheduler_output.preempted_req_ids + ) - logits_indices, spec_decode_metadata = self._prepare_inputs( + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + # Update persistent batch states. + self._update_states(scheduler_output) + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( scheduler_output, - num_scheduled_tokens_np, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" ) - ( - cudagraph_mode, - batch_desc, - ubatch_slices, - num_tokens_across_dp, - ) = self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens_np, - max_num_scheduled_tokens=max_num_scheduled_tokens, - use_cascade_attn=False, + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, ) - num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - pad_attn = cudagraph_mode == CUDAGraphMode.FULL + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) - ( - attn_metadata, - spec_decode_common_attn_metadata, - ) = self._build_attention_metadata( + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs, num_reqs_padded=num_reqs_padded if pad_attn else None, max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_attn, logits_indices=logits_indices, use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=None, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) + ) ( input_ids, @@ -152,15 +241,19 @@ def execute_model( model_kwargs, ec_connector_output, ) = self._preprocess( - scheduler_output, - num_tokens_padded, - intermediate_tensors, + scheduler_output, num_tokens_padded, intermediate_tensors ) + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. if self.calculate_kv_scales: cudagraph_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + # Run the model. + # Use persistent buffers for CUDA graphs. with ( set_forward_context( attn_metadata, @@ -169,9 +262,9 @@ def execute_model( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -208,30 +301,37 @@ def execute_model( logger.debug("[AR] execute_model: multimodal_outputs is None") if not self.broadcast_pp_output: + # Common case. if not get_pp_group().is_last_rank: + # Return the intermediate tensors. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output return hidden_states if self.is_pooling_model: - output = self._pool( + # Return the pooling output. + return self._pool( hidden_states, - num_tokens_padded, + num_scheduled_tokens, num_scheduled_tokens_np, + kv_connector_output, ) - output.kv_connector_output = kv_connector_output - return output sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits( sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata ) else: + # Rare case. assert not self.is_pooling_model + sample_hidden_states = hidden_states[logits_indices] if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": not is_residual_scattered_for_sp(self.vllm_config, num_tokens_padded) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_tokens_padded + ) } get_pp_group().send_tensor_dict( hidden_states.tensors, @@ -240,7 +340,6 @@ def execute_model( ) logits = None else: - sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits( sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata ) @@ -264,6 +363,7 @@ def execute_model( sample_hidden_states, aux_hidden_states, ec_connector_output, + cudagraph_stats, multimodal_outputs, ) self.kv_connector_output = kv_connector_output @@ -278,14 +378,20 @@ def sample_tokens( self.kv_connector_output = None if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. if not kv_connector_output: return None # type: ignore[return-value] + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output if kv_connector_output.is_empty(): return EMPTY_MODEL_RUNNER_OUTPUT + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = kv_connector_output return output + # Unpack ephemeral state. ( scheduler_output, logits, @@ -295,16 +401,22 @@ def sample_tokens( sample_hidden_states, aux_hidden_states, ec_connector_output, + cudagraph_stats, multimodal_outputs, ) = self.execute_model_state self.execute_model_state = None + # Apply structured output bitmasks if present. if grammar_output is not None: - apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits) + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self._draft_token_ids = None + self._draft_token_req_ids = None self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -320,39 +432,44 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, ) + self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config - use_padded_batch_for_eagle = ( - spec_config is not None and spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch - ) - effective_drafter_max_model_len = self.max_model_len - if effective_drafter_max_model_len is None: - effective_drafter_max_model_len = self.model_config.max_model_len - if ( - spec_config is not None - and spec_config.draft_model_config is not None - and spec_config.draft_model_config.max_model_len is not None - ): - effective_drafter_max_model_len = spec_config.draft_model_config.max_model_len - input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= effective_drafter_max_model_len - ) - if use_padded_batch_for_eagle: - assert self.speculative_config is not None - assert isinstance(self.drafter, EagleProposer) - sampled_token_ids = sampler_output.sampled_token_ids - if input_fits_in_drafter: - propose_draft_token_ids(sampled_token_ids) - elif self.valid_sampled_token_count_event is not None: - assert spec_decode_common_attn_metadata is not None - next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, - ) - self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) + propose_drafts_after_bookkeeping = False + if spec_config is not None: + input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + <= self.effective_drafter_max_model_len + ) + if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + assert isinstance(self.drafter, EagleProposer) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + # Since we couldn't run the drafter, + # just use zeros for the draft tokens. + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + else: + propose_drafts_after_bookkeeping = input_fits_in_drafter with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( @@ -372,7 +489,9 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) - if self.speculative_config and not use_padded_batch_for_eagle and input_fits_in_drafter: + if propose_drafts_after_bookkeeping: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) with record_function_or_nullcontext("gpu_model_runner: eplb"): @@ -421,6 +540,12 @@ def propose_draft_token_ids(sampled_token_ids): payload.update(mm_payload) pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + if self.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.slot_mapping) # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") output = OmniModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, @@ -431,6 +556,7 @@ def propose_draft_token_ids(sampled_token_ids): kv_connector_output=kv_connector_output, ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, num_nans_in_logits=num_nans_in_logits, + cudagraph_stats=cudagraph_stats, ) if not self.use_async_scheduling: diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py index 9e058addb6e..599dea31f2f 100644 --- a/vllm_omni/worker/gpu_ar_worker.py +++ b/vllm_omni/worker/gpu_ar_worker.py @@ -3,19 +3,21 @@ import torch from vllm.logger import init_logger -from vllm.model_executor import set_random_seed +from vllm.utils.torch_utils import set_random_seed from vllm.platforms import current_platform -from vllm.utils.mem_constants import GiB_bytes -from vllm.utils.mem_utils import MemorySnapshot +from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_worker import Worker as GPUWorker from vllm.v1.worker.gpu_worker import init_worker_distributed_environment from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner - +from vllm.v1.worker.workspace import init_workspace_manager +from vllm.v1.worker.utils import request_memory +from vllm.logger import init_logger logger = init_logger(__name__) + class GPUARWorker(GPUWorker): """GPU worker for autoregressive omni model stages. @@ -24,24 +26,24 @@ class GPUARWorker(GPUWorker): """ def init_device(self): - device = self.device_config.device - if isinstance(device, torch.device) and device.type == "cuda": + if self.device_config.device_type == "cuda": # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + parallel_config = self.parallel_config if ( - self.parallel_config.data_parallel_size > 1 - and self.parallel_config.data_parallel_size_local > 0 - and self.parallel_config.distributed_executor_backend not in ["ray", "external_launcher"] - and self.vllm_config.parallel_config.data_parallel_backend != "ray" - and self.vllm_config.parallel_config.nnodes_within_dp == 1 + parallel_config.distributed_executor_backend + not in ("ray", "external_launcher") + and parallel_config.data_parallel_backend != "ray" + and parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. dp_local_rank = self.parallel_config.data_parallel_rank_local if dp_local_rank is None: - dp_local_rank = self.parallel_config.data_parallel_rank + dp_local_rank = self.parallel_config.data_parallel_index tp_pp_world_size = ( - self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size + self.parallel_config.pipeline_parallel_size + * self.parallel_config.tensor_parallel_size ) # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK @@ -49,7 +51,9 @@ def init_device(self): assert self.local_rank < torch.cuda.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + visible_device_count = ( + torch.cuda.device_count() if torch.cuda.is_available() else 0 + ) assert self.parallel_config.local_world_size <= visible_device_count, ( f"local_world_size ({self.parallel_config.local_world_size}) must " f"be less than or equal to the number of visible devices " @@ -80,28 +84,22 @@ def init_device(self): torch.cuda.empty_cache() # take current memory snapshot - self.init_snapshot = MemorySnapshot() - self.requested_memory = self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization - if self.init_snapshot.free_memory < self.requested_memory: - - def gib(bytes_val: float) -> float: - return round(bytes_val / GiB_bytes, 2) - - raise ValueError( - f"Free memory on device " - f"({gib(self.init_snapshot.free_memory)}/" - f"{gib(self.init_snapshot.total_memory)} GiB) on startup " - f"is less than desired GPU memory utilization " - f"({self.cache_config.gpu_memory_utilization}, " - f"{gib(self.requested_memory)} GiB). Decrease GPU memory " - f"utilization or reduce GPU memory used by other processes." - ) + self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) + self.requested_memory = request_memory(init_snapshot, self.cache_config) + logger.debug("worker init memory snapshot: %r", self.init_snapshot) + logger.debug( + "worker requested memory: %sGiB", format_gib(self.requested_memory) + ) else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + # Construct the model runner self.model_runner = GPUARModelRunner(self.vllm_config, self.device) if self.rank == 0: # If usage stat is enabled, collect relevant info. - report_usage_stats(self.vllm_config) + report_usage_stats(self.vllm_config) \ No newline at end of file diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 17740d85805..40b26e2ba50 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -5,16 +5,16 @@ """ from __future__ import annotations +from copy import copy import gc import logging - +from typing import Any import numpy as np import torch from vllm.config import CUDAGraphMode -from vllm.multimodal.inputs import MultiModalKwargs from vllm.utils.math_utils import cdiv -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import SchedulerOutput, GrammarOutput from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.gpu_model_runner import ( @@ -25,11 +25,20 @@ get_pp_group, set_forward_context, ) +from vllm.model_executor.models.interfaces import supports_mm_encoder_only from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs - +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner - +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.v1.outputs import make_empty_encoder_model_runner_output +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState logger = logging.getLogger(__name__) @@ -47,57 +56,149 @@ def execute_model( scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, ) -> OmniModelRunnerOutput | IntermediateTensors: - with record_function_or_nullcontext("Preprocess"): - with self.synchronize_input_prep(): - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - return EMPTY_MODEL_RUNNER_OUTPUT + if self.execute_model_state is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) - num_reqs = self.input_batch.num_reqs - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") - logits_indices, spec_decode_metadata = self._prepare_inputs( + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions( + scheduler_output.preempted_req_ids + ) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, ) + + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) - ( - cudagraph_mode, - batch_desc, - ubatch_slices, - num_tokens_across_dp, - ) = self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens_np, - max_num_scheduled_tokens=max_num_scheduled_tokens, - use_cascade_attn=False, - ) + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) - num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - pad_attn = cudagraph_mode == CUDAGraphMode.FULL + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) - ( - attn_metadata, - spec_decode_common_attn_metadata, - ) = self._build_attention_metadata( + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs, num_reqs_padded=num_reqs_padded if pad_attn else None, max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_attn, logits_indices=logits_indices, use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=None, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) + ) ( input_ids, @@ -112,10 +213,16 @@ def execute_model( intermediate_tensors, ) + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. if self.calculate_kv_scales: cudagraph_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + # Run the model. + # Use persistent buffers for CUDA graphs. with ( set_forward_context( attn_metadata, @@ -124,7 +231,7 @@ def execute_model( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, @@ -139,6 +246,58 @@ def execute_model( ) _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + self.execute_model_state = ExecuteModelState( + scheduler_output, + None, + spec_decode_metadata, + spec_decode_common_attn_metadata, + None, + None, + None, + ec_connector_output, + cudagraph_stats, + multimodal_outputs, + ) + self.kv_connector_output = kv_connector_output + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # type: ignore[return-value] + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + multimodal_outputs, + ) = self.execute_model_state + self.execute_model_state = None + pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): assert multimodal_outputs.shape[0] == 1, ( @@ -169,6 +328,10 @@ def execute_model( pooler_output=pooler_output, kv_connector_output=kv_connector_output, num_nans_in_logits={}, + cudagraph_stats=cudagraph_stats, + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, ) if not self.use_async_scheduling: @@ -176,9 +339,11 @@ def execute_model( return AsyncGPUModelRunnerOutput( model_runner_output=output, - sampled_token_ids=[], + sampled_token_ids=torch.tensor([], device=self.device), invalid_req_indices=[], async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + logprobs_tensors=None, ) def _run_generation_model( @@ -206,7 +371,7 @@ def _run_generation_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs(model_kwargs, device=self.device), + **model_kwargs, sampling_metadata=self.input_batch.sampling_metadata, logits_index=logits_indices, sampler=self.sampler, @@ -263,7 +428,15 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. """ - assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() + if supports_mm_encoder_only(self.model): + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -316,23 +489,26 @@ def _dummy_run( num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens, - max_num_scheduled_tokens=max_query_len, - use_cascade_attn=False, - allow_microbatching=allow_microbatching, - force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), - # `force_uniform_decode` is used for cudagraph capture; because for - # capturing mixed prefill-decode batches, we sometimes use - # num_tokens == num_reqs which looks like a uniform decode batch to the - # dispatcher; but we actually want to capture a piecewise cudagraph - force_uniform_decode=uniform_decode, - # `force_has_lora` is used for cudagraph capture; because LoRA is - # activated later in the context manager, but we need to know the - # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile + or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, + ) ) if cudagraph_runtime_mode is None: @@ -344,7 +520,21 @@ def _dummy_run( ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) attn_metadata: PerLayerAttnMetadata | None = None @@ -366,11 +556,12 @@ def _dummy_run( self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_reqs=num_reqs_padded, max_query_len=max_query_len, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, for_cudagraph_capture=is_graph_capturing, ) @@ -383,10 +574,10 @@ def _dummy_run( ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs), @@ -394,7 +585,7 @@ def _dummy_run( elif self.enable_prompt_embeds: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() else: input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None @@ -410,24 +601,28 @@ def _dummy_run( intermediate_tensors = None else: if self.intermediate_tensors is None: - self.intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device, + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) ) - intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens_padded, None, False + ) - if ubatch_slices is not None: + if ubatch_slices_padded is not None: # Adjust values to reflect a single ubatch. # TODO(sage,lucas): this is cruft that should be addressed in # the padding refactor. - num_tokens_padded = ubatch_slices[0].num_tokens + num_tokens_padded = ubatch_slices_padded[0].num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded with ( - self.maybe_randomize_inputs(input_ids), + self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( attn_metadata, self.vllm_config, @@ -435,7 +630,7 @@ def _dummy_run( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), ): outputs = self.model( @@ -453,10 +648,19 @@ def _dummy_run( hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) + # Eagle currently only supports PIECEWISE cudagraphs. + # Therefore only use cudagraphs if the main model uses PIECEWISE + # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( - cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) + ( + is_graph_capturing + and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + ) + or ( + not is_graph_capturing + and cudagraph_runtime_mode != CUDAGraphMode.NONE + ) + ) and not self.speculative_config.enforce_eager # Note(gnovack) - We need to disable cudagraphs for one of the two # lora cases when cudagraph_specialize_lora is enabled. This is a @@ -471,6 +675,17 @@ def _dummy_run( is_graph_capturing=is_graph_capturing, ) + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() + # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process. diff --git a/vllm_omni/worker/gpu_generation_worker.py b/vllm_omni/worker/gpu_generation_worker.py index 27111f39408..6a1a3039211 100644 --- a/vllm_omni/worker/gpu_generation_worker.py +++ b/vllm_omni/worker/gpu_generation_worker.py @@ -2,17 +2,17 @@ import os import torch -from vllm.model_executor import set_random_seed +from vllm.utils.torch_utils import set_random_seed from vllm.platforms import current_platform -from vllm.utils.mem_constants import GiB_bytes -from vllm.utils.mem_utils import MemorySnapshot +from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_worker import Worker as GPUWorker from vllm.v1.worker.gpu_worker import init_worker_distributed_environment - +from vllm.v1.worker.workspace import init_workspace_manager +from vllm.v1.worker.utils import request_memory from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner - - +from vllm.logger import init_logger +logger = init_logger(__name__) class GPUGenerationWorker(GPUWorker): """GPU Worker for Generation model (non-autoregressive waveform generation). @@ -21,24 +21,24 @@ class GPUGenerationWorker(GPUWorker): """ def init_device(self): - device = self.device_config.device - if isinstance(device, torch.device) and device.type == "cuda": + if self.device_config.device_type == "cuda": # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + parallel_config = self.parallel_config if ( - self.parallel_config.data_parallel_size > 1 - and self.parallel_config.data_parallel_size_local > 0 - and self.parallel_config.distributed_executor_backend not in ["ray", "external_launcher"] - and self.vllm_config.parallel_config.data_parallel_backend != "ray" - and self.vllm_config.parallel_config.nnodes_within_dp == 1 + parallel_config.distributed_executor_backend + not in ("ray", "external_launcher") + and parallel_config.data_parallel_backend != "ray" + and parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. dp_local_rank = self.parallel_config.data_parallel_rank_local if dp_local_rank is None: - dp_local_rank = self.parallel_config.data_parallel_rank + dp_local_rank = self.parallel_config.data_parallel_index tp_pp_world_size = ( - self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size + self.parallel_config.pipeline_parallel_size + * self.parallel_config.tensor_parallel_size ) # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK @@ -46,7 +46,9 @@ def init_device(self): assert self.local_rank < torch.cuda.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + visible_device_count = ( + torch.cuda.device_count() if torch.cuda.is_available() else 0 + ) assert self.parallel_config.local_world_size <= visible_device_count, ( f"local_world_size ({self.parallel_config.local_world_size}) must " f"be less than or equal to the number of visible devices " @@ -77,26 +79,19 @@ def init_device(self): torch.cuda.empty_cache() # take current memory snapshot - self.init_snapshot = MemorySnapshot() - self.requested_memory = self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization - if self.init_snapshot.free_memory < self.requested_memory: - - def gib(bytes_val: float) -> float: - return round(bytes_val / GiB_bytes, 2) - - raise ValueError( - f"Free memory on device " - f"({gib(self.init_snapshot.free_memory)}/" - f"{gib(self.init_snapshot.total_memory)} GiB) on startup " - f"is less than desired GPU memory utilization " - f"({self.cache_config.gpu_memory_utilization}, " - f"{gib(self.requested_memory)} GiB). Decrease GPU memory " - f"utilization or reduce GPU memory used by other processes." - ) + self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) + self.requested_memory = request_memory(init_snapshot, self.cache_config) + logger.debug("worker init memory snapshot: %r", self.init_snapshot) + logger.debug( + "worker requested memory: %sGiB", format_gib(self.requested_memory) + ) else: raise RuntimeError(f"Not support device type: {self.device_config.device}") - # Construct the model runner + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + self.model_runner = GPUGenerationModelRunner(self.vllm_config, self.device) if self.rank == 0: diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 69729d95429..24d4ffd028e 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.models.interfaces import supports_mrope +from vllm.model_executor.models.interfaces import supports_mrope, supports_mm_encoder_only from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.sampling_params import SamplingType from vllm.utils.import_utils import LazyLoader @@ -16,7 +16,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata - +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm_omni.model_executor.models.output_templates import OmniOutput if TYPE_CHECKING: @@ -129,6 +129,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -149,7 +150,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct @@ -240,22 +248,64 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: except Exception as e: logger.error(f"Error decoding additional information: {e}") pass - + + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) - + + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._init_xdrope_positions(req_state) + reqs_to_add.append(self.requests[req_id]) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] - + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt lenth), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) + # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -272,7 +322,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) - + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) + self.input_batch.num_tokens_no_spec[req_index] = end_idx + # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: @@ -280,6 +340,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: + assert req_index is None assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -290,6 +351,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + reqs_to_add.append(req_state) continue @@ -304,24 +372,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) - self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids + self.input_batch.token_ids_cpu[ + req_index, start_token_index:end_token_index + ] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index - self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) - start_index = self.input_batch.num_tokens_no_spec[req_index] - end_token_index = start_index + num_spec_tokens - self.input_batch.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec tokens. - self.input_batch.num_tokens[req_index] += num_spec_tokens + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -388,7 +452,15 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. """ - assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() + if supports_mm_encoder_only(self.model): + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -441,23 +513,26 @@ def _dummy_run( num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens, - max_num_scheduled_tokens=max_query_len, - use_cascade_attn=False, - allow_microbatching=allow_microbatching, - force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), - # `force_uniform_decode` is used for cudagraph capture; because for - # capturing mixed prefill-decode batches, we sometimes use - # num_tokens == num_reqs which looks like a uniform decode batch to the - # dispatcher; but we actually want to capture a piecewise cudagraph - force_uniform_decode=uniform_decode, - # `force_has_lora` is used for cudagraph capture; because LoRA is - # activated later in the context manager, but we need to know the - # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile + or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, + ) ) if cudagraph_runtime_mode is None: @@ -469,7 +544,21 @@ def _dummy_run( ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) attn_metadata: PerLayerAttnMetadata | None = None @@ -491,11 +580,12 @@ def _dummy_run( self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_reqs=num_reqs_padded, max_query_len=max_query_len, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, for_cudagraph_capture=is_graph_capturing, ) @@ -508,10 +598,10 @@ def _dummy_run( ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs), @@ -519,7 +609,7 @@ def _dummy_run( elif self.enable_prompt_embeds: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() else: input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None @@ -535,24 +625,28 @@ def _dummy_run( intermediate_tensors = None else: if self.intermediate_tensors is None: - self.intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device, + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) ) - intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens_padded, None, False + ) - if ubatch_slices is not None: + if ubatch_slices_padded is not None: # Adjust values to reflect a single ubatch. # TODO(sage,lucas): this is cruft that should be addressed in # the padding refactor. - num_tokens_padded = ubatch_slices[0].num_tokens + num_tokens_padded = ubatch_slices_padded[0].num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded with ( - self.maybe_randomize_inputs(input_ids), + self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( attn_metadata, self.vllm_config, @@ -560,20 +654,9 @@ def _dummy_run( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), ): - if ( - getattr(self.model, "talker", None) is not None - and hasattr(self.model, "talker_mtp") - and num_tokens_padded == 1 - ): - outputs = self.talker_mtp( - self.talker_mtp_input_ids.gpu[:num_tokens_padded], - self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], - self.last_talker_hidden.gpu[:num_tokens_padded], - self.text_step.gpu[:num_tokens_padded], - ) outputs = self.model( input_ids=input_ids, positions=positions, @@ -589,10 +672,19 @@ def _dummy_run( hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) + # Eagle currently only supports PIECEWISE cudagraphs. + # Therefore only use cudagraphs if the main model uses PIECEWISE + # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( - cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) + ( + is_graph_capturing + and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + ) + or ( + not is_graph_capturing + and cudagraph_runtime_mode != CUDAGraphMode.NONE + ) + ) and not self.speculative_config.enforce_eager # Note(gnovack) - We need to disable cudagraphs for one of the two # lora cases when cudagraph_specialize_lora is enabled. This is a @@ -607,6 +699,17 @@ def _dummy_run( is_graph_capturing=is_graph_capturing, ) + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() + # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process. @@ -618,7 +721,9 @@ def _dummy_run( self.eplb_step(is_dummy=True, is_profile=is_profile) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - logit_indices_device = torch.from_numpy(logit_indices).to(self.device, non_blocking=True) + logit_indices_device = torch.from_numpy(logit_indices).to( + self.device, non_blocking=True + ) return hidden_states, hidden_states[logit_indices_device] def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None: @@ -776,7 +881,7 @@ def _preprocess( num_input_tokens: int, intermediate_tensors: IntermediateTensors | None = None, ): - """Align with v0.12 preprocess and omni's additional information handling.""" + """Align with v0.14.0 preprocess and omni's additional information handling.""" num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens is_first_rank = get_pp_group().is_first_rank is_encoder_decoder = self.model_config.is_encoder_decoder @@ -806,10 +911,9 @@ def _preprocess( # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) - input_ids = self.input_ids.gpu[:num_input_tokens] - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) model_kwargs = { - **self._init_model_kwargs(num_scheduled_tokens), + **self._init_model_kwargs(), **self._extract_mm_kwargs(scheduler_output), } elif self.enable_prompt_embeds and is_first_rank: @@ -833,7 +937,7 @@ def _preprocess( self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - model_kwargs = self._init_model_kwargs(num_input_tokens) + model_kwargs = self._init_model_kwargs() input_ids = self.input_ids.gpu[:num_input_tokens] else: # For text-only models, we use token ids as input. @@ -842,7 +946,7 @@ def _preprocess( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_input_tokens) + model_kwargs = self._init_model_kwargs() if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_input_tokens] @@ -894,12 +998,19 @@ def _preprocess( span_len = int(e) - int(s) # call the custom process function - req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos - ) + try: + req_input_ids, req_embeds, update_dict = self.model.preprocess( + input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + ) + except Exception as e: + logger.error(f"Error in preprocess for request {req_id}: {e}") + import traceback + traceback.print_exc() + raise e + #TODO: This is Model Specific Code, need to be generalized in the future ZTC # run talker mtp decode if hasattr(self.model, "talker_mtp"): - _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( + _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( num_tokens=span_len, num_reqs=1, num_scheduled_tokens_np=num_scheduled_tokens_np[req_index],