diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 67a3b68e1b3..753a304226c 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -162,13 +162,14 @@ def build_connector_meta( self._request_trackers.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None) self._unfinished_request_ids.discard(finished_req_id) - + for req_id in scheduler_output.preempted_req_ids: self._preempted_req_ids.update(scheduler_output.preempted_req_ids) self._request_trackers.pop(req_id, None) self._unfinished_requests.pop(req_id, None) - meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids) + meta = AscendConnectorMetadata(self._unfinished_request_ids, + scheduler_output.preempted_req_ids) for request in scheduler_output.scheduled_new_reqs: # Right now, we only load KV for new requests @@ -183,17 +184,17 @@ def build_connector_meta( else: unfolded_block_ids = request.block_ids[0].copy() request_tracker = RequestTracker( - req_id=request.req_id, - token_len=num_tokens_to_compute, - allocated_block_ids=unfolded_block_ids, - num_saved_tokens=0, - ) + req_id=request.req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) self._request_trackers[request.req_id] = request_tracker last_chunk_tokens_num = ((len(request.prompt_token_ids) // self._block_size * self._block_size) if self._discard_partial_chunks else len( request.prompt_token_ids)) - + req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, @@ -233,10 +234,11 @@ def build_connector_meta( num_saved_tokens=0, ) self._request_trackers[req_id] = request_tracker - last_chunk_tokens_num = ((len(request_real.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else len( - request_real.prompt_token_ids)) + last_chunk_tokens_num = ( + (len(request_real.prompt_token_ids) // + self._block_size * + self._block_size) if self._discard_partial_chunks else + len(request_real.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, @@ -247,17 +249,19 @@ def build_connector_meta( >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) - + # decode/chunked request else: request_tracker = self._request_trackers[req_id] - num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[ + req_id] req_tuple = self._unfinished_requests.get(req_id) if req_tuple: request = req_tuple[0] num_current_tokens = request_tracker.token_len new_token_ids = request.all_token_ids[ - num_current_tokens:num_current_tokens + num_new_tokens] + num_current_tokens:num_current_tokens + + num_new_tokens] request_tracker.token_len += len(new_token_ids) else: raise ValueError( @@ -269,9 +273,10 @@ def build_connector_meta( request_tracker.update(new_block_ids) last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(request.prompt_token_ids)) + self._block_size * + self._block_size) if + self._discard_partial_chunks else + len(request.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 563dd6174c8..cc352198d9f 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -461,11 +461,13 @@ def store_layer( for layer_id in range(self.num_layers): yield - def get_finished(self, - finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]: + def get_finished( + self, finished_req_ids: set[str], + meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]: done_sending = ( self.get_and_clear_finished_requests( - finished_req_ids, meta # type: ignore[union-attr] + finished_req_ids, + meta # type: ignore[union-attr] ) if self.kv_role in ['kv_producer', 'kv_both'] or self.consumer_is_to_put else set()) @@ -480,7 +482,8 @@ def get_finished(self, self.tp_rank) return done_sending, done_recving - def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]: + def get_and_clear_finished_requests( + self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]: finished_sending = set() for req_id in meta.preempted_req_ids: self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3b38aa431e5..16338ab0a57 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -25,7 +25,10 @@ DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb -from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, @@ -531,6 +534,37 @@ def forward(self, class AscendMRotaryEmbedding(MRotaryEmbedding): + def forward_triton(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None): + assert positions.ndim == 2 + assert key is not None + + self._match_cos_sin_cache_dtype(query) + cos_sin = self.cos_sin_cache[positions] # type: ignore + cos, sin = cos_sin.chunk(2, dim=-1) + self.cos = cos.contiguous() + self.sin = sin.contiguous() + query_shape = query.shape + key_shape = key.shape + + assert self.mrope_section + + q, k = triton_mrope( + query, + key, + self.cos, + self.sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + self.mrope_interleaved, + ) + + return q.reshape(query_shape), k.reshape(key_shape) + def forward_oot( self, positions: torch.Tensor, @@ -538,7 +572,8 @@ def forward_oot( key: torch.Tensor, ): # use triton mrope for Qwen3-VL - if self.mrope_section == _QWEN3_VL_MROPE_SECTION: + if HAS_TRITON and positions.ndim == 2 and self.mrope_section == _QWEN3_VL_MROPE_SECTION: + # todo: need cann update in 8.5.0 return self.forward_triton(positions, query, key) if self.mrope_section != [16, 24, 24] or \ @@ -567,35 +602,6 @@ def forward_oot( return query, key - def forward_triton( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - assert positions.ndim == 2 - assert key is not None - - self._match_cos_sin_cache_dtype(query) - cos_sin = self.cos_sin_cache[positions] - cos, sin = cos_sin.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - - assert self.mrope_section - q, k = triton_mrope( - query, - key, - cos, - sin, - self.mrope_section, - self.head_size, - self.rotary_dim, - self.mrope_interleaved, - ) - return q.reshape(query_shape), k.reshape(key_shape) - class AscendApplyRotaryEmb(ApplyRotaryEmb):