From 75f730a2b631cc5a9129190feccebf6d49af12ba Mon Sep 17 00:00:00 2001 From: amy-why-3459 Date: Sat, 14 Feb 2026 15:53:30 +0800 Subject: [PATCH] Qwen3-Omni performance optimization Signed-off-by: amy-why-3459 --- .../test_chunk_transfer_adapter.py | 16 +-- .../connectors/shm_connector.py | 87 +++++------- .../omni_connectors/transfer_adapter/base.py | 39 +++--- .../chunk_transfer_adapter.py | 124 ++++++++---------- .../models/qwen3_omni/qwen3_omni.py | 10 +- .../stage_input_processors/qwen3_omni.py | 18 ++- .../stage_input_processors/qwen3_tts.py | 12 +- .../worker/gpu_generation_model_runner.py | 2 +- vllm_omni/worker/gpu_model_runner.py | 14 +- 9 files changed, 142 insertions(+), 180 deletions(-) diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index a5ebc2227f4..17e3c4a143a 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -35,6 +35,7 @@ def _req(req_id: str, status: RequestStatus, external_req_id: str | None = None) prompt_token_ids=[], num_computed_tokens=0, additional_information=None, + is_finished=lambda: status == RequestStatus.FINISHED_STOPPED, ) @@ -48,9 +49,9 @@ def _build(*, stage_id: int = 1, model_mode: str = "ar", max_num_seqs: int = 2): def _fake_base_init(self, config): self.config = config - self._pending_load_reqs = {} + self._pending_load_reqs = deque() self._finished_load_reqs = set() - self._pending_save_reqs = {} + self._pending_save_reqs = deque() self._finished_save_reqs = set() self.stop_event = threading.Event() self.lock = threading.Lock() @@ -108,9 +109,8 @@ def test_load_poll(build_adapter): adapter.load_async(request) payload = {"code_predictor_codes": [[1]], "hidden_states": torch.tensor([[2.0]]), "finished": True} connector.get.return_value = (payload, 16) - adapter._poll_single_request("req-1") + adapter._poll_single_request(request) - connector.get.assert_called_once_with("1", "2", "external-1_1_0") assert request.additional_information == payload assert adapter.get_req_chunk["req-1"] == 1 assert "req-1" in adapter._finished_load_reqs @@ -120,17 +120,15 @@ def test_load_poll(build_adapter): def test_save_async(build_adapter): adapter, _ = build_adapter(stage_id=1) - request = SimpleNamespace(external_req_id="external-1") + request = _req("req-1", RequestStatus.WAITING, external_req_id="external-1") adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False} adapter.save_async(pooling_output=None, request=request) adapter.custom_process_next_stage_input_func = lambda **kwargs: {} adapter.save_async(pooling_output=None, request=request) - assert adapter.put_req_chunk["external-1"] == 1 - queued = adapter._pending_save_reqs["external-1"] - assert len(queued) == 1 - assert queued[0]["put_key"] == "external-1_1_0" + task = adapter._pending_save_reqs.popleft() + assert task["is_finished"] is False def test_update_request_payload(build_adapter): diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py index 812d8b9d5fa..3439e3a39e8 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -3,7 +3,7 @@ import fcntl import os -import time +from multiprocessing import shared_memory as shm_pkg from typing import Any from vllm_omni.entrypoints.stage_utils import shm_read_bytes, shm_write_bytes @@ -51,7 +51,7 @@ def put( if True: # Use Shared Memory lock_file = f"/dev/shm/shm_{put_key}_lockfile.lock" - with open(lock_file, "w") as lockf: + with open(lock_file, "wb+") as lockf: fcntl.flock(lockf, fcntl.LOCK_EX) meta = shm_write_bytes(payload, name=put_key) fcntl.flock(lockf, fcntl.LOCK_UN) @@ -75,6 +75,23 @@ def put( logger.error(f"SharedMemoryConnector put failed for req {put_key}: {e}") return False, 0, None + def _get_data_with_lock(self, lock_file: str, shm_handle: dict): + obj = None + try: + with open(lock_file, "rb+") as lockf: + fcntl.flock(lockf, fcntl.LOCK_EX) + data_bytes = shm_read_bytes(shm_handle) + fcntl.flock(lockf, fcntl.LOCK_UN) + obj = self.deserialize_obj(data_bytes) + return obj, int(shm_handle.get("size", 0)) + except Exception as e: + logger.error(f"SharedMemoryConnector shm get failed for req : {e}") + return None, 0 + finally: + # If data has been received, delete lock_file. + if obj and os.path.exists(lock_file): + os.remove(lock_file) + def get( self, from_stage: str, @@ -88,7 +105,7 @@ def get( metadata = metadata.get(get_key) if not isinstance(metadata, dict): - return None + return None, 0 if "inline_bytes" in metadata: try: @@ -96,63 +113,27 @@ def get( return obj, int(metadata.get("size", 0)) except Exception as e: logger.error(f"SharedMemoryConnector inline get failed for req {get_key}: {e}") - return None + return None, 0 if "shm" in metadata: - try: - shm_handle = metadata["shm"] - lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock" - with open(lock_file, "w") as lockf: - fcntl.flock(lockf, fcntl.LOCK_SH) - data_bytes = shm_read_bytes(shm_handle) - fcntl.flock(lockf, fcntl.LOCK_UN) - if os.path.exists(lock_file): - os.remove(lock_file) - obj = self.deserialize_obj(data_bytes) - return obj, int(metadata.get("size", 0)) - except Exception as e: - logger.error(f"SharedMemoryConnector shm get failed for req {get_key}: {e}") - return None - - return None - - from multiprocessing import shared_memory as shm_pkg + shm_handle = metadata["shm"] + lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock" + return self._get_data_with_lock(lock_file, shm_handle) - # Wait for shared memory to be available (with retry logic) - max_retries = 30 - retry_delay = 0.1 # 100ms between retries + return None, 0 shm = None - - for attempt in range(max_retries): - try: - shm = shm_pkg.SharedMemory(name=get_key) - break # Successfully opened, exit retry loop - except FileNotFoundError: - if attempt < max_retries - 1: - time.sleep(retry_delay) - else: - # Max retries reached, return None - logger.warning(f"Shared memory '{get_key}' not found after {max_retries} retries") - return None - - if shm is None: - return None - try: + shm = shm_pkg.SharedMemory(name=get_key) + if shm is None or shm.size == 0: + return None, 0 lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" - with open(lock_file) as lockf: - fcntl.flock(lockf, fcntl.LOCK_SH) - data_bytes = shm_read_bytes({"name": get_key, "size": shm.size}) - fcntl.flock(lockf, fcntl.LOCK_UN) - # Clean up the temporary file if it still exists. - if os.path.exists(lock_file): - os.remove(lock_file) - obj = self.deserialize_obj(data_bytes) - return obj, shm.size + shm_handle = {"name": get_key, "size": shm.size} + return self._get_data_with_lock(lock_file, shm_handle) + except Exception: + return None, 0 finally: - shm.close() - - # TODO: update another read method + if shm: + shm.close() def cleanup(self, request_id: str) -> None: # SHM segments are automatically unlinked during 'get' (shm_read_bytes). diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/base.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/base.py index 0962808ed16..c64bcad69cc 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/base.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/base.py @@ -3,6 +3,7 @@ import threading import time +from collections import deque from typing import Any from ..utils.logging import get_connector_logger @@ -22,17 +23,16 @@ def __init__(self, config: Any): if not hasattr(self, "connector"): self.connector = None # Requests that are waiting to be polled - self._pending_load_reqs = {} + self._pending_load_reqs = deque() # Requests that have successfully retrieved data self._finished_load_reqs = set() # Requests that are waiting to be saved - self._pending_save_reqs = {} + self._pending_save_reqs = deque() # Requests that have successfully saved data self._finished_save_reqs = set() self.stop_event = threading.Event() - self.lock = threading.Lock() self.recv_thread = threading.Thread(target=self.recv_loop, daemon=True) self.recv_thread.start() @@ -48,37 +48,30 @@ def recv_loop(self): """Loop to poll for incoming data.""" while not self.stop_event.is_set(): # Iterate over a snapshot of pending requests - with self.lock: - pending_reqs_ids = list(self._pending_load_reqs.keys()) - - for req_id in pending_reqs_ids: + while self._pending_load_reqs: + request = self._pending_load_reqs.popleft() + request_id = request.request_id + self.request_ids_mapping[request_id] = request.external_req_id try: - self._poll_single_request(req_id) + is_success = self._poll_single_request(request) + if not is_success: + self._pending_load_reqs.append(request) except Exception as e: - logger.warning(f"Error receiving data for {req_id}: {e}") + self._pending_load_reqs.append(request) + logger.warning(f"Error receiving data for {request_id}: {e}") time.sleep(0.001) def save_loop(self): """Loop to send outgoing data.""" while not self.stop_event.is_set(): - task = None - with self.lock: - pending_save_reqs_ids = list(self._pending_save_reqs.keys()) - for req_id in pending_save_reqs_ids: - if self._pending_save_reqs[req_id]: - task = self._pending_save_reqs[req_id].popleft() - if not self._pending_save_reqs[req_id]: - del self._pending_save_reqs[req_id] - break - - if task: + while self._pending_save_reqs: + task = self._pending_save_reqs.popleft() try: self._send_single_request(task) except Exception as e: - logger.error(f"Error saving data for {task.get('request_id')}: {e}") - else: - time.sleep(0.001) + logger.warning(f"Error saving data for {task.get('request_id')}: {e}") + time.sleep(0.001) def _poll_single_request(self, *args, **kwargs): """Poll connector for a single request task. diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index a6afb97bd4c..b46191539fa 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -89,15 +89,12 @@ def load_async(self, request: Request): request: The request object needing data. """ stage_id = self.connector.stage_id - request_id = request.request_id - self.request_ids_mapping[request_id] = request.external_req_id if stage_id == 0: return if not hasattr(request, "additional_information"): request.additional_information = None - with self.lock: - self._pending_load_reqs[request_id] = request + self._pending_load_reqs.append(request) def save_async(self, pooling_output: torch.Tensor | None = None, request: Request | None = None): """Build and enqueue one chunk for asynchronous sending. @@ -109,93 +106,61 @@ def save_async(self, pooling_output: torch.Tensor | None = None, request: Reques pooling_output: Partial pooling output dictionary request: Request object """ - stage_id = self.connector.stage_id - next_stage_id = stage_id + 1 - request_id = request.external_req_id - chunk_id = self.put_req_chunk[request_id] - - # Process payload in main thread to avoid race conditions on request state - payload_data = None - if self.custom_process_next_stage_input_func: - try: - payload_data = self.custom_process_next_stage_input_func( - transfer_manager=self, - pooling_output=pooling_output, - request=request, - ) - - except Exception as e: - logger.error(f"Failed to use custom_process_input_func for payload extraction: {e}") - - if not payload_data: - return - - # Increment chunk_id - self.put_req_chunk[request_id] += 1 - connector_put_key = f"{request_id}_{stage_id}_{chunk_id}" - task = { - "stage_id": stage_id, - "next_stage_id": next_stage_id, - "put_key": connector_put_key, - "data": payload_data, - "request_id": request_id, + "pooling_output": pooling_output, + "request": request, + "is_finished": request.is_finished(), } + self._pending_save_reqs.append(task) - with self.lock: - if request_id not in self._pending_save_reqs: - self._pending_save_reqs[request_id] = deque() - self._pending_save_reqs[request_id].append(task) - - def _poll_single_request(self, req_id: str): + def _poll_single_request(self, request: Request): stage_id = self.connector.stage_id target_stage_id = stage_id - 1 + req_id = request.request_id chunk_id = self.get_req_chunk[req_id] external_req_id = self.request_ids_mapping.get(req_id, req_id) connector_get_key = f"{external_req_id}_{target_stage_id}_{chunk_id}" # Use timeout=0 for non-blocking poll - result = self.connector.get( - str(target_stage_id), - str(stage_id), - connector_get_key, - ) - - if result is None: - return + try: + result = self.connector.get( + str(target_stage_id), + str(stage_id), + connector_get_key, + ) + except Exception as e: + logger.error(f"SharedMemoryConnector get failed for req {connector_get_key}: {e}") + return False payload_data, size = result if payload_data: # Update connector state self.get_req_chunk[req_id] += 1 - req = self._pending_load_reqs[req_id] if self.model_mode == "ar": self._update_request_payload(external_req_id, payload_data) - req.additional_information = payload_data + request.additional_information = payload_data if payload_data.get("finished"): self.finished_requests.add(req_id) else: if payload_data.get("finished"): self.finished_requests.add(req_id) - # req.prompt_token_ids = payload_data.get("code_predictor_codes", []) - # req.num_computed_tokens = 0 new_ids = payload_data.get("code_predictor_codes", []) - req.prompt_token_ids = new_ids - req.num_computed_tokens = 0 + request.prompt_token_ids = new_ids + request.num_computed_tokens = 0 # Empty chunk with more data expected: keep polling. if not new_ids and not payload_data.get("finished"): - return + return True # Mark as finished for consumption - with self.lock: - self._finished_load_reqs.add(req_id) - if req_id in self._pending_load_reqs: - del self._pending_load_reqs[req_id] + self._finished_load_reqs.add(req_id) logger.debug(f"[Stage-{stage_id}] Received one chunk for key {connector_get_key}") + return True + + return False def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: """Update the payload data for a request in the connector. @@ -221,10 +186,30 @@ def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> return payload_data def _send_single_request(self, task: dict): - connector_put_key = task["put_key"] - stage_id = task["stage_id"] - next_stage_id = task["next_stage_id"] - payload_data = task["data"] + pooling_output = task["pooling_output"] + request = task["request"] + is_finished = task["is_finished"] + stage_id = self.connector.stage_id + next_stage_id = stage_id + 1 + request_id = request.external_req_id + chunk_id = self.put_req_chunk[request_id] + connector_put_key = f"{request_id}_{stage_id}_{chunk_id}" + # Process payload in main thread to avoid race conditions on request state + payload_data = None + if self.custom_process_next_stage_input_func: + try: + payload_data = self.custom_process_next_stage_input_func( + transfer_manager=self, + pooling_output=pooling_output, + request=request, + is_finished=is_finished, + ) + + except Exception as e: + logger.error(f"Failed to use custom_process_input_func for payload extraction: {e}") + + if not payload_data: + return success, size, metadata = self.connector.put( from_stage=str(stage_id), @@ -234,6 +219,7 @@ def _send_single_request(self, task: dict): ) if success: + self.put_req_chunk[request_id] += 1 logger.debug(f"[Stage-{stage_id}] Sent {connector_put_key}") ######################################################################## @@ -250,12 +236,11 @@ def process_pending_chunks( """ if self.connector.stage_id == 0: return - finished_load_reqs = self.get_finished_requests() self._process_chunk_queue( - waiting_queue, self.waiting_for_chunk_waiting_requests, RequestStatus.WAITING, finished_load_reqs + waiting_queue, self.waiting_for_chunk_waiting_requests, RequestStatus.WAITING, self._finished_load_reqs ) self._process_chunk_queue( - running_queue, self.waiting_for_chunk_running_requests, RequestStatus.RUNNING, finished_load_reqs + running_queue, self.waiting_for_chunk_running_requests, RequestStatus.RUNNING, self._finished_load_reqs ) while len(running_queue) > self.scheduler_max_num_seqs: request = running_queue.pop() @@ -321,6 +306,7 @@ def _process_chunk_queue( else: if request.request_id in finished_load_reqs: request.status = target_status + finished_load_reqs.remove(request.request_id) self.requests_with_ready_chunks.add(request.request_id) continue queue.remove(request) @@ -336,9 +322,3 @@ def _clear_chunk_ready(self, scheduler_output: Any) -> None: for req_id in scheduler_output.scheduled_cached_reqs.req_ids: if req_id in self.requests_with_ready_chunks: self.requests_with_ready_chunks.remove(req_id) - - def get_finished_requests(self): - with self.lock: - finished_load = set(self._finished_load_reqs) - self._finished_load_reqs = set() - return finished_load diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 4bd9bd3ba2b..46b07f9deb5 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -155,6 +155,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # for CI: Initialize special tokens embeddings early to avoid AttributeError when loading dummy weights self._init_special_tokens_embeddings() + # suppress tokens by setting their probability to ~1e-9 (finite very small) + self.suppressed_tokens = self._get_talker_suppressed_tokens() self.requires_raw_input_tokens = True elif self.model_stage == "code2wav": @@ -1067,18 +1069,16 @@ def compute_logits( # implemented by assigning their logits to log(1e-9). if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor): - # suppress tokens by setting their probability to ~1e-9 (finite very small) - suppressed_tokens = self._get_talker_suppressed_tokens() try: logits_cpu = logits.cpu() - logits_cpu[:, suppressed_tokens] = -1e9 + logits_cpu[:, self.suppressed_tokens] = -1e9 logits = logits_cpu.to(logits.device) except Exception as e: print(f"Error in logits suppression: {e}") print(f"logits.shape: {logits.shape}") - print(f"suppressed_tokens: {suppressed_tokens}") + print(f"self.suppressed_tokens: {self.suppressed_tokens}") raise e - logits[:, suppressed_tokens] = -1e9 + logits[:, self.suppressed_tokens] = -1e9 return logits def sample( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 11aea4f0410..3a42159a8ff 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -87,6 +87,7 @@ def thinker2talker_async_chunk( transfer_manager: Any, pooling_output: dict[str, Any], request: OmniEngineCoreRequest, + is_finished: bool = False, ) -> list[dict[str, Any]]: """ Process thinker outputs to create talker inputs. @@ -112,10 +113,10 @@ def thinker2talker_async_chunk( "tts_bos_embed": pooling_output.get("tts_bos_embed").detach().cpu(), "tts_eos_embed": pooling_output.get("tts_eos_embed").detach().cpu(), "tts_pad_embed": pooling_output.get("tts_pad_embed").detach().cpu(), - "finished": torch.tensor(request.is_finished(), dtype=torch.bool), + "finished": torch.tensor(is_finished, dtype=torch.bool), } if transfer_manager.request_payload.get(request_id) is None: - if not request.is_finished(): + if not is_finished: transfer_manager.request_payload[request_id] = talker_additional_info return None else: @@ -134,10 +135,12 @@ def thinker2talker_async_chunk( talker_additional_info = { "thinker_embeddings": pooling_output.get("0").detach().cpu(), - "thinker_hidden_states": pooling_output.get("24").detach().cpu(), - "thinker_sequences": output_token_ids, - "finished": torch.tensor(request.is_finished(), dtype=torch.bool), + "finished": torch.tensor(is_finished, dtype=torch.bool), } + + if not output_token_ids: + # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. + talker_additional_info["thinker_hidden_states"] = pooling_output.get("24").detach().cpu() return talker_additional_info @@ -209,6 +212,7 @@ def talker2code2wav_async_chunk( transfer_manager: Any, pooling_output: dict[str, Any], request: OmniEngineCoreRequest, + is_finished: bool = False, ): """ Pooling version. @@ -244,7 +248,7 @@ def talker2code2wav_async_chunk( transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) length = len(transfer_manager.code_prompt_token_ids[request_id]) chunk_length = length % chunk_size - if chunk_length != 0 and not request.is_finished(): + if chunk_length != 0 and not is_finished: return None context_length = chunk_length if chunk_length != 0 else chunk_size @@ -257,7 +261,7 @@ def talker2code2wav_async_chunk( .reshape(-1) .tolist() ), - "finished": torch.tensor(request.is_finished(), dtype=torch.bool), + "finished": torch.tensor(is_finished, dtype=torch.bool), } return info diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 8599ea2e3e8..3d9617bf3a6 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -23,6 +23,7 @@ def talker2code2wav_async_chunk( transfer_manager: Any, pooling_output: dict[str, Any], request: Any, + is_finished: bool = False, ) -> dict[str, Any] | None: if not isinstance(pooling_output, dict): return None @@ -40,8 +41,6 @@ def talker2code2wav_async_chunk( f"codec_left_context_frames={left_context_size}" ) - finished = bool(request.is_finished()) - frame = _extract_last_frame(pooling_output) if frame is not None: codec_codes = frame.cpu().tolist() @@ -50,16 +49,13 @@ def talker2code2wav_async_chunk( length = len(transfer_manager.code_prompt_token_ids[request_id]) chunk_length = length % chunk_size - if chunk_length != 0 and not finished: + if chunk_length != 0 and not is_finished: return None context_length = chunk_length if chunk_length != 0 else chunk_size if length <= 0: - return { - "code_predictor_codes": [], - "finished": torch.tensor(bool(finished), dtype=torch.bool), - } + return None end_index = min(length, left_context_size + context_length) ctx_frames = max(0, int(end_index - context_length)) @@ -72,5 +68,5 @@ def talker2code2wav_async_chunk( # The model expects input_ids layout: [ctx_frames, *flat_codes]. return { "code_predictor_codes": [int(ctx_frames)] + code_predictor_codes, - "finished": torch.tensor(bool(finished), dtype=torch.bool), + "finished": torch.tensor(is_finished, dtype=torch.bool), } diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index fc1362f883d..3629cb2a99f 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -100,7 +100,7 @@ def execute_model( record_function_or_nullcontext("gpu_model_runner: preprocess"), self.synchronize_input_prep(), ): - if self.model_config.async_chunk: + if self.model_config.async_chunk and num_scheduled_tokens: self._update_request_states(scheduler_output) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index ba6f9e36791..32ce8a8894f 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -924,17 +924,27 @@ def _collect_additional_information_for_prefill( start_offset = int(self.query_start_loc.cpu[req_index]) self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + def _update_request_information(self, request_id: str, payload_info: dict) -> None: + """Update per-request additional_information stored in request state.""" + req_state = self.requests.get(request_id) + if req_state is None: + return + + info_dict = getattr(req_state, "additional_information_cpu", None) + if isinstance(payload_info, dict) and info_dict is not None: + info_dict.update(payload_info) + def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> None: for new_req in scheduler_output.scheduled_new_reqs: payload_info = getattr(new_req, "additional_information", None) if isinstance(payload_info, dict): - self._merge_additional_information_update(new_req.req_id, payload_info) + self._update_request_information(new_req.req_id, payload_info) if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"): cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {}) if isinstance(cached_infos, dict): for req_id, req_infos in cached_infos.items(): - self._merge_additional_information_update(req_id, req_infos) + self._update_request_information(req_id, req_infos) def _preprocess( self,