Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
import vllm_ascend.patch.worker.patch_bert # noqa
import vllm_ascend.patch.worker.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_mamba_utils # noqa
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
Expand Down
59 changes: 59 additions & 0 deletions vllm_ascend/patch/worker/patch_mamba_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: ignore-errors
# Adapted from vllm-project/vllm/vllm/v1/worker/mamba_utils.py
# Replace the CUDA batch_memcpy kernel with a Triton kernel for Ascend NPU.

import vllm
from vllm.triton_utils import tl, triton


@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)

src_ptr = tl.load(src_ptrs + pid)
dst_ptr = tl.load(dst_ptrs + pid)
size = tl.load(sizes + pid)

offsets = tl.arange(0, BLOCK_SIZE)
for i in range(0, size, BLOCK_SIZE):
mask = (i + offsets) < size

curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))

data = tl.load(curr_src_ptr, mask=mask)
tl.store(curr_dst_ptr, data, mask=mask)


def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch

grid = (batch,)
# NOTE: BLOCK_SIZE must be a power-of-2 constexpr for Triton.
# 128 bytes per tile provides a reasonable balance between parallelism
# and register pressure on Ascend NPU; adjust if profiling suggests
# a different value.
BLOCK_SIZE = 128
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)


vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel
vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy
47 changes: 46 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.cp_utils import get_total_cp_world_size
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
Expand Down Expand Up @@ -393,6 +395,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.cpu_slot_mapping = None
self.sampling_done_event: torch.npu.Event | None = None

self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None

@property
def use_cp(self) -> bool:
return self.pcp_size * self.dcp_size > 1
Expand Down Expand Up @@ -1206,6 +1211,23 @@ def execute_model(

pad_attn = cudagraph_mode == CUDAGraphMode.FULL

# NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877,
# there should be a corresponding 'postprocess_mamba'. However, it is called inside
# '_update_states_after_model_execute', which is not overridden in vLLM-Ascend.
# We simply utilize the implementation in vLLM.
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.preprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.cache_config,
self.mamba_state_idx,
self.input_batch,
self.requests,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
)

use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices

Expand Down Expand Up @@ -2518,6 +2540,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
Expand Down Expand Up @@ -2903,6 +2926,27 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
# of mamba block. In this case, BlockTable.block_size will never equal
# to kernel_block_sizes[0]
kernel_block_sizes.append([0])

max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
max_num_blocks_per_req = cdiv(
max_model_len, block_sizes[i] * get_total_cp_world_size()
)
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = (
max_num_blocks_per_req
if self.cache_config.enable_prefix_caching
else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks

max_num_blocks_per_req = max(
max_num_blocks_per_req, mamba_blocks_per_req
)
max_num_blocks.append(max_num_blocks_per_req)

if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [[self.cache_config.block_size]]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
Expand All @@ -2911,7 +2955,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
)
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.model_config.max_model_len, self.max_encoder_len),
max_model_len=max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
Expand All @@ -2926,6 +2970,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
else 0
),
kernel_block_sizes=kernel_block_sizes,
max_num_blocks_per_req=max_num_blocks,
)

def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker/npu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[list[int]],
max_num_blocks_per_req: list[int] | None = None,
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
max_num_blocks=max_num_blocks_per_req,
num_speculative_tokens=num_speculative_tokens,
kernel_sizes=kernel_block_sizes,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
Expand Down