diff --git a/tests/dfx/conftest.py b/tests/dfx/conftest.py index 29bb33302a7..fa33b96c9dc 100644 --- a/tests/dfx/conftest.py +++ b/tests/dfx/conftest.py @@ -70,17 +70,17 @@ def _build_serve_args(serve_args: Any) -> list[str]: def create_unique_server_params( configs: list[dict[str, Any]], stage_configs_dir: Path, -) -> list[tuple[str, str, str | None, str | None, tuple[str, ...]]]: - """Return one row per unique server configuration (same 5-tuple shape as upstream). +) -> list[tuple[str, str, str | None, str | None, tuple[str, ...], bool]]: + """Return one row per unique server configuration. - ``(test_name, model, deploy_yaml_path, stage_overrides_json, extra_cli_args)``. + ``(test_name, model, deploy_yaml_path, stage_overrides_json, extra_cli_args, use_omni)``. JSON ``server_params.serve_args`` (dict/list) is expanded via ``_build_serve_args`` and **prepended** to ``extra_cli_args`` so perf / stability ``omni_server`` fixtures stay identical to main while still honoring ``serve_args`` in benchmark JSON. """ - unique_params: list[tuple[str, str, str | None, str | None, tuple[str, ...]]] = [] - seen: set[tuple[str, str, str | None, str | None, tuple[str, ...]]] = set() + unique_params: list[tuple[str, str, str | None, str | None, tuple[str, ...], bool]] = [] + seen: set[tuple[str, str, str | None, str | None, tuple[str, ...], bool]] = set() for config in configs: test_name = config["test_name"] server_params = config["server_params"] @@ -104,8 +104,16 @@ def create_unique_server_params( serve_flat = _build_serve_args(server_params.get("serve_args")) raw_extra = tuple(server_params.get("extra_cli_args") or ()) extra_cli_args = tuple(serve_flat) + raw_extra - - server_param = (test_name, model, stage_config_path, stage_overrides_json, extra_cli_args) + use_omni = bool(server_params.get("use_omni", True)) + + server_param = ( + test_name, + model, + stage_config_path, + stage_overrides_json, + extra_cli_args, + use_omni, + ) if server_param not in seen: seen.add(server_param) unique_params.append(server_param) diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py index 9036508cb1c..c6185cb797c 100644 --- a/tests/dfx/perf/scripts/run_benchmark.py +++ b/tests/dfx/perf/scripts/run_benchmark.py @@ -65,7 +65,7 @@ def omni_server(request): Multi-stage initialization can take 10-20+ minutes. """ with _omni_server_lock: - test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param + test_name, model, stage_config_path, stage_overrides, extra_cli_args, use_omni = request.param print(f"Starting OmniServer with test: {test_name}, model: {model}") @@ -78,7 +78,7 @@ def omni_server(request): server_args = ["--stage-overrides", stage_overrides] + server_args if extra_cli_args: server_args = list(extra_cli_args) + server_args - with OmniServer(model, server_args) as server: + with OmniServer(model, server_args, use_omni=use_omni) as server: server.test_name = test_name print("OmniServer started successfully") yield server diff --git a/tests/dfx/stability/conftest.py b/tests/dfx/stability/conftest.py index 30718d4bf5a..744e7dfc388 100644 --- a/tests/dfx/stability/conftest.py +++ b/tests/dfx/stability/conftest.py @@ -35,10 +35,10 @@ def omni_server(request: pytest.FixtureRequest): """Start OmniServer for stability tests, with per-module timeout override.""" timeout_args = getattr(request.module, "STABILITY_SERVER_TIMEOUT_ARGS", DEFAULT_STABILITY_SERVER_TIMEOUT_ARGS) with _omni_server_lock: - # Same 5-tuple and CLI composition as ``tests/dfx/perf/scripts/run_benchmark.py`` on main; + # Same tuple and CLI composition as ``tests/dfx/perf/scripts/run_benchmark.py``; # ``serve_args`` from JSON are folded into ``extra_cli_args`` inside # ``create_unique_server_params``. - test_name, model, deploy_path, stage_overrides, extra_cli_args = request.param + test_name, model, deploy_path, stage_overrides, extra_cli_args, use_omni = request.param print(f"Starting OmniServer with test: {test_name}, model: {model}") server_args = list(timeout_args) @@ -48,7 +48,7 @@ def omni_server(request: pytest.FixtureRequest): server_args = ["--stage-overrides", stage_overrides] + server_args if extra_cli_args: server_args = list(extra_cli_args) + server_args - with OmniServer(model, server_args) as server: + with OmniServer(model, server_args, use_omni=use_omni) as server: server.test_name = test_name print("OmniServer started successfully") yield server diff --git a/vllm_omni/benchmarks/metrics/metrics.py b/vllm_omni/benchmarks/metrics/metrics.py index dbf764698a0..f320fffe9fc 100644 --- a/vllm_omni/benchmarks/metrics/metrics.py +++ b/vllm_omni/benchmarks/metrics/metrics.py @@ -188,7 +188,10 @@ def calculate_metrics( total_input += outputs[i].prompt_len tpot = 0 if output_len > 1: - latency_minus_ttft = outputs[i].text_latency - outputs[i].ttft + try: + latency_minus_ttft = outputs[i].text_latency - outputs[i].ttft + except Exception: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft tpot = latency_minus_ttft / (output_len - 1) tpots.append(tpot) # Note: if output_len <= 1, we regard tpot as 0 for goodput diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index a5579dd4640..3c767f63ab9 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -9,8 +9,8 @@ from vllm.distributed.kv_events import KVEventBatch from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger +from vllm.v1.core.sched.async_scheduler import AsyncScheduler as VLLMScheduler from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.core.sched.utils import remove_all from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.perf import PerfStats diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 7f36e38fd85..34e5b0967a0 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -296,7 +296,19 @@ async def create_chat_completion( if raw_request: raw_request.state.request_metadata = request_metadata - output_modalities = getattr(request, "modalities", self.engine_client.output_modalities) + # NOTE: + # - OpenAI python client flattens extra_body fields into model_extra. + # - Raw HTTP requests may keep them under request.extra_body. + # Keep modalities resolution tolerant so `--extra_body '{"modalities":["text"]}'` + # can reliably drive multi-stage routing. + request_extra_body = getattr(request, "extra_body", None) or request.model_extra or {} + output_modalities = getattr(request, "modalities", None) + if output_modalities is None: + output_modalities = request_extra_body.get("modalities") + if isinstance(output_modalities, str): + output_modalities = [output_modalities] + if output_modalities is not None: + output_modalities = [str(m).lower() for m in output_modalities] request.modalities = ( output_modalities if output_modalities is not None else self.engine_client.output_modalities ) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 8018df343ee..f7da44067e7 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -28,6 +28,17 @@ _HIDDEN_LAYER_KEY = "24" +def _layer_tensor(layers: dict[Any, Any], key: str) -> torch.Tensor | None: + """Fetch layer tensor with tolerant key lookup (str/int).""" + if not isinstance(layers, dict): + return None + key_int = int(key) + val = layers.get(key_int) + if val is None: + val = layers.get(key) + return val if isinstance(val, torch.Tensor) else None + + def _compute_talker_prompt_ids_length(info: OmniPayload, device: torch.device | str = "cuda") -> int: im_start_token_id = 151644 system_token_id = 8948 @@ -300,9 +311,23 @@ def thinker2talker_async_chunk( request_id = request.external_req_id chunk_id = transfer_manager.put_req_chunk[request_id] + if not isinstance(pooling_output, dict): + logger.debug("thinker2talker_async_chunk: skip non-dict pooling_output for req=%s", request_id) + return None + thinker_hs = pooling_output.get("hidden_states", {}) - thinker_layers = thinker_hs.get("layers", {}) - thinker_embed = pooling_output.get("embed", {}) + thinker_layers = thinker_hs.get("layers", {}) if isinstance(thinker_hs, dict) else {} + thinker_embed = pooling_output.get("embed", {}) if isinstance(pooling_output.get("embed", {}), dict) else {} + thinker_emb = _layer_tensor(thinker_layers, _EMBED_LAYER_KEY) + thinker_hid = _layer_tensor(thinker_layers, _HIDDEN_LAYER_KEY) + if thinker_emb is None or thinker_hid is None: + logger.debug( + "thinker2talker_async_chunk: missing thinker layers for req=%s (embed=%s hidden=%s)", + request_id, + thinker_emb is not None, + thinker_hid is not None, + ) + return None if chunk_id == 0: all_token_ids = request.all_token_ids # prefill + decode @@ -312,13 +337,19 @@ def thinker2talker_async_chunk( prompt_token_ids = _ensure_list(prompt_token_ids) payload: OmniPayload = { "embed": { - "prefill": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu(), + "prefill": thinker_emb.detach().cpu(), # Provide thinker-side TTS token embeddings for talker projection - "tts_bos": thinker_embed["tts_bos"].detach().cpu(), - "tts_eos": thinker_embed["tts_eos"].detach().cpu(), - "tts_pad": thinker_embed["tts_pad"].detach().cpu(), + "tts_bos": thinker_embed.get("tts_bos").detach().cpu() + if isinstance(thinker_embed.get("tts_bos"), torch.Tensor) + else None, + "tts_eos": thinker_embed.get("tts_eos").detach().cpu() + if isinstance(thinker_embed.get("tts_eos"), torch.Tensor) + else None, + "tts_pad": thinker_embed.get("tts_pad").detach().cpu() + if isinstance(thinker_embed.get("tts_pad"), torch.Tensor) + else None, }, - "hidden_states": {"output": thinker_layers[int(_HIDDEN_LAYER_KEY)].detach().cpu()}, + "hidden_states": {"output": thinker_hid.detach().cpu()}, "ids": {"all": all_token_ids, "prompt": prompt_token_ids}, "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)}, } @@ -366,12 +397,12 @@ def thinker2talker_async_chunk( if output_token_ids: talker_additional_info["meta"]["override_keys"] = [("embed", "decode"), ("ids", "output")] - talker_additional_info["embed"] = {"decode": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu()} + talker_additional_info["embed"] = {"decode": thinker_emb.detach().cpu()} talker_additional_info["ids"] = {"output": output_token_ids} else: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. - talker_additional_info["embed"] = {"prefill": thinker_layers[0].detach().cpu()} - talker_additional_info["hidden_states"] = {"output": thinker_layers[24].detach().cpu()} + talker_additional_info["embed"] = {"prefill": thinker_emb.detach().cpu()} + talker_additional_info["hidden_states"] = {"output": thinker_hid.detach().cpu()} return talker_additional_info @@ -431,11 +462,20 @@ def thinker2talker( thinker_sequences = prompt_token_ids + output_ids thinker_input_ids = prompt_token_ids new_seq_length = len(prompt_token_ids + output_ids) - 1 - thinker_mm: OmniPayload = output.multimodal_output + thinker_mm_raw = getattr(output, "multimodal_output", None) + if not isinstance(thinker_mm_raw, dict): + logger.debug("thinker2talker: skip req=%s due to empty multimodal_output", req_id) + continue + thinker_mm: OmniPayload = thinker_mm_raw mm_hs = thinker_mm.get("hidden_states", {}) - mm_layers = mm_hs.get("layers", {}) - thinker_emb = mm_layers[int(_EMBED_LAYER_KEY)].detach().to(device=device, dtype=torch.float)[-new_seq_length:] - thinker_hid = mm_layers[int(_HIDDEN_LAYER_KEY)].detach().to(device=device, dtype=torch.float)[-new_seq_length:] + mm_layers = mm_hs.get("layers", {}) if isinstance(mm_hs, dict) else {} + emb_layer = _layer_tensor(mm_layers, _EMBED_LAYER_KEY) + hid_layer = _layer_tensor(mm_layers, _HIDDEN_LAYER_KEY) + if emb_layer is None or hid_layer is None: + logger.debug("thinker2talker: skip req=%s due to missing hidden-state layers", req_id) + continue + thinker_emb = emb_layer.detach().to(device=device, dtype=torch.float)[-new_seq_length:] + thinker_hid = hid_layer.detach().to(device=device, dtype=torch.float)[-new_seq_length:] prefill_mm: dict[str, Any] | None = None if prefill_stage is not None: @@ -507,7 +547,11 @@ def talker2code2wav_async_chunk( """ Pooling version. """ + if not isinstance(pooling_output, dict): + return None talker_codes = pooling_output.get("codes", {}) + if not isinstance(talker_codes, dict): + return None code_predictor_codes = talker_codes.get("audio") if code_predictor_codes is None: return None @@ -600,7 +644,14 @@ def talker2code2wav( is_streaming_session = bool(getattr(streaming_context, "enabled", False)) if is_streaming_session: seq_len = _get_streaming_codec_delta_len(cur_seq_len, req_id, talker_output, streaming_context) - mm: OmniPayload = output.multimodal_output + mm_raw = getattr(output, "multimodal_output", None) + if not isinstance(mm_raw, dict): + logger.debug("talker2code2wav: skip req=%s due to empty multimodal_output", req_id) + continue + mm: OmniPayload = mm_raw + if "codes" not in mm or not isinstance(mm.get("codes"), dict) or "audio" not in mm["codes"]: + logger.debug("talker2code2wav: skip req=%s due to missing codes.audio", req_id) + continue # Extract codec codes from talker output # Expected shape: [8, seq_len] (8-layer RVQ codes) codec_codes = ( diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 947b3164f3e..008a9b6f50d 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -80,6 +80,7 @@ def __init__(self, *args, **kwargs): self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) # Initialize KV cache manager (preserve vllm_config fallback behavior) self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) + self._downstream_payload_cache: dict[str, bool] = {} def _make_buffer(self, *size, dtype, numpy=True): # Prevent ray from pinning the buffer due to large size @@ -141,8 +142,62 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata): return sampling_metadata return replace(sampling_metadata, output_token_ids=output_token_ids) + def _request_final_stage_id(self, req_id: str) -> int | None: + info = self.model_intermediate_buffer.get(req_id) + if not isinstance(info, dict): + req_state = self.requests.get(req_id) + info = getattr(req_state, "additional_information_cpu", None) + if not isinstance(info, dict): + return None + val = info.get("omni_final_stage_id") + if val is None: + return None + try: + return int(val) + except (TypeError, ValueError): + return None + + def _request_needs_downstream_stage_payload(self, req_id: str) -> bool: + cached = self._downstream_payload_cache.get(req_id) + if cached is not None: + return cached + # Conservative default: keep payload if marker is missing. + final_stage_id = self._request_final_stage_id(req_id) + needs_payload = final_stage_id is None or final_stage_id > 0 + self._downstream_payload_cache[req_id] = needs_payload + return needs_payload + + def _maybe_prune_downstream_payload_cache(self) -> None: + # Keep cache size bounded under long-lived serving workloads. + if len(self._downstream_payload_cache) <= max(4096, len(self.requests) * 2): + return + self._downstream_payload_cache = { + req_id: needs for req_id, needs in self._downstream_payload_cache.items() if req_id in self.requests + } + + def _warmup_single_request_prefill_compile_path(self) -> None: + # Warm up with a single long prefill request to avoid first-token + # compile spikes from wide prefill shapes. + warmup_tokens = min(int(self.max_num_tokens), 3072) + if warmup_tokens <= 1: + return + try: + with torch.inference_mode(): + self._dummy_run( + num_tokens=warmup_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_single_prefill_request=True, + allow_microbatching=False, + skip_eplb=True, + ) + torch.cuda.synchronize() + logger.info("Warmed single-request prefill compile path with %d tokens", warmup_tokens) + except Exception as e: + logger.warning("Skipped single-request prefill warmup due to: %s", e) + def capture_model(self) -> int: result = super().capture_model() + self._warmup_single_request_prefill_compile_path() self._capture_talker_mtp_graphs() return result @@ -833,7 +888,21 @@ def propose_draft_token_ids(sampled_token_ids): kv_connector_output = self.kv_connector_output self.kv_connector_output = None - hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() + self._maybe_prune_downstream_payload_cache() + downstream_req_ids = [rid for rid in req_ids_output_copy if self._request_needs_downstream_stage_payload(rid)] + needs_downstream_payload = len(downstream_req_ids) > 0 + downstream_req_id_set = set(downstream_req_ids) + hidden_states_cpu = None + req_hidden_states_cpu: dict[str, torch.Tensor] | None = None + if needs_downstream_payload: + num_valid_tokens = min( + int(scheduler_output.total_num_scheduled_tokens), + int(hidden_states.shape[0]), + ) + if len(downstream_req_ids) == len(req_ids_output_copy): + hidden_states_cpu = hidden_states[:num_valid_tokens].detach().to("cpu").contiguous() + else: + req_hidden_states_cpu = {} num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) if num_scheduled_tokens_np is None: req_ids = self.input_batch.req_ids @@ -841,78 +910,96 @@ def propose_draft_token_ids(sampled_token_ids): [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], dtype=np.int32, ) + query_start_loc_cpu = self.query_start_loc.cpu + + pooler_output: list[dict[str, object]] | None = None + if needs_downstream_payload: + # Prior to applying the post-processing func, extract + # the prefix cached hidden states and multimodal states. + if self.omni_prefix_cache is not None: + ( + combined_hidden_states, + combined_multimodal_outputs, + ) = self._maybe_get_combined_prefix_cache_tensors( + hidden_states, + multimodal_outputs, + scheduler_output.num_scheduled_tokens, + ) + # Otherwise we don't have the mm CPU data yet, so we still need to build it + if self.omni_prefix_cache is None: + mm_cpu = build_mm_cpu(flatten_payload(multimodal_outputs)) - # Prior to applying the post-processing func, extract - # the prefix cached hidden states and multimodal states. - if self.omni_prefix_cache is not None: - ( - combined_hidden_states, - combined_multimodal_outputs, - ) = self._maybe_get_combined_prefix_cache_tensors( + self._process_additional_information_updates( hidden_states, multimodal_outputs, - scheduler_output.num_scheduled_tokens, - ) - # Otherwise we don't have the mm CPU data yet, so we still need to build it - if self.omni_prefix_cache is None: - mm_cpu = build_mm_cpu(flatten_payload(multimodal_outputs)) - - self._process_additional_information_updates( - hidden_states, - multimodal_outputs, - num_scheduled_tokens_np, - scheduler_output, - combined_hidden_states, - combined_multimodal_outputs, - ) - - pooler_output: list[dict[str, object]] = [] - for rid in req_ids_output_copy: - idx = req_id_to_index_output_copy[rid] - start = int(self.query_start_loc.cpu[idx]) - sched = int(num_scheduled_tokens_np[idx]) - end = start + sched - # If prefix cache is enabled, we have already split everything - # by request and converted the states to CPU tensors - req_hidden_states = self._resolve_req_hidden_states( - hidden_states_cpu, + num_scheduled_tokens_np, + scheduler_output, combined_hidden_states, - rid, - start, - end, + combined_multimodal_outputs, + req_ids_filter=downstream_req_id_set, ) - payload: dict[str, object] = {"hidden": req_hidden_states} - - mm_payload: dict[str, object] = {} - if combined_multimodal_outputs or mm_cpu: - if combined_multimodal_outputs: - # Prefix cache enabled; all items have already been processed - # and split apart for each request as needed, and all tensors - # have already been detached to the CPU. The only exception is - # lists, which we keep as passthrough data for consistent behavior - # in postprocess. - for mm_key in combined_multimodal_outputs.keys(): - value = combined_multimodal_outputs[mm_key][rid] - if isinstance(value, list): - mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] - else: - mm_payload[mm_key] = value + if req_hidden_states_cpu is not None and combined_hidden_states is None: + for rid in downstream_req_ids: + idx = req_id_to_index_output_copy[rid] + start = int(query_start_loc_cpu[idx]) + sched = int(num_scheduled_tokens_np[idx]) + end = start + sched + req_hidden_states_cpu[rid] = hidden_states[start:end].detach().to("cpu").contiguous() + + pooler_output = [] + for rid in req_ids_output_copy: + if rid not in downstream_req_id_set: + pooler_output.append({}) + continue + idx = req_id_to_index_output_copy[rid] + start = int(query_start_loc_cpu[idx]) + sched = int(num_scheduled_tokens_np[idx]) + end = start + sched + # If prefix cache is enabled, we have already split everything + # by request and converted the states to CPU tensors + if req_hidden_states_cpu is not None and combined_hidden_states is None: + req_hidden_states = req_hidden_states_cpu[rid] else: - # Prefix cache disabled; we still need to process the data - for mm_key, mm_val in mm_cpu.items(): - mm_payload[mm_key] = to_payload_element( - element=mm_val, - idx=idx, - start=start, - end=end, - pass_lists_through=False, - seq_len=seq_len, - ) - payload.update(mm_payload) - # Flatten nested dicts to dotted keys so pooling_output - # stays dict[str, torch.Tensor] for msgspec serialization. - pooler_output.append(flatten_payload(payload)) + req_hidden_states = self._resolve_req_hidden_states( + hidden_states_cpu, + combined_hidden_states, + rid, + start, + end, + ) + payload: dict[str, object] = {"hidden": req_hidden_states} + + mm_payload: dict[str, object] = {} + if combined_multimodal_outputs or mm_cpu: + if combined_multimodal_outputs: + # Prefix cache enabled; all items have already been processed + # and split apart for each request as needed, and all tensors + # have already been detached to the CPU. The only exception is + # lists, which we keep as passthrough data for consistent behavior + # in postprocess. + for mm_key in combined_multimodal_outputs.keys(): + value = combined_multimodal_outputs[mm_key][rid] + if isinstance(value, list): + mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] + else: + mm_payload[mm_key] = value + + else: + # Prefix cache disabled; we still need to process the data + for mm_key, mm_val in mm_cpu.items(): + mm_payload[mm_key] = to_payload_element( + element=mm_val, + idx=idx, + start=start, + end=end, + pass_lists_through=False, + seq_len=seq_len, + ) + payload.update(mm_payload) + # Flatten nested dicts to dotted keys so pooling_output + # stays dict[str, torch.Tensor] for msgspec serialization. + pooler_output.append(flatten_payload(payload)) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() @@ -926,7 +1013,11 @@ def propose_draft_token_ids(sampled_token_ids): sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + pooler_output=( + pooler_output + if (self.vllm_config.model_config.engine_output_type != "text" and needs_downstream_payload) + else None + ), kv_connector_output=kv_connector_output, ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, num_nans_in_logits=num_nans_in_logits, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index efad5f3d4d5..d78fd117ca8 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -592,6 +592,7 @@ def _dummy_run( cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, + force_single_prefill_request: bool = False, allow_microbatching: bool = True, skip_eplb: bool = False, is_profile: bool = False, @@ -617,6 +618,8 @@ def _dummy_run( force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. + force_single_prefill_request: If True, force one request to consume + all tokens. Useful to warm up long-prefill compile paths. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. create_mixed_batch: If True, create a mixed batch with both decode @@ -657,7 +660,7 @@ def _dummy_run( assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs if create_mixed_batch: - assert not uniform_decode + assert not uniform_decode and not force_single_prefill_request # Create mixed batch: # first half decode tokens, second half one prefill num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) @@ -669,11 +672,17 @@ def _dummy_run( # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: - assert not create_mixed_batch + assert not create_mixed_batch and not force_single_prefill_request num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len + elif force_single_prefill_request: + # Keep one request with a long prompt to proactively compile the + # longest prefill path that first-token traffic will hit. + num_reqs = 1 + num_scheduled_tokens_list = [num_tokens] + max_query_len = num_tokens else: num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs @@ -1056,6 +1065,7 @@ def _process_additional_information_updates( scheduler_output: "SchedulerOutput", combined_hidden_states: dict[str, torch.Tensor] | None = None, combined_multimodal_outputs: dict[str, object] | None = None, + req_ids_filter: set[str] | None = None, ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" try: @@ -1063,6 +1073,8 @@ def _process_additional_information_updates( # TODO(Peiqi): do we have a more elegant way to do this? if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: for req_index, req_id in enumerate(self.input_batch.req_ids): + if req_ids_filter is not None and req_id not in req_ids_filter: + continue req_infos = self.model_intermediate_buffer.get(req_id, {}) if combined_hidden_states: # Combined hidden states contains all hidden states for every request