diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 2ec54b365ee..2a0c92e9d07 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -23,7 +23,9 @@ from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) -from vllm_omni.engine.serialization import deserialize_additional_information +from vllm_omni.worker_v2.model_states.intermediate_buffer import ( + _resolve_additional_information, +) logger = init_logger(__name__) @@ -461,6 +463,9 @@ def update_from_output( ) ) if self.chunk_transfer_adapter is not None: + # Only clean receiver-side state here. Sender-side + # cleanup (cleanup_sender) is unsafe while save_async() + # background threads may still reference sender dicts. self.chunk_transfer_adapter.cleanup_receiver( request.request_id, ) @@ -616,14 +621,15 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di } # Also update request.additional_information for good measure add_info = getattr(request, "additional_information", None) - # If additional_information is an AdditionalInformationPayload-like object, - # unpack it into a plain dict. + # If additional_information is an AdditionalInformationPayload-like + # object, fully resolve it into a plain dict (tensor_data → Tensor, + # list_data → list, scalar_data → scalar). if ( add_info is not None and hasattr(add_info, "entries") and isinstance(getattr(add_info, "entries"), dict) ): - request.additional_information = deserialize_additional_information(add_info) + request.additional_information = _resolve_additional_information(add_info) add_info = request.additional_information if add_info is None: request.additional_information = {} diff --git a/vllm_omni/worker_v2/model_states/intermediate_buffer.py b/vllm_omni/worker_v2/model_states/intermediate_buffer.py index 7057cb43107..64d62f6aad6 100644 --- a/vllm_omni/worker_v2/model_states/intermediate_buffer.py +++ b/vllm_omni/worker_v2/model_states/intermediate_buffer.py @@ -53,8 +53,12 @@ def _resolve_additional_information(payload: Any) -> dict[str, Any]: arr = np.frombuffer(tensor_data, dtype=dt) arr = arr.reshape(getattr(entry, "tensor_shape", ())) info[k] = torch.from_numpy(arr.copy()) + elif getattr(entry, "list_data", None) is not None: + info[k] = entry.list_data + elif getattr(entry, "scalar_data", None) is not None: + info[k] = entry.scalar_data else: - info[k] = getattr(entry, "list_data", None) + info[k] = None return info except Exception: logger.exception("Failed to decode additional_information payload") diff --git a/vllm_omni/worker_v2/model_states/omni_model_state.py b/vllm_omni/worker_v2/model_states/omni_model_state.py index a801c59e3e0..35e0b93d945 100644 --- a/vllm_omni/worker_v2/model_states/omni_model_state.py +++ b/vllm_omni/worker_v2/model_states/omni_model_state.py @@ -83,42 +83,63 @@ def __init__( OmniModelState._rope_patch_lock = threading.Lock() def _safe_get_rope(model_config: Any, mdl: Any, **kwargs: Any) -> Any: + result = None + needs_mrope_override = False try: - return _orig_get_rope(model_config, mdl, **kwargs) + result = _orig_get_rope(model_config, mdl, **kwargs) except (AssertionError, TypeError): - if not model_config.uses_mrope: - return None - logger.info( - "Model uses M-RoPE (config) but does not implement SupportsMRoPE; creating RopeState(num_dims=3)." - ) - # Add get_mrope_input_positions if missing. - # Returns 3D sequential positions with delta=0 - # (pure text, no vision token offsets). - if not hasattr(mdl, "get_mrope_input_positions"): - - def _default_mrope_positions( - self_model: Any, - input_tokens: list[int], - mm_features: list, - ) -> tuple[torch.Tensor, int]: - """Return 3D sequential positions with zero delta. - - For non-vision Omni models (e.g. TTS Talker), - all 3 M-RoPE dimensions use the same sequential - positions. Delta=0 means decode-step positions - are simply ``num_computed + offset``, identical - to the 1D case but broadcast to 3 dims. - """ - n = len(input_tokens) - pos = torch.arange(n, dtype=torch.long) - return pos.unsqueeze(0).expand(3, -1), 0 - - mdl.get_mrope_input_positions = types.MethodType(_default_mrope_positions, mdl) - # has_delta=True is required so init_prefill_positions - # calls get_mrope_input_positions (not the XD-RoPE - # path). delta=0 (returned above) means no offset is - # applied during decode — positions stay sequential. - return RopeState(num_dims=3, has_delta=True, **kwargs) + # Model does not implement SupportsMRoPE — may still + # need M-RoPE if config declares mrope_section. + needs_mrope_override = model_config.uses_mrope + + if result is not None and not needs_mrope_override: + # Upstream returned a rope but check dimensionality: + # config has mrope_section but upstream returned a 1D + # rope (e.g. rope_type="default" with mrope_section). + if model_config.uses_mrope and getattr(result, "num_dims", 0) < 3: + logger.info( + "Upstream returned %dD rope but config has mrope_section; " + "overriding with RopeState(num_dims=3).", + getattr(result, "num_dims", 0), + ) + needs_mrope_override = True + else: + return result + + if not needs_mrope_override: + return None + + logger.info( + "Model uses M-RoPE (config) but does not implement SupportsMRoPE; creating RopeState(num_dims=3)." + ) + # Add get_mrope_input_positions if missing. + # Returns 3D sequential positions with delta=0 + # (pure text, no vision token offsets). + if not hasattr(mdl, "get_mrope_input_positions"): + + def _default_mrope_positions( + self_model: Any, + input_tokens: list[int], + mm_features: list, + ) -> tuple[torch.Tensor, int]: + """Return 3D sequential positions with zero delta. + + For non-vision Omni models (e.g. TTS Talker), + all 3 M-RoPE dimensions use the same sequential + positions. Delta=0 means decode-step positions + are simply ``num_computed + offset``, identical + to the 1D case but broadcast to 3 dims. + """ + n = len(input_tokens) + pos = torch.arange(n, dtype=torch.long) + return pos.unsqueeze(0).expand(3, -1), 0 + + mdl.get_mrope_input_positions = types.MethodType(_default_mrope_positions, mdl) + # has_delta=True is required so init_prefill_positions + # calls get_mrope_input_positions (not the XD-RoPE + # path). delta=0 (returned above) means no offset is + # applied during decode — positions stay sequential. + return RopeState(num_dims=3, has_delta=True, **kwargs) with OmniModelState._rope_patch_lock: _orig_get_rope = _default_mod.get_rope_state diff --git a/vllm_omni/worker_v2/omni_generation_model_runner.py b/vllm_omni/worker_v2/omni_generation_model_runner.py index b42b4fde5f4..704f711ada4 100644 --- a/vllm_omni/worker_v2/omni_generation_model_runner.py +++ b/vllm_omni/worker_v2/omni_generation_model_runner.py @@ -22,7 +22,7 @@ get_uniform_token_count, ) -from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData +from vllm_omni.core.sched.output import OmniCachedRequestData from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker_v2.omni_model_runner import OmniGPUModelRunner @@ -52,7 +52,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # ------------------------------------------------------------------ def _handle_async_chunk_updates(self, scheduler_output: SchedulerOutput) -> None: - """Re-initialize cached requests whose prompt_token_ids changed. + """In-place update cached requests whose prompt_token_ids changed. In async_chunk mode, the ``ChunkTransferAdapter`` replaces ``Request.prompt_token_ids`` with new codec frames for each @@ -60,14 +60,17 @@ def _handle_async_chunk_updates(self, scheduler_output: SchedulerOutput) -> None propagates the new ``prompt_token_ids`` via ``OmniCachedRequestData``. - Upstream V2 ``update_requests`` only handles block allocation - and does **not** replace token state. This method removes the - stale request from ``req_states`` and re-adds it with the - updated tokens, mirroring V1's ``_update_request_states``. - - Note: ``finish_requests`` / ``free_states`` (called before us) - already handle unscheduled request cleanup, so we only need to - process requests with new prompt_token_ids here. + Instead of remove + re-add (which involves free_indices churn + and redundant model_state init), we update the existing slot + in-place. This is safe for Code2Wav because: + - No KV cache / rope state to reinitialize + - staged writes are applied once at the end + + ``additional_information`` is NOT merged here — the inherited + ``OmniGPUModelRunner.update_requests`` (called right after this + method in ``execute_model``) is the single source of truth for + ``intermediate_buffer`` updates. Doing it in both places would + clone every tensor to CPU twice per step. """ cached = scheduler_output.scheduled_cached_reqs if not cached.req_ids: @@ -80,68 +83,31 @@ def _handle_async_chunk_updates(self, scheduler_output: SchedulerOutput) -> None if not new_prompt_ids: return - addl_info = cached.additional_information + updated = False - # Phase 1: remove all stale states, collecting re-add data. - updates: list[tuple[str, list[int], Any]] = [] for req_id in cached.req_ids: new_ids = new_prompt_ids.get(req_id) if new_ids is None: continue - if req_id not in self.req_states.req_id_to_index: + req_idx = self.req_states.req_id_to_index.get(req_id) + if req_idx is None: continue - old_index = self.req_states.req_id_to_index[req_id] - - self.req_states.remove_request(req_id) - self.model_state.remove_request(old_index) - - updates.append((req_id, new_ids, addl_info.get(req_id))) - - if not updates: - return - - # Phase 2: re-add all requests with new tokens, then batch-apply. - for req_id, new_ids, info in updates: - # req_id_to_index is updated eagerly by add_request() (not - # staged), so new_index is valid before apply_staged_writes(). - self.req_states.add_request( - req_id=req_id, - prompt_len=len(new_ids), - all_token_ids=new_ids, - num_computed_tokens=0, - ) + # In-place update token state — same slot, no remove/re-add. + # .np[] = direct write (no GPU buffer); stage_write = GPU-synced. + n = len(new_ids) + self.req_states.prompt_len.np[req_idx] = n + self.req_states.prefill_len.np[req_idx] = n + self.req_states.total_len.stage_write_elem(req_idx, n) + self.req_states.all_token_ids.stage_write(req_idx, 0, new_ids) + self.req_states.num_computed_tokens.stage_write_elem(req_idx, 0) + self.req_states.num_computed_prefill_tokens[req_idx] = 0 - new_index = self.req_states.req_id_to_index[req_id] - - # Build a synthetic NewRequestData so model_state.add_request - # can run _resolve_additional_information and notify plugins. - # sampling_params is None because the generation runner does - # not sample tokens — Code2Wav output goes directly to - # pooler_output. block_ids is empty because generation - # models have no KV cache. - synthetic = OmniNewRequestData( - req_id=req_id, - prompt_token_ids=new_ids, - mm_features=None, - sampling_params=None, - pooling_params=None, - block_ids=([],), - num_computed_tokens=0, - lora_request=None, - prefill_token_ids=new_ids, - additional_information=info, - ) - # model_state.add_request calls super().add_request() - # (DefaultModelState) internally. This is safe for Code2Wav - # because generation models have no attention/rope state to - # initialize — the super() call only touches intermediate - # buffer and encoder cache, both of which are idempotent - # overwrites on the same slot. - self.model_state.add_request(new_index, synthetic) + updated = True - self.req_states.apply_staged_writes() + if updated: + self.req_states.apply_staged_writes() # ------------------------------------------------------------------ # profile / warmup — skip sampler since there are no logits @@ -175,8 +141,8 @@ def execute_model( self.finish_requests(scheduler_output) self.free_states(scheduler_output) # Handle async_chunk prompt_token_ids replacement for cached - # requests BEFORE add/update — so the stale request state is - # removed and re-created with the new chunk's tokens. + # requests BEFORE add/update — update the existing slot + # in-place with the new chunk's tokens. self._handle_async_chunk_updates(scheduler_output) self.add_requests(scheduler_output) self.update_requests(scheduler_output) diff --git a/vllm_omni/worker_v2/omni_model_runner.py b/vllm_omni/worker_v2/omni_model_runner.py index e63be5636fc..a6ba0f5f6e8 100644 --- a/vllm_omni/worker_v2/omni_model_runner.py +++ b/vllm_omni/worker_v2/omni_model_runner.py @@ -32,6 +32,9 @@ from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.worker_v2.model_states import init_omni_model_state +from vllm_omni.worker_v2.model_states.intermediate_buffer import ( + _resolve_additional_information, +) from vllm_omni.worker_v2.model_states.omni_model_state import OmniModelState logger = init_logger(__name__) @@ -305,6 +308,35 @@ def execute_model( assert isinstance(hidden_states, torch.Tensor) return None + # ------------------------------------------------------------------ + # Request lifecycle: update intermediate buffer from cached requests + # ------------------------------------------------------------------ + + def update_requests(self, scheduler_output: SchedulerOutput) -> None: + """Merge updated additional_information into intermediate_buffer. + + In async_chunk mode, chunk_transfer_adapter attaches updated + additional_information (e.g. thinker_decode_embeddings) to + OmniCachedRequestData for cached requests every schedule step. + Upstream GPUModelRunner.update_requests does not handle this + field, so we merge it into the intermediate buffer here. + """ + super().update_requests(scheduler_output) + + cached = scheduler_output.scheduled_cached_reqs + addl_info = getattr(cached, "additional_information", None) + if not addl_info: + return + for req_id, info in addl_info.items(): + if info is None: + continue + req_idx = self.req_states.req_id_to_index.get(req_id) + if req_idx is None: + continue + resolved = _resolve_additional_information(info) + if resolved: + self.model_state.intermediate_buffer.update(req_idx, resolved) + # ------------------------------------------------------------------ # Request lifecycle: clean up intermediate buffer on finish # ------------------------------------------------------------------