diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ed1b6e72cc0..e175385ff0d 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1,7 +1,7 @@ steps: - label: "Diffusion Model Test" - timeout_in_minutes: 20 + timeout_in_minutes: 30 agent_pool: mi325_2 depends_on: amd-build mirror_hardwares: [amdproduction] @@ -11,7 +11,7 @@ steps: - pytest -s -v tests/e2e/offline_inference/test_t2i_model.py - label: "Diffusion Images API LoRA E2E" - timeout_in_minutes: 20 + timeout_in_minutes: 30 agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml index 06bf033cac2..b606648c73c 100644 --- a/.buildkite/test-merge.yml +++ b/.buildkite/test-merge.yml @@ -17,7 +17,7 @@ steps: - "/fsx/hf_cache:/fsx/hf_cache" - label: "Diffusion Model Test" - timeout_in_minutes: 20 + timeout_in_minutes: 30 depends_on: upload-merge-pipeline commands: - pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "advanced_model and diffusion" --run-level "advanced_model" @@ -35,7 +35,7 @@ steps: - "/fsx/hf_cache:/fsx/hf_cache" - label: "Diffusion Images API LoRA E2E" - timeout_in_minutes: 20 + timeout_in_minutes: 30 depends_on: upload-merge-pipeline commands: - pytest -s -v tests/e2e/online_serving/test_images_generations_lora.py diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml index 1d0b7ad0e84..a772e673e21 100644 --- a/.buildkite/test-ready.yml +++ b/.buildkite/test-ready.yml @@ -36,7 +36,7 @@ steps: - label: "Diffusion Model Test" depends_on: upload-ready-pipeline commands: - - timeout 20m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model" + - timeout 30m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model" agents: queue: "gpu_1_queue" plugins: diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci index cb80828eb95..83463687385 100644 --- a/docker/Dockerfile.ci +++ b/docker/Dockerfile.ci @@ -11,10 +11,29 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -# Install vllm-omni into the same uv-managed Python environment used by the base image. -# Use bash -c so that $(python3 -c ...) is expanded inside the container. -RUN uv pip install --system --no-cache-dir ".[dev]" +RUN uv pip uninstall --system -y vllm || true +# Install vLLM from precompiled wheel at the selected commit. +# Must use direct URL because the wheel has a PEP 440 local version identifier +# (e.g. +g0a0a1a198) which pip/uv refuse to install from a PEP 503 package index. +ENV VLLM_PRECOMPILED_WHEEL_COMMIT=89138b21cc246ae944c741d5c399c148e2b770ab +RUN VLLM_WHEEL_URL=$(python3 -c "import urllib.request,re; \ + html=urllib.request.urlopen('https://wheels.vllm.ai/${VLLM_PRECOMPILED_WHEEL_COMMIT}/vllm/').read().decode(); \ + m=re.search(r'>(\S+x86_64\.whl)<',html); \ + print('https://wheels.vllm.ai/${VLLM_PRECOMPILED_WHEEL_COMMIT}/'+m.group(1).replace('+','%2B'))") && \ + echo "Installing vLLM from: ${VLLM_WHEEL_URL}" && \ + uv pip install --system --force-reinstall "${VLLM_WHEEL_URL}" + +RUN uv pip install --system ".[dev]" + +RUN uv pip install --system --upgrade \ + "flashinfer-cubin==0.6.6" \ + "nvidia-cublas-cu12==12.9.1.4" \ + "numpy==2.2.6" + +RUN uv pip install --system --upgrade \ + "flashinfer-jit-cache==0.6.6" \ + --index-url https://flashinfer.ai/whl/cu129 RUN ln -sf /usr/bin/python3 /usr/bin/python ENTRYPOINT [] diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md index bc2d3b60cba..a763637035d 100644 --- a/docs/getting_started/installation/npu/npu.inc.md +++ b/docs/getting_started/installation/npu/npu.inc.md @@ -68,18 +68,18 @@ We are keeping [issue #886](https://github.com/vllm-project/vllm-omni/issues/886 You can also build vLLM-Omni from the latest main branch if you want to use the latest features or bug fixes. (But sometimes it will break for a while. You can check [issue #886](https://github.com/vllm-project/vllm-omni/issues/886) for the status of the latest commit of vLLM-Omni main branch on NPU.) ```bash -# Pin vLLM version to 0.17.0 +# Pin vLLM version to 0.18.0 cd /vllm-workspace/vllm git pull origin main git fetch origin --tags -git checkout v0.17.0 +git checkout v0.18.0 VLLM_TARGET_DEVICE=empty pip install -v -e . # Because vllm-ascend has not yet entered continuous development and has not been officially released, we need to pin it to a specific commit. Please note that this commit may change over time. cd /vllm-workspace/vllm-ascend git pull origin main git fetch origin --tags -git checkout v0.17.0 +git checkout 1e05c4908f31737bc4eef865a9f351d030a77c9d pip install -v -e . # Install vLLM-Omni from the latest main branch diff --git a/tests/conftest.py b/tests/conftest.py index 1c8d1bef386..f2d866a5894 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1910,6 +1910,7 @@ def generate_multimodal( def _cleanup_process(self): try: keywords = ["enginecore"] + matched = [] for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]): try: @@ -1922,16 +1923,32 @@ def _cleanup_process(self): if is_process: print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}") + matched.append(proc) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass - try: - proc.terminate() - time.sleep(2) - except Exception: - proc.kill() + for proc in matched: + try: + proc.terminate() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + _, still_alive = psutil.wait_procs(matched, timeout=5) + for proc in still_alive: + try: + proc.kill() except (psutil.NoSuchProcess, psutil.AccessDenied): pass + if still_alive: + _, stubborn = psutil.wait_procs(still_alive, timeout=3) + if stubborn: + print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}") + else: + print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}") + elif matched: + print(f"Terminated vllm pids: {[p.pid for p in matched]}") + except Exception as e: print(f"Error in psutil vllm cleanup: {e}") diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py index aacc016aa9c..b414fe30eeb 100644 --- a/tests/e2e/offline_inference/test_diffusion_lora.py +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -24,6 +24,7 @@ # This test is specific to Z-Image LoRA behavior. Keep it focused on a single # model to reduce runtime and avoid extra downloads. models = ["Tongyi-MAI/Z-Image-Turbo"] +DIFFUSION_INIT_TIMEOUT_S = 600 @pytest.mark.parametrize("model_name", models) @@ -76,7 +77,11 @@ def _write_zimage_lora(adapter_dir: Path) -> str: ) return str(adapter_dir) - m = Omni(model=model_name) + m = Omni( + model=model_name, + stage_init_timeout=DIFFUSION_INIT_TIMEOUT_S, + init_timeout=DIFFUSION_INIT_TIMEOUT_S, + ) try: # high resolution may cause OOM on L4 height = 256 diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py index a1f1d145004..e2079c9096d 100644 --- a/tests/e2e/online_serving/test_images_generations_lora.py +++ b/tests/e2e/online_serving/test_images_generations_lora.py @@ -28,6 +28,7 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" MODEL = "Tongyi-MAI/Z-Image-Turbo" +DIFFUSION_INIT_TIMEOUT_S = 600 PROMPT = "a photo of a cat sitting on a laptop keyboard" @@ -37,7 +38,17 @@ @pytest.fixture(scope="module") def omni_server(): - with OmniServer(MODEL, ["--num-gpus", "1"]) as server: + with OmniServer( + MODEL, + [ + "--num-gpus", + "1", + "--stage-init-timeout", + str(DIFFUSION_INIT_TIMEOUT_S), + "--init-timeout", + str(DIFFUSION_INIT_TIMEOUT_S), + ], + ) as server: yield server diff --git a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml index 2d5690d0034..12871951e6a 100644 --- a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml +++ b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml @@ -45,7 +45,7 @@ stage_args: max_model_len: 16384 max_num_batched_tokens: 16384 max_num_seqs: 1 - gpu_memory_utilization: 0.9 + gpu_memory_utilization: 0.4 skip_mm_profiling: true enforce_eager: true trust_remote_code: true @@ -72,7 +72,7 @@ stage_args: model_arch: Qwen2_5OmniForConditionalGeneration worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.9 #increase the gpu memory utilization to enable the test on H800 + gpu_memory_utilization: 0.5 #increase the gpu memory utilization to enable the test on H800 enforce_eager: true trust_remote_code: true enable_prefix_caching: false diff --git a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py index 84112336e4b..8e04b04966b 100644 --- a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py +++ b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py @@ -123,6 +123,7 @@ def make_mock_model(hidden: int = 8): cfg.video_token_index = VIDEO_TOKEN_ID cfg.audio_token_index = AUDIO_TOKEN_ID model.config = cfg + model._has_oov_mm_tokens = False def fake_lm_embed(ids: torch.Tensor) -> torch.Tensor: # Use .clone() so the tensor is contiguous (expand() creates a strided @@ -137,13 +138,12 @@ def fake_lm_embed(ids: torch.Tensor) -> torch.Tensor: model._embed_text_input_ids = lambda *a, **kw: SupportsMultiModal._embed_text_input_ids(model, *a, **kw) - def fake_super_embed(ids, mm_embs=None, *, is_multimodal=None, handle_oov_mm_token=False): + def fake_super_embed(ids, mm_embs=None, *, is_multimodal=None): return SupportsMultiModal.embed_input_ids( model, ids, mm_embs, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) model.embed_input_ids = lambda *a, **kw: Qwen2_5OmniThinkerForConditionalGeneration.embed_input_ids(model, *a, **kw) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 540fc391c51..3629ade63f5 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -364,6 +364,7 @@ def update_from_output( if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) + self.skipped_waiting.remove_requests(stopped_preempted_reqs) # [Main] Handle failed KV load requests if failed_kv_load_req_ids and not self.recompute_kv_load_failures: diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index a397e608518..64522b909d3 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -247,6 +247,15 @@ def schedule(self) -> SchedulerOutput: ) total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + + # Record the request ids scheduled in this step (v0.14.0 behavior). + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + + new_block_ids_to_zero = ( + (self.kv_cache_manager.take_new_block_ids() or None) if self.needs_kv_cache_zeroing else None + ) + scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -258,12 +267,9 @@ def schedule(self) -> SchedulerOutput: finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), preempted_req_ids=set(), + new_block_ids_to_zero=new_block_ids_to_zero, ) - # Record the request ids scheduled in this step (v0.14.0 behavior). - self.prev_step_scheduled_req_ids.clear() - self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) - # KVTransfer: package metadata if self.connector is not None: meta = self.connector.build_connector_meta(scheduler_output) @@ -496,6 +502,7 @@ def update_from_output( if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) + self.skipped_waiting.remove_requests(stopped_preempted_reqs) # Handle failed KV load requests if failed_kv_load_req_ids and not self.recompute_kv_load_failures: diff --git a/vllm_omni/diffusion/quantization/base.py b/vllm_omni/diffusion/quantization/base.py index 17e6d32ead2..8a35547f9eb 100644 --- a/vllm_omni/diffusion/quantization/base.py +++ b/vllm_omni/diffusion/quantization/base.py @@ -31,14 +31,15 @@ class DiffusionQuantizationConfig(ABC): # The underlying vLLM config instance _vllm_config: "QuantizationConfig | None" = None - def get_name(self) -> str: + @classmethod + def get_name(cls) -> str: """Return the quantization method name (e.g., 'fp8', 'int8'). - By default, delegates to the underlying vLLM config instance. + Delegates to the underlying vLLM config class's get_name(). """ - if self._vllm_config is not None: - return self._vllm_config.get_name() - raise NotImplementedError("Subclass must initialize _vllm_config or override get_name().") + if cls.quant_config_cls is not None: + return cls.quant_config_cls.get_name() + raise NotImplementedError("Subclass must set quant_config_cls or override get_name().") def get_vllm_quant_config(self) -> "QuantizationConfig | None": """Return the underlying vLLM QuantizationConfig for linear layers.""" diff --git a/vllm_omni/distributed/kv_transfer/monkey_patch.py b/vllm_omni/distributed/kv_transfer/monkey_patch.py index 455ad80652c..20ad53258ad 100644 --- a/vllm_omni/distributed/kv_transfer/monkey_patch.py +++ b/vllm_omni/distributed/kv_transfer/monkey_patch.py @@ -8,6 +8,7 @@ from __future__ import annotations +import importlib import logging import sys from dataclasses import dataclass @@ -30,30 +31,23 @@ class PatchedRecvReqMeta: def _import_mooncake_module(): """Import MooncakeConnector module, supporting both vLLM >=0.16 and older.""" - try: - from vllm.distributed.kv_transfer.kv_connector.v1.mooncake import mooncake_connector - - return mooncake_connector - except ImportError: - pass - try: - from vllm.distributed.kv_transfer.kv_connector.v1 import mooncake_connector - - return mooncake_connector - except ImportError: - return None + for mod_path in ( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector", + ): + try: + return importlib.import_module(mod_path) + except (ImportError, ModuleNotFoundError): + continue + return None def _create_patched_mooncake_connector(): """Return a subclass of MooncakeConnector with remote_request_id support.""" - try: - from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import ( - MooncakeConnector as _OriginalMooncakeConnector, - ) - except (ImportError, AttributeError): - from vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector import ( - MooncakeConnector as _OriginalMooncakeConnector, - ) + _mc_mod = _import_mooncake_module() + if _mc_mod is None: + raise ImportError("Cannot import MooncakeConnector from upstream vLLM") + _OriginalMooncakeConnector = _mc_mod.MooncakeConnector class PatchedMooncakeConnector(_OriginalMooncakeConnector): """Fixes request-ID mismatch in PD disaggregation by injecting diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py index a4e8bc0e932..284cc2d31a2 100644 --- a/vllm_omni/engine/stage_engine_core_client.py +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -38,8 +38,12 @@ def __init__( self, vllm_config: Any, executor_class: type, - metadata: StageMetadata, + log_stats: bool = False, client_addresses: dict[str, str] | None = None, + client_count: int = 1, + client_index: int = 0, + *, + metadata: StageMetadata | None = None, engine_manager: Any = None, coordinator: Any = None, ): @@ -51,17 +55,18 @@ def __init__( and calls super().__init__(). """ # -------- Stage metadata (public fields used at runtime) -------- - self.stage_id = metadata.stage_id - self.stage_type = metadata.stage_type - self.engine_output_type = metadata.engine_output_type - self.is_comprehension = metadata.is_comprehension - self.requires_multimodal_data = metadata.requires_multimodal_data - self.engine_input_source = metadata.engine_input_source - self.final_output = metadata.final_output - self.final_output_type = metadata.final_output_type - self.default_sampling_params = metadata.default_sampling_params - self.custom_process_input_func = metadata.custom_process_input_func - self.model_stage = metadata.model_stage + if metadata is not None: + self.stage_id = metadata.stage_id + self.stage_type = metadata.stage_type + self.engine_output_type = metadata.engine_output_type + self.is_comprehension = metadata.is_comprehension + self.requires_multimodal_data = metadata.requires_multimodal_data + self.engine_input_source = metadata.engine_input_source + self.final_output = metadata.final_output + self.final_output_type = metadata.final_output_type + self.default_sampling_params = metadata.default_sampling_params + self.custom_process_input_func = metadata.custom_process_input_func + self.model_stage = metadata.model_stage self.engine_outputs: Any = None @@ -73,8 +78,10 @@ def __init__( super().__init__( vllm_config, executor_class, - log_stats=False, + log_stats=log_stats, client_addresses=client_addresses, + client_count=client_count, + client_index=client_index, ) if engine_manager is not None: self.resources.engine_manager = engine_manager diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 3ba2c4a7efd..18b4f200365 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: from vllm.inputs.preprocess import InputPreprocessor from vllm.tokenizers import TokenizerLike + from vllm.v1.engine import PauseMode from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams @@ -66,7 +67,7 @@ class AsyncOmni(EngineClient, OmniBase): ... print(output) """ - def __init__(self, model: str, **kwargs: Any) -> None: + def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None: OmniBase.__init__(self, model=model, **kwargs) self._pause_cond: asyncio.Condition = asyncio.Condition() self._paused: bool = False @@ -130,9 +131,13 @@ def model_config(self): async def generate( self, prompt: OmniPromptType, - request_id: str, - sampling_params_list: Sequence[OmniSamplingParams] | None = None, + sampling_params: Any = None, + request_id: str = "", *, + prompt_text: str | None = None, + lora_request: Any = None, + tokenization_kwargs: dict[str, Any] | None = None, + sampling_params_list: Sequence[OmniSamplingParams] | None = None, output_modalities: list[str] | None = None, ) -> AsyncGenerator[OmniRequestOutput, None]: """Generate outputs for the given prompt asynchronously. @@ -226,6 +231,7 @@ async def encode( trace_headers: dict[str, str] | None = None, priority: int = 0, tokenization_kwargs: dict[str, Any] | None = None, + reasoning_ended: bool | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """EngineClient.encode() stub. @@ -396,6 +402,7 @@ async def abort(self, request_id: str | Iterable[str]) -> None: async def pause_generation( self, *, + mode: PauseMode = "abort", wait_for_inflight_requests: bool = False, clear_cache: bool = True, ) -> None: @@ -467,7 +474,7 @@ async def reset_prefix_cache( logger.warning("[AsyncOmni] reset_prefix_cache not yet supported with Orchestrator process") return True - async def sleep(self, level: int = 1) -> None: + async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None: """Sleep all stages. Best-effort: unsupported stages will emit a TODO result. @@ -584,7 +591,7 @@ async def check_health(self) -> None: # ==================== Shutdown ==================== - def shutdown(self) -> None: + def shutdown(self, timeout: float | None = None) -> None: """Shutdown the engine.""" if self.final_output_task is not None: self.final_output_task.cancel() diff --git a/vllm_omni/entrypoints/cli/benchmark/main.py b/vllm_omni/entrypoints/cli/benchmark/main.py index 8880e35c7cf..865064d1e9e 100644 --- a/vllm_omni/entrypoints/cli/benchmark/main.py +++ b/vllm_omni/entrypoints/cli/benchmark/main.py @@ -9,7 +9,7 @@ from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase if typing.TYPE_CHECKING: - from vllm.utils import FlexibleArgumentParser + from vllm.utils.argparse_utils import FlexibleArgumentParser class OmniBenchmarkSubcommand(CLISubcommand): diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 0de412a961f..458996ff465 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -53,6 +53,11 @@ def omni_snapshot_download(model_id: str) -> str: ) except huggingface_hub.errors.RepositoryNotFoundError: logger.warning("Repository not found for '%s'.", model_id) + except PermissionError: + logger.warning( + "Permission denied when downloading '%s'. Assuming the model is already cached locally.", + model_id, + ) return model_id diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index c3c250fda7f..52edab11844 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -35,13 +35,6 @@ from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.openai.api_server import build_app as build_openai_app from vllm.entrypoints.openai.api_server import setup_server as setup_openai_server - -# vLLM moved `base` from openai.basic.api_router to serve.instrumentator.basic. -# Keep a fallback for older/newer upstream layouts during rebase windows. -try: - from vllm.entrypoints.serve.instrumentator.basic import base -except ModuleNotFoundError: - from vllm.entrypoints.openai.basic.api_router import base from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -68,10 +61,15 @@ ) from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.serving import ServingClassification -from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding +from vllm.entrypoints.pooling.embed.serving import ServingEmbedding as OpenAIServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.entrypoints.serve.disagg.serving import ServingTokens + +# vLLM moved `base` from openai.basic.api_router to serve.instrumentator.basic. +# Keep a fallback for older/newer upstream layouts during rebase windows. +from vllm.entrypoints.serve.instrumentator.basic import base +from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.utils import ( load_aware_call, @@ -329,6 +327,7 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + ssl_ciphers=args.ssl_ciphers, h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, h11_max_header_count=args.h11_max_header_count, **uvicorn_kwargs, @@ -572,7 +571,13 @@ async def omni_init_app_state( else vllm_config.model_config ) io_processor_plugin = model_config.io_processor_plugin - engine_client.io_processor = get_io_processor(vllm_config, io_processor_plugin) + renderer = getattr(engine_client, "renderer", None) + if renderer is None: + from vllm.renderers import renderer_from_config + + renderer = renderer_from_config(vllm_config) + engine_client.renderer = renderer + engine_client.io_processor = get_io_processor(vllm_config, renderer, io_processor_plugin) logger.info("Initialized io_processor for AsyncOmni") else: logger.warning("Cannot initialize processors: tokenizer is None. OpenAIServingModels may fail.") @@ -591,6 +596,22 @@ async def omni_init_app_state( ) await state.openai_serving_models.init_static_loras() + state.openai_serving_render = OpenAIServingRender( + model_config=engine_client.model_config, + renderer=engine_client.renderer, + io_processor=engine_client.io_processor, + model_registry=state.openai_serving_models.registry, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + default_chat_template_kwargs=args.default_chat_template_kwargs, + log_error_stack=args.log_error_stack, + ) + state.openai_serving_responses = ( OpenAIServingResponses( engine_client, @@ -606,7 +627,6 @@ async def omni_init_app_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, - log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None @@ -616,6 +636,7 @@ async def omni_init_app_state( engine_client, state.openai_serving_models, args.response_role, + openai_serving_render=state.openai_serving_render, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -630,24 +651,23 @@ async def omni_init_app_state( enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, enable_log_deltas=args.enable_log_deltas, - log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None ) # Warm up chat template processing to avoid first-request latency if state.openai_serving_chat is not None: - await state.openai_serving_chat.warmup() + state.openai_serving_chat.warmup() state.openai_serving_completion = ( OpenAIServingCompletion( engine_client, state.openai_serving_models, + openai_serving_render=state.openai_serving_render, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None @@ -661,7 +681,6 @@ async def omni_init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) if any(task in POOLING_TASKS for task in supported_tasks) else None @@ -674,7 +693,6 @@ async def omni_init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None @@ -687,7 +705,6 @@ async def omni_init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks else None @@ -700,7 +717,7 @@ async def omni_init_app_state( score_template=resolved_chat_template, log_error_stack=args.log_error_stack, ) - if ("embed" in supported_tasks or "score" in supported_tasks) + if any(t in supported_tasks for t in ("embed", "score", "token_embed")) else None ) state.openai_serving_tokenization = OpenAIServingTokenization( @@ -709,15 +726,14 @@ async def omni_init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + default_chat_template_kwargs=args.default_chat_template_kwargs, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) state.openai_serving_transcription = ( OpenAIServingTranscription( engine_client, state.openai_serving_models, request_logger=request_logger, - log_error_stack=args.log_error_stack, enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks @@ -728,7 +744,6 @@ async def omni_init_app_state( engine_client, state.openai_serving_models, request_logger=request_logger, - log_error_stack=args.log_error_stack, enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks @@ -739,6 +754,7 @@ async def omni_init_app_state( engine_client, state.openai_serving_models, args.response_role, + openai_serving_render=state.openai_serving_render, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -758,7 +774,6 @@ async def omni_init_app_state( state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - log_error_stack=args.log_error_stack, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_log_outputs=args.enable_log_outputs, force_no_detokenize=args.tokens_only, diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index d5eee3933c7..49fd0c03a22 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -256,7 +256,9 @@ async def create_chat_completion( ) else: should_include_tools = tool_dicts is not None - conversation, engine_prompts = self._make_request_with_harmony(request, should_include_tools) + conversation, engine_prompts = self.openai_serving_render._make_request_with_harmony( + request, should_include_tools + ) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 283f7206ac2..5a0ae996570 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1433,7 +1433,6 @@ def embed_input_ids( multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, ) -> torch.Tensor: """Embed input IDs with optional multimodal embeddings.""" # Get text embeddings diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py index c1b71e076e4..1424ca7756b 100644 --- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py +++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py @@ -773,7 +773,6 @@ def embed_input_ids( multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: @@ -783,7 +782,6 @@ def embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) def base_local_forward( diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py index 988d8c23673..dbb6c49efae 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py @@ -109,7 +109,6 @@ def embed_input_ids( multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: @@ -119,7 +118,6 @@ def embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) def forward( @@ -232,7 +230,7 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 64eaefaa28a..0307034089c 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -102,6 +102,15 @@ def _maybe_apply_prompt_updates( tokenizer = self.info.get_tokenizer() audio_pad_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>") use_audio_in_video = audio_pad_id not in prompt_ids + # for mutilmodality cache + if any(item is None for item in mm_kwargs["video"]): + video_token_id = self.info.get_hf_config().video_token_id + audio_token_id = self.info.get_hf_config().audio_token_id + video_audio_item_num = sum(id in (video_token_id, audio_token_id) for id in prompt_ids) + audio_updates_num = len(mm_prompt_updates.get("audio", [])) + video_updates_num = len(mm_prompt_updates.get("video", [])) + if video_audio_item_num != video_updates_num + audio_updates_num: + use_audio_in_video = True if is_update_applied: mm_placeholders = self._find_mm_placeholders( @@ -586,11 +595,19 @@ def embed_input_ids( multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, ) -> torch.Tensor: if multimodal_embeddings is None or is_multimodal is None: return super().embed_input_ids(input_ids) + inputs_embeds = self._embed_text_input_ids( + input_ids, + self.get_language_model().embed_input_ids, + is_multimodal=is_multimodal, + ) + + if len(multimodal_embeddings) == 0: + return inputs_embeds + # Check for audio-in-video: interleaved video and audio tokens # in the multimodal region. Only use the interleaved path when # needed; otherwise fall back to the default parent implementation. @@ -608,7 +625,6 @@ def embed_input_ids( input_ids, self.get_language_model().embed_input_ids, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) return merge_interleaved_embeddings( inputs_embeds, @@ -625,7 +641,6 @@ def embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) def forward( diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 9f5c20cf7e4..7b3d0b017c0 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -341,7 +341,7 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: ) # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). dummy_multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py index 7dfb988eb88..4693a231452 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py @@ -29,6 +29,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from packaging.version import Version from transformers import PretrainedConfig from transformers import __version__ as TRANSFORMERS_VERSION @@ -139,44 +140,66 @@ def forward( rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw) if isinstance(grid_thw, torch.Tensor): - grid_thw_np = grid_thw.cpu().numpy().astype(np.int32) + grid_thw_tensor = grid_thw.to(self.device) else: - grid_thw_np = np.array(grid_thw, dtype=np.int32) - - cu_seqlens = np.repeat(grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]).cumsum(axis=0, dtype=np.int32) - cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) - - sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(self.attn_backend, cu_seqlens) - if sequence_lengths is not None: - sequence_lengths = torch.from_numpy(sequence_lengths).to(self.device, non_blocking=True) - max_seqlen = torch.tensor( - MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens), - dtype=torch.int32, - device=self.device, + grid_thw_tensor = torch.as_tensor(grid_thw, dtype=torch.int32, device=self.device) + + try: + cu_seqlens = torch.repeat_interleave( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], + grid_thw_tensor[:, 0], + ).cumsum( + dim=0, + dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + except RuntimeError: + logger.warning( + "torch.repeat_interleave not executable, switching to vectorized searchsorted implementation." + ) + repeat_counts = grid_thw_tensor[:, 0] + values = grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2] + repeat_cumsum = repeat_counts.cumsum(0) + total_items = repeat_cumsum[-1].item() + indices = torch.searchsorted( + repeat_cumsum, + torch.arange(total_items, device=grid_thw_tensor.device), + right=True, + ) + cu_seqlens = values[indices].cumsum( + dim=0, + dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device) + rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + + grid_thw_np = grid_thw_tensor.cpu().numpy().astype(np.int32) + cu_seqlens_np = np.repeat(grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]).cumsum( + axis=0, dtype=np.int32 ) - cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np]) + sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens( self.attn_backend, - cu_seqlens, - self.hidden_size, - self.tp_size, + cu_seqlens_np, + self.device, ) - cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True) - - hidden_states = hidden_states.unsqueeze(1) hidden_states_list = [] deepstack_visual_indexes = self.deepstack_visual_indexes for layer_num, blk in enumerate(self.blocks): - hidden_states = hidden_states + blk.attn( - blk.norm1(hidden_states), + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, sequence_lengths=sequence_lengths, ) - hidden_states = hidden_states + blk.mlp(blk.norm2(hidden_states)) if deepstack_visual_indexes is not None and layer_num in deepstack_visual_indexes: hidden_states_list.append(hidden_states) @@ -475,6 +498,15 @@ def _maybe_apply_prompt_updates( tokenizer = self.info.get_tokenizer() audio_pad_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>") use_audio_in_video = audio_pad_id not in prompt_ids + # for mutilmodality cache + if any(item is None for item in mm_kwargs["video"]): + video_token_id = self.info.get_hf_config().video_token_id + audio_token_id = self.info.get_hf_config().audio_token_id + video_audio_item_num = sum(id in (video_token_id, audio_token_id) for id in prompt_ids) + audio_updates_num = len(mm_prompt_updates.get("audio", [])) + video_updates_num = len(mm_prompt_updates.get("video", [])) + if video_audio_item_num != video_updates_num + audio_updates_num: + use_audio_in_video = True # normal case with `use_audio_in_video=False` if is_update_applied: @@ -487,19 +519,13 @@ def _maybe_apply_prompt_updates( mm_item_counts, ) else: - if use_audio_in_video: - # When use_audio_in_video=True, audio is extracted from video and embedded - # in the video placeholder tokens. We should: - # 1. Filter out audio from prompt updates (audio has no separate placeholder) - # 2. Apply remaining updates (video, image, etc.) - # 3. Derive audio placeholders from video placeholders + if use_audio_in_video and "audio" in mm_prompt_updates: filtered_updates = {k: v for k, v in mm_prompt_updates.items() if k != "audio"} prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, filtered_updates, ) - # Derive audio placeholders from video placeholders - mm_placeholders = self._derive_audio_from_video_placeholders(mm_placeholders, mm_item_counts) + mm_placeholders = self._derive_audio_from_video_placeholders(mm_placeholders, mm_prompt_updates) else: prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, @@ -618,7 +644,7 @@ def get_replacement_qwen2_vision(item_idx: int, modality: str): def get_replacement_qwen2_use_audio_in_video(item_idx: int): nonlocal audio_in_video_item_idx - audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] + audio_num_features = audio_output_lengths[audio_in_video_item_idx] video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 @@ -664,27 +690,17 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): def _derive_audio_from_video_placeholders( self, placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: """ Helper to derive audio placeholders from video placeholders when use_audio_in_video=True. - - In use_audio_in_video mode, audio is extracted from video and embedded - within the video placeholder tokens. This function creates audio placeholder - info by extracting the audio token positions from video placeholders. - - Args: - placeholders: Current placeholders (should contain "video") - mm_item_counts: Counts of multimodal items from mm_items.get_all_counts() """ if "video" not in placeholders: return placeholders - # Validate audio and video counts match - # In use_audio_in_video mode, audio count comes from mm_items (extracted from video) num_videos = len(placeholders["video"]) - num_audios = mm_item_counts.get("audio", 0) + num_audios = len(mm_prompt_updates.get("audio", [])) if num_audios != num_videos: raise ValueError( f"use_audio_in_video requires equal number of audio and video items, got {num_audios=}, {num_videos=}" @@ -942,7 +958,7 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -966,13 +982,11 @@ def embed_input_ids( multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, ) -> torch.Tensor: inputs_embeds = self._embed_text_input_ids( input_ids, self.language_model.embed_input_ids, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: @@ -1064,7 +1078,6 @@ def embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, ) def forward( diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index 22badb49ef6..3d9cb86bacd 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -26,6 +26,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, PerLayerAttnMetadata from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm_ascend.ascend_forward_context import set_ascend_forward_context @@ -237,6 +238,23 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -339,7 +357,7 @@ def execute_model( skip_compiled=has_encoder_input, ), self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata + scheduler_output, defer_finalize=not clear_kv_metadata ) as kv_connector_output, ): hidden_states = self._model_forward( @@ -606,6 +624,34 @@ def propose_draft_token_ids(sampled_token_ids): hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output ) + # Pre-copy multimodal tensors to CPU once (not per-request) to avoid + # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + mm_cpu: dict[str, object] = {} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + for k, v in multimodal_outputs.items(): + try: + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_cpu[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: + sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() + if sub_dict: + mm_cpu[k] = sub_dict + elif isinstance(v, list): + if len(v) == 0: + continue + cpu_list = [] + for elem in v: + if isinstance(elem, torch.Tensor): + cpu_list.append(elem.detach().to("cpu").contiguous()) + else: + cpu_list.append(elem) + mm_cpu[k] = cpu_list + except Exception as e: + logger.error(f"Error in merge multimodal outputs: {e}") + pooler_output: list[dict[str, object]] = [] for rid in req_ids_output_copy: idx = req_id_to_index_output_copy[rid] @@ -614,28 +660,26 @@ def propose_draft_token_ids(sampled_token_ids): end = start + sched hidden_slice = hidden_states_cpu[start:end] payload: dict[str, object] = {"hidden": hidden_slice} - if isinstance(multimodal_outputs, dict) and multimodal_outputs: + if mm_cpu: mm_payload: dict[str, object] = {} - for k, v in multimodal_outputs.items(): - try: - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: - mm_payload[k] = v.detach().to("cpu")[start:end].contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: - sub_dict[str(sk)] = sv.detach().to("cpu")[start:end].contiguous() - if sub_dict: - mm_payload[k] = sub_dict - elif isinstance(v, list): - element = v[0] - if isinstance(element, torch.Tensor): - element = element.detach().to("cpu").contiguous() - mm_payload[k] = element - except Exception as e: - logger.error(f"Error in merge multimodal outputs: {e}") - if mm_payload: - payload.update(mm_payload) + for k, v in mm_cpu.items(): + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_payload[k] = v[start:end].contiguous() + elif isinstance(v, dict): + mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()} + elif isinstance(v, list): + element = v[idx] if idx < len(v) else v[0] + # Clone tensors to avoid cross-request aliasing + if isinstance(element, torch.Tensor): + element = element.clone() + mm_payload[k] = element + elif isinstance(v, torch.Tensor): + # List-derived tensor payloads are request-invariant; clone to + # avoid accidental cross-request aliasing on downstream mutation. + mm_payload[k] = v.clone() + else: + mm_payload[k] = v + payload.update(mm_payload) pooler_output.append(payload) model_runner_output = OmniModelRunnerOutput( diff --git a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py index 0296c042d5b..48a9982801d 100644 --- a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py @@ -20,6 +20,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, make_empty_encoder_model_runner_output from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, PerLayerAttnMetadata from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs @@ -195,6 +196,23 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -300,7 +318,7 @@ def execute_model( skip_compiled=has_encoder_input, ), self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata + scheduler_output, defer_finalize=not clear_kv_metadata ) as kv_connector_output, ): # -------------------------------------- Omni-new ------------------------------------------------- @@ -503,6 +521,7 @@ def _dummy_run( remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # only support eager mode and piecewise graph now assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() @@ -609,11 +628,14 @@ def _dummy_run( # seq_lens. We use this seq_len only when capturing graph, and still use max_query_len # in inference. This will be removed once npu_fused_infer_attention_score # outperforms _npu_paged_attention on all cases. - seq_lens = ( - SEQ_LEN_WITH_MAX_PA_WORKSPACE - if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) - else max_query_len - ) # type: ignore[assignment] + if profile_seq_lens is not None: + seq_lens = profile_seq_lens + else: + seq_lens = ( + SEQ_LEN_WITH_MAX_PA_WORKSPACE + if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) + else max_query_len + ) # type: ignore[assignment] self.seq_lens.np[:num_reqs_padded] = seq_lens self.seq_lens.np[num_reqs_padded:] = 0 self.seq_lens.copy_to_gpu() diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py index 9ff5f720e96..8ef39adfa67 100644 --- a/vllm_omni/platforms/npu/worker/npu_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py @@ -35,9 +35,11 @@ def load_model(self, *args, **kwargs) -> None: enable_sp(self.vllm_config) # TODO move this model specific logic to a separate class # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS. + self.has_talker_mtp = False talker_mtp = getattr(self.model, "talker_mtp", None) if talker_mtp is not None: self.talker_mtp = talker_mtp # type: ignore[assignment] + self.has_talker_mtp = True cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that @@ -73,6 +75,7 @@ def _dummy_run( remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # only support eager mode and piecewise graph now assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() @@ -179,11 +182,14 @@ def _dummy_run( # seq_lens. We use this seq_len only when capturing graph, and still use max_query_len # in inference. This will be removed once npu_fused_infer_attention_score # outperforms _npu_paged_attention on all cases. - seq_lens = ( - SEQ_LEN_WITH_MAX_PA_WORKSPACE - if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) - else max_query_len - ) # type: ignore[assignment] + if profile_seq_lens is not None: + seq_lens = profile_seq_lens + else: + seq_lens = ( + SEQ_LEN_WITH_MAX_PA_WORKSPACE + if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config) + else max_query_len + ) # type: ignore[assignment] self.seq_lens.np[:num_reqs_padded] = seq_lens self.seq_lens.np[num_reqs_padded:] = 0 self.seq_lens.copy_to_gpu() @@ -283,7 +289,7 @@ def dummy_drafter_compute_logits(hidden_states): model_instance=self.model, ): # ---------------------------------------Omni-new---------------------------------------------- - if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + if getattr(self.model, "talker", None) is not None and self.has_talker_mtp: num_tokens_padded_talker_mtp = num_tokens_padded if num_tokens_padded_talker_mtp == self.max_num_tokens: num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0] @@ -362,7 +368,7 @@ def _model_forward( # Omni-specific: wrap output if needed if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"): - model_output = self.model.make_omni_output(model_output, **model_kwargs_extra) + model_output = self.model.make_omni_output(model_output, **model_kwargs, **model_kwargs_extra) # Omni-specific: cache model output for later sample_tokens self._omni_last_model_output = model_output @@ -415,12 +421,14 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) - # update the inputs_embeds and code_predictor_codes - code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + # code_predictor_codes stays on GPU here; _update_intermediate_buffer + # keeps it device-resident when the key is in gpu_resident_buffer_keys. + # D2H is deferred to sample_tokens where hidden_states.to("cpu") already + # syncs the stream, avoiding a per-step cudaStreamSynchronize. out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes") for idx, req_id in enumerate(decode_req_ids): req_index = self.input_batch.req_ids.index(req_id) start_offset = int(self.query_start_loc.cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {out_key: code_predictor_codes_cpu[idx : idx + 1]} + update_dict = {out_key: code_predictor_codes[idx : idx + 1]} self._merge_additional_information_update(req_id, update_dict) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 0b8257cc690..697c39d242e 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -123,7 +123,7 @@ def execute_model( request_id_resolver=self._resolve_global_request_id, ) - if self.vllm_config.model_config.enable_return_routed_experts: + if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.clear_buffer() # noqa @@ -141,7 +141,7 @@ def execute_model( # Update persistent batch states. self._update_states(scheduler_output) - if has_ec_transfer() and get_ec_transfer().is_producer: + if has_ec_transfer() and not get_ec_transfer().is_consumer: with self.maybe_get_ec_connector_output( scheduler_output, encoder_cache=self.encoder_cache, @@ -276,9 +276,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay clearing connector metadata - # until after draft model runs in sample_tokens. - clear_kv_metadata = self.speculative_config is None + # When spec decode is enabled, defer connector finalization + # (wait_for_save + clear metadata) until after draft model runs. + defer_kv_connector_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -292,7 +292,8 @@ def execute_model( ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata + scheduler_output, + defer_finalize=defer_kv_connector_finalize, ) as kv_connector_output, ): model_output = self._model_forward( @@ -533,11 +534,11 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - # Clear KV connector metadata after draft model runs (if spec decode). - # This was deferred from target model forward to allow draft model - # to also save its KV cache. + # Finalize KV connector (wait_for_save + clear metadata) after + # draft model runs. Deferred from target model forward to allow + # draft model to also save its KV cache. if self.speculative_config is not None: - self.clear_kv_connector_metadata() + self.finalize_kv_connector() with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() @@ -617,7 +618,7 @@ def propose_draft_token_ids(sampled_token_ids): payload.update(mm_payload) pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): - if self.model_config.enable_return_routed_experts: + if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.save_captured_experts(indices=self.slot_mapping) # noqa diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py index 198d42c8a31..4abe21964b3 100644 --- a/vllm_omni/worker/gpu_ar_worker.py +++ b/vllm_omni/worker/gpu_ar_worker.py @@ -48,17 +48,17 @@ def init_device(self): # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK self.local_rank += dp_local_rank * tp_pp_world_size - assert self.local_rank < torch.cuda.device_count(), ( + assert self.local_rank < torch.accelerator.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + visible_device_count = torch.accelerator.device_count() if torch.cuda.is_available() else 0 assert self.parallel_config.local_world_size <= visible_device_count, ( f"local_world_size ({self.parallel_config.local_world_size}) must " f"be less than or equal to the number of visible devices " f"({visible_device_count})." ) self.device = torch.device(f"cuda:{self.local_rank}") - current_platform.set_device(self.device) + torch.accelerator.set_device_index(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) @@ -79,7 +79,7 @@ def init_device(self): # Now take memory snapshot after NCCL is initialized gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # take current memory snapshot self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 05785d7aa6c..4db683a8b4a 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -87,7 +87,7 @@ def execute_model( if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") - if self.vllm_config.model_config.enable_return_routed_experts: + if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.clear_buffer() # noqa @@ -108,7 +108,7 @@ def execute_model( if not scheduler_output.total_num_scheduled_tokens: return EMPTY_MODEL_RUNNER_OUTPUT - if has_ec_transfer() and get_ec_transfer().is_producer: + if has_ec_transfer() and not get_ec_transfer().is_consumer: with self.maybe_get_ec_connector_output( scheduler_output, encoder_cache=self.encoder_cache, @@ -265,9 +265,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay clearing connector metadata - # until after draft model runs in sample_tokens. - clear_kv_metadata = self.speculative_config is None + # When spec decode is enabled, defer connector finalization + # (wait_for_save + clear metadata) until after draft model runs. + defer_kv_connector_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -281,7 +281,8 @@ def execute_model( ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata + scheduler_output, + defer_finalize=defer_kv_connector_finalize, ) as kv_connector_output, ): outputs = self._run_generation_model( @@ -351,9 +352,10 @@ def sample_tokens( ) = self.execute_model_state self.execute_model_state = None - # Clear KV connector metadata after draft model runs (if spec decode). + # Finalize KV connector (wait_for_save + clear metadata) after + # draft model runs. Deferred from target model forward. if self.speculative_config is not None: - self.clear_kv_connector_metadata() + self.finalize_kv_connector() pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): @@ -476,6 +478,7 @@ def _dummy_run( remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -500,6 +503,9 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run num_active_loras: Number of distinct active LoRAs to capture for. LoRA is activated when num_active_loras > 0. + profile_seq_lens: If provided, use this value for seq_lens instead + of max_query_len. Used to profile attention workspace that + scales with context length. """ mm_config = self.vllm_config.model_config.multimodal_config if mm_config and mm_config.mm_encoder_only: @@ -621,11 +627,13 @@ def _dummy_run( # If force_attention is True, we always capture attention. # Otherwise, it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - if create_mixed_batch: + if profile_seq_lens is not None: + seq_lens = profile_seq_lens # type: ignore[assignment] + elif create_mixed_batch: # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. # TODO(luka) better system for describing dummy batches - seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] else: seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens diff --git a/vllm_omni/worker/gpu_generation_worker.py b/vllm_omni/worker/gpu_generation_worker.py index 7de9a79227b..267ed61c0a4 100644 --- a/vllm_omni/worker/gpu_generation_worker.py +++ b/vllm_omni/worker/gpu_generation_worker.py @@ -48,17 +48,17 @@ def init_device(self): # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK self.local_rank += dp_local_rank * tp_pp_world_size - assert self.local_rank < torch.cuda.device_count(), ( + assert self.local_rank < torch.accelerator.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + visible_device_count = torch.accelerator.device_count() if torch.cuda.is_available() else 0 assert self.parallel_config.local_world_size <= visible_device_count, ( f"local_world_size ({self.parallel_config.local_world_size}) must " f"be less than or equal to the number of visible devices " f"({visible_device_count})." ) self.device = torch.device(f"cuda:{self.local_rank}") - current_platform.set_device(self.device) + torch.accelerator.set_device_index(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) @@ -79,7 +79,7 @@ def init_device(self): # Now take memory snapshot after NCCL is initialized gc.collect() - torch.cuda.empty_cache() + torch.accelerator.empty_cache() # take current memory snapshot self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index ddc6c119024..3965670027f 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -257,6 +257,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests.pop(req_id, None) self.model_intermediate_buffer.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + if hasattr(self, "late_interaction_runner"): + self.late_interaction_runner.on_requests_finished(scheduler_output.finished_req_ids) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -266,6 +268,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) + # Zero GPU memory for freshly allocated cache blocks to prevent + # stale NaN/data from corrupting attention or SSM computation. + if hasattr(scheduler_output, "new_block_ids_to_zero") and scheduler_output.new_block_ids_to_zero: + self._zero_block_ids(scheduler_output.new_block_ids_to_zero) + # Free the cached encoder outputs. for mm_hash in scheduler_output.free_encoder_mm_hashes: self.encoder_cache.pop(mm_hash, None) @@ -333,6 +340,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: lora_request=new_req_data.lora_request, ) self.requests[req_id] = req_state + if hasattr(self, "late_interaction_runner"): + self.late_interaction_runner.register_request(req_id, pooling_params) # If prompt embeddings are provided, decode and attach to inter_data try: @@ -418,13 +427,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # prev_num_draft_len is used in async scheduling mode with # spec decode. it indicates if need to update num_computed_tokens # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], + # first step: num_computed_tokens = 0, spec_tokens = [], # prev_num_draft_len = 0. # second step: num_computed_tokens = 100(prompt length), # spec_tokens = [a,b], prev_num_draft_len = 0. # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain + # num_computed_tokens in first step and second step doesn't contain # the spec tokens length, but in third step it contains the # spec tokens length. we only need to update num_computed_tokens # when prev_num_draft_len > 0. @@ -551,6 +560,7 @@ def _dummy_run( remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -575,6 +585,9 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run num_active_loras: Number of distinct active LoRAs to capture for. LoRA is activated when num_active_loras > 0. + profile_seq_lens: If provided, use this value for seq_lens instead + of max_query_len. Used to profile attention workspace that + scales with context length. """ mm_config = self.vllm_config.model_config.multimodal_config if mm_config and mm_config.mm_encoder_only: @@ -695,11 +708,13 @@ def _dummy_run( # If force_attention is True, we always capture attention. # Otherwise, it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - if create_mixed_batch: + if profile_seq_lens is not None: + seq_lens = profile_seq_lens # type: ignore[assignment] + elif create_mixed_batch: # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. # TODO(luka) better system for describing dummy batches - seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] else: seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens