diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml index a02d84ed6d7..d021ef218ee 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml @@ -39,7 +39,7 @@ stage_args: worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.8 - enforce_eager: false + enforce_eager: true # haven't supported talker ACL graph on NPU trust_remote_code: true enable_prefix_caching: false engine_output_type: latent diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml index fdf81723941..f99ed22e878 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml @@ -44,7 +44,7 @@ stage_args: worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.2 - enforce_eager: false + enforce_eager: true # haven't supported talker ACL graph on NPU trust_remote_code: true engine_output_type: latent # Output codec codes for code2wav # tensor_parallel_size: 2 diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index 0ba57e038f4..fbf9e632f78 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -35,7 +35,7 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.utils import ProfileExecuteDuration -from vllm_omni.core.sched.omni_ar_scheduler import KVCacheTransferData +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner @@ -65,7 +65,8 @@ def __init__(self, *args, **kwargs): # each model stage has their own hidden size self.hidden_size = self.model_config.hf_text_config.hidden_size self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) - self.omni_connector = None + # Initialize KV cache manager (preserve vllm_config fallback behavior) + self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) def _make_buffer(self, *size, dtype, numpy=True): # Prevent ray from pinning the buffer due to large size @@ -92,7 +93,13 @@ def execute_model( # -------------------------------------- Omni-new ------------------------------------------------- # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) - self.kv_extracted_req_ids = self._handle_finished_requests_kv_transfer(scheduler_output) + self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer( + finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}), + kv_caches=self.kv_caches, + block_size=self.cache_config.block_size, + cache_dtype=str(self.cache_config.cache_dtype), + request_id_resolver=self._resolve_global_request_id, + ) # -------------------------------------- Omni-new ------------------------------------------------- with ProfileExecuteDuration().capture_async("prepare input"): @@ -499,161 +506,6 @@ def _generate_process_reqs_hidden_states(self, num_input_tokens, return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( hidden_states) - def _handle_finished_requests_kv_transfer(self, scheduler_output: SchedulerOutput) -> list[str]: - """Handle KV cache transfer for finished requests. - - Returns list of request IDs that were processed (for Scheduler to free blocks). - """ - finished_reqs = getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}) - if not finished_reqs: - return [] - - logger.debug(f"Processing KV transfer for {len(finished_reqs)} requests") - - extracted_ids = [] - for req_id, data in finished_reqs.items(): - try: - seq_len = data.get("seq_len", 0) - block_ids = data.get("block_ids", []) - if not block_ids: - logger.warning(f"Request {req_id} has no block IDs, skipping") - continue - - # Extract KV cache from GPU blocks -> CPU tensors - kv_data = self._extract_kv_cache(req_id, block_ids, seq_len) - if kv_data: - # Transfer to downstream stage via connector - self._transfer_kv_cache(kv_data) - - except Exception as e: - logger.error(f"Failed KV transfer for {req_id}: {e}") - finally: - extracted_ids.append(req_id) - - return extracted_ids - - def _extract_kv_cache(self, req_id: str, block_ids: list[int], seq_len: int) -> KVCacheTransferData | None: - """Extract KV cache from GPU blocks for a single request.""" - num_layers = len(self.kv_caches) - key_cache = [None] * num_layers - value_cache = [None] * num_layers - - for layer_idx, kv_tensor in enumerate(self.kv_caches): - # Validate block IDs - max_block = kv_tensor.shape[1] - 1 - valid_ids = [bid for bid in block_ids if 0 <= bid <= max_block] - if not valid_ids: - continue - - # Extract and reshape: [2, n_blocks, block_size, n_heads, head_dim] - # -> [2, seq_len, n_heads, head_dim] - selected = kv_tensor[:, valid_ids] # [2, n_valid, block_size, n_heads, head_dim] - n_kv, n_blks, blk_sz, n_heads, d_head = selected.shape - flat = selected.reshape(n_kv, n_blks * blk_sz, n_heads, d_head) - if seq_len < flat.shape[1]: - flat = flat[:, :seq_len] - - # Move to CPU - flat_cpu = flat.detach().cpu().contiguous() - key_cache[layer_idx] = flat_cpu[0] - value_cache[layer_idx] = flat_cpu[1] - - if not any(k is not None for k in key_cache): - return None - - return KVCacheTransferData( - request_id=req_id, - layer_blocks={"key_cache": key_cache, "value_cache": value_cache}, - block_ids=block_ids, - metadata={ - "block_size": self.cache_config.block_size, - "num_layers": num_layers, - "dtype": str(self.cache_config.cache_dtype), - "seq_len": seq_len, - }, - ) - - def _transfer_kv_cache(self, kv_data: KVCacheTransferData) -> None: - """Transfer KV cache data to downstream stage via OmniConnector.""" - connector = self._get_or_create_connector() - if not connector: - return - - # Resolve global request ID if available - transfer_req_id = self._resolve_global_request_id(kv_data.request_id) - from_stage, to_stage = self._detect_transfer_stages() - - # Prepare data and transfer with retry - data_dict = kv_data.to_dict() - data_dict["request_id"] = transfer_req_id - - success, size, _ = self._transfer_with_retry( - connector, from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict - ) - - if success: - logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes") - else: - logger.error(f"KV transfer FAILED: {transfer_req_id}") - - def _get_or_create_connector(self) -> Any | None: - """Get existing connector or create one from config.""" - if self.omni_connector: - return self.omni_connector - - from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory - from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec - - config = self._get_omni_connector_config() - if not config or not isinstance(config, dict): - logger.warning("No valid OmniConnector config found") - return None - - c_type = config.get("type") - if not c_type: - logger.error("OmniConnector config missing 'type' field") - return None - - c_extra = {k: v for k, v in config.items() if k != "type"} - self.omni_connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra)) - return self.omni_connector - - def _get_omni_connector_config(self) -> dict[str, Any] | None: - """Get OmniConnector configuration from model config.""" - # Primary: omni_kv_config from YAML - omni_kv = getattr(self.model_config, "omni_kv_config", None) - if isinstance(omni_kv, dict): - cfg = omni_kv.get("connector_config") - if isinstance(cfg, dict) and cfg: - return cfg - - # Fallback: kv_transfer_config - kv_cfg = getattr(self.vllm_config, "kv_transfer_config", None) - if kv_cfg: - direct = getattr(kv_cfg, "omni_connector_config", None) - if isinstance(direct, dict) and direct: - return direct - extra = getattr(kv_cfg, "kv_connector_extra_config", None) - if isinstance(extra, dict): - omni = extra.get("omni_connector_config") - if isinstance(omni, dict) and omni: - return omni - - return None - - def _detect_transfer_stages(self) -> tuple[str, str]: - """Detect source and target stages for KV transfer.""" - omni_kv = getattr(self.model_config, "omni_kv_config", None) - if isinstance(omni_kv, dict): - from_s = omni_kv.get("omni_from_stage") - to_s = omni_kv.get("omni_to_stage") - if from_s and to_s: - return str(from_s), str(to_s) - - raise ValueError( - "KV transfer stages not configured. Please set 'omni_from_stage' and 'omni_to_stage' in omni_kv_config." - ) - def _resolve_global_request_id(self, req_id: str) -> str: """Resolve global request ID from request state.""" req_state = self.requests.get(req_id) @@ -669,32 +521,3 @@ def _resolve_global_request_id(self, req_id: str) -> str: return global_id.decode("utf-8") return str(global_id) return req_id - - def _transfer_with_retry( - self, - connector: Any, - from_stage: str, - to_stage: str, - request_id: str, - data: dict[str, Any], - max_retries: int = 3, - ) -> tuple[bool, int, dict[str, Any] | None]: - """Transfer data with retry and exponential backoff.""" - import time - - for attempt in range(max_retries): - try: - put_key = f"omni_{from_stage}_to_{to_stage}_{request_id}" - success, size, metadata = connector.put( - from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data - ) - if success: - return success, size, metadata - logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}") - except Exception as e: - logger.warning(f"Transfer attempt {attempt + 1} exception: {e}") - - if attempt < max_retries - 1: - time.sleep(0.1 * (2**attempt)) - - return False, 0, None diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py index b0fbf7a4e2f..a9a376f95fa 100644 --- a/vllm_omni/platforms/npu/worker/npu_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py @@ -54,12 +54,13 @@ def load_model(self, *args, **kwargs) -> None: self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size - self.talker_mtp_input_ids = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) + self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) self.talker_mtp_inputs_embeds = self._make_buffer( - self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False + max_batch_size, hidden_size, dtype=self.dtype, numpy=False ) - self.last_talker_hidden = self._make_buffer(self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False) - self.text_step = self._make_buffer(self.max_num_reqs, hidden_size, dtype=self.dtype, numpy=False) + self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) def _init_mrope_positions(self, req_state: CachedRequestState): image_grid_thw = [] @@ -590,12 +591,16 @@ def dummy_drafter_compute_logits(hidden_states): model_instance=self.model, ): if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + num_tokens_padded_talker_mtp = num_tokens_padded + if num_tokens_padded_talker_mtp == self.max_num_tokens: + num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0] hidden_states = self.talker_mtp( - self.talker_mtp_input_ids.gpu[:num_tokens_padded], - self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], - self.last_talker_hidden.gpu[:num_tokens_padded], - self.text_step.gpu[:num_tokens_padded], + self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp], + self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp], + self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp], + self.text_step.gpu[:num_tokens_padded_talker_mtp], ) + self.compilation_config.cache_dir = None hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds )