diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index f5a50aafb58..78a2abc8619 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -25,6 +25,7 @@ import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa import vllm_ascend.patch.worker.patch_bert # noqa import vllm_ascend.patch.worker.patch_distributed # noqa +import vllm_ascend.patch.worker.patch_mamba_utils # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py new file mode 100644 index 00000000000..d44d5583a70 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -0,0 +1,59 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# mypy: ignore-errors +# Adapted from vllm-project/vllm/vllm/v1/worker/mamba_utils.py +# Replace the CUDA batch_memcpy kernel with a Triton kernel for Ascend NPU. + +import vllm +from vllm.triton_utils import tl, triton + + +@triton.jit +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + + data = tl.load(curr_src_ptr, mask=mask) + tl.store(curr_dst_ptr, data, mask=mask) + + +def batch_memcpy(src_ptrs, dst_ptrs, sizes): + batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch + + grid = (batch,) + # NOTE: BLOCK_SIZE must be a power-of-2 constexpr for Triton. + # 128 bytes per tile provides a reasonable balance between parallelism + # and register pressure on Ascend NPU; adjust if profiling suggests + # a different value. + BLOCK_SIZE = 128 + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + + +vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel +vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1bb40291ed2..42de48eff31 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,6 +77,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import get_total_cp_world_size from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner from vllm.v1.worker.ubatch_utils import ( UBatchSlices, @@ -393,6 +395,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.cpu_slot_mapping = None self.sampling_done_event: torch.npu.Event | None = None + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None + @property def use_cp(self) -> bool: return self.pcp_size * self.dcp_size > 1 @@ -1206,6 +1211,23 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -2518,6 +2540,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups @@ -2903,6 +2926,27 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] kernel_block_sizes.append([0]) + + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv( + max_model_len, block_sizes[i] * get_total_cp_world_size() + ) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req + if self.cache_config.enable_prefix_caching + else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + + max_num_blocks_per_req = max( + max_num_blocks_per_req, mamba_blocks_per_req + ) + max_num_blocks.append(max_num_blocks_per_req) + if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -2911,7 +2955,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: ) self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2926,6 +2970,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: else 0 ), kernel_block_sizes=kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, ) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 2d7a7c8b062..b5f21a6a986 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -40,6 +40,7 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[list[int]], + max_num_blocks_per_req: list[int] | None = None, logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, @@ -97,6 +98,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + max_num_blocks=max_num_blocks_per_req, num_speculative_tokens=num_speculative_tokens, kernel_sizes=kernel_block_sizes, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,