From 589278cdd76eea7e7aecf54f4b037f97673fb4b3 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 6 Mar 2026 16:55:57 +0800 Subject: [PATCH 01/12] npu support mamba prefix cache Signed-off-by: Angazenn --- vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_mamba_utils.py | 39 +++++++++++++++ vllm_ascend/worker/block_table.py | 37 +++++++++------ vllm_ascend/worker/model_runner_v1.py | 47 +++++++++++++++++-- vllm_ascend/worker/npu_input_batch.py | 2 + 5 files changed, 109 insertions(+), 17 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_mamba_utils.py diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index f7d509a24c0..06aa5d2a7b1 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -28,6 +28,7 @@ import vllm_ascend.patch.worker.patch_distributed # noqa import vllm_ascend.patch.worker.patch_minimax_m2 # noqa import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa +import vllm_ascend.patch.worker.patch_mamba_utils # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py new file mode 100644 index 00000000000..29f9f7a8b9d --- /dev/null +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -0,0 +1,39 @@ +# mypy: ignore-errors + +import torch + +import vllm +from vllm.triton_utils import tl, triton + + +@triton.jit +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + + data = tl.load(curr_src_ptr, mask=mask) + tl.store(curr_dst_ptr, data, mask=mask) + + +def batch_memcpy(src_ptrs, dst_ptrs, sizes): + batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch + + grid = (batch,) + BLOCK_SIZE = 1 + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + + +vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel +vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy \ No newline at end of file diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 4ffc7df6310..d38f2abd804 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -3,6 +3,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.cp_utils import get_total_cp_world_size class BlockTable: @@ -239,21 +240,10 @@ def __init__( device: torch.device, block_sizes: list[int], num_speculative_tokens: int = 0, + max_num_blocks: list[int] | None = None, kernel_sizes: list[list[int]] | None = None, cp_kv_cache_interleave_size: int = 1, ) -> None: - # Note(hc): each dcp rank only store - # (max_model_len//dcp_world_size) tokens in kvcache, - # so the block_size which used for calc max_num_blocks_per_req - # must be multiplied by dcp_world_size. - try: - dcp_world_size = get_dcp_group().world_size - pcp_world_size = get_pcp_group().world_size - except AssertionError: - # DCP might not be initialized in testing - dcp_world_size = 1 - pcp_world_size = 1 - if kernel_sizes is None: kernel_sizes = [[0]] * len(block_sizes) # Ensure kernel_sizes matches block_sizes length @@ -264,12 +254,29 @@ def __init__( f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})" ) + if max_num_blocks is None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + total_cp_world_size = get_total_cp_world_size() + max_num_blocks = [ + cdiv(max_model_len, block_size * total_cp_world_size) + for block_size in block_sizes + ] + + if len(max_num_blocks) != len(block_sizes): + raise ValueError( + f"max_num_blocks length ({len(max_num_blocks)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) + # Use zip to pair block_sizes with kernel_sizes one-to-one self.block_tables = [ BlockTable( block_size, max_num_reqs, - max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), + max_num_blocks_per_req, max_num_batched_tokens, pin_memory, device, @@ -277,7 +284,9 @@ def __init__( cp_kv_cache_interleave_size, num_speculative_tokens, ) - for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + for block_size, kernel_size_list, max_num_blocks_per_req in zip( + block_sizes, kernel_sizes, max_num_blocks + ) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8780b4d0fe8..2a74363bae9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -78,6 +78,11 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + check_attention_cp_compatibility, + get_total_cp_world_size, +) from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner from vllm.v1.worker.ubatch_utils import ( UBatchSlices, @@ -417,6 +422,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.cudagraph_batch_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) else: self.cudagraph_batch_sizes = [] + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None @property def use_cp(self) -> bool: @@ -1251,6 +1258,19 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -2543,6 +2563,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups @@ -2984,6 +3005,27 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] self.kernel_block_sizes.append([0]) + + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv( + max_model_len, block_sizes[i] * get_total_cp_world_size() + ) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req + if self.cache_config.enable_prefix_caching + else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + + max_num_blocks_per_req = max( + max_num_blocks_per_req, mamba_blocks_per_req + ) + max_num_blocks.append(max_num_blocks_per_req) + if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -2992,7 +3034,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: ) self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -3007,6 +3049,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: else 0 ), kernel_block_sizes=self.kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, ) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -3172,8 +3215,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers[layer_name] = attn_module if len(mamba_layers) > 0: - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError("Prefix caching is not supported for Mamba yet.") mamba_page_size_padded = 0 for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 2d7a7c8b062..b5f21a6a986 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -40,6 +40,7 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[list[int]], + max_num_blocks_per_req: list[int] | None = None, logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, @@ -97,6 +98,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + max_num_blocks=max_num_blocks_per_req, num_speculative_tokens=num_speculative_tokens, kernel_sizes=kernel_block_sizes, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, From 2fcfbdb499a3ad39983052c137335deaf4a03019 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Tue, 10 Mar 2026 11:42:58 +0800 Subject: [PATCH 02/12] add comments Signed-off-by: Angazenn --- vllm_ascend/worker/model_runner_v1.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2a74363bae9..1197ad88063 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1258,6 +1258,10 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. if self.cache_config.mamba_cache_mode == "align": mamba_utils.preprocess_mamba( scheduler_output, From b60aee131696763c8262d40525394a27bfc79b3f Mon Sep 17 00:00:00 2001 From: Angazenn Date: Tue, 10 Mar 2026 15:29:21 +0800 Subject: [PATCH 03/12] fix lint Signed-off-by: Angazenn --- vllm_ascend/patch/worker/patch_mamba_utils.py | 3 +-- vllm_ascend/worker/block_table.py | 12 +++--------- vllm_ascend/worker/model_runner_v1.py | 13 +++---------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index 29f9f7a8b9d..2a4481357a6 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -1,6 +1,5 @@ # mypy: ignore-errors -import torch import vllm from vllm.triton_utils import tl, triton @@ -36,4 +35,4 @@ def batch_memcpy(src_ptrs, dst_ptrs, sizes): vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel -vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy \ No newline at end of file +vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index d38f2abd804..3c812aa4432 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -260,15 +260,11 @@ def __init__( # so the block_size which used for calc max_num_blocks_per_req # must be multiplied by dcp_world_size. total_cp_world_size = get_total_cp_world_size() - max_num_blocks = [ - cdiv(max_model_len, block_size * total_cp_world_size) - for block_size in block_sizes - ] + max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes] if len(max_num_blocks) != len(block_sizes): raise ValueError( - f"max_num_blocks length ({len(max_num_blocks)}) " - f"must match block_sizes length ({len(block_sizes)})" + f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})" ) # Use zip to pair block_sizes with kernel_sizes one-to-one @@ -284,9 +280,7 @@ def __init__( cp_kv_cache_interleave_size, num_speculative_tokens, ) - for block_size, kernel_size_list, max_num_blocks_per_req in zip( - block_sizes, kernel_sizes, max_num_blocks - ) + for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1197ad88063..32cc8c15357 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -80,7 +80,6 @@ from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker import mamba_utils from vllm.v1.worker.cp_utils import ( - check_attention_cp_compatibility, get_total_cp_world_size, ) from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner @@ -3015,19 +3014,13 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): continue - max_num_blocks_per_req = cdiv( - max_model_len, block_sizes[i] * get_total_cp_world_size() - ) + max_num_blocks_per_req = cdiv(max_model_len, block_sizes[i] * get_total_cp_world_size()) if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): mamba_blocks_per_req = ( - max_num_blocks_per_req - if self.cache_config.enable_prefix_caching - else 1 + max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1 ) + kv_cache_group.kv_cache_spec.num_speculative_blocks - max_num_blocks_per_req = max( - max_num_blocks_per_req, mamba_blocks_per_req - ) + max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req) max_num_blocks.append(max_num_blocks_per_req) if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: From 8093dd3bd2d12e0cc7344af937f792e8ebe6dae1 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Wed, 11 Mar 2026 10:32:14 +0800 Subject: [PATCH 04/12] update kernel Signed-off-by: Angazenn --- vllm_ascend/patch/worker/patch_mamba_utils.py | 94 +++++++++++++++---- 1 file changed, 74 insertions(+), 20 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index 2a4481357a6..00b48c0f13d 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -1,38 +1,92 @@ # mypy: ignore-errors -import vllm from vllm.triton_utils import tl, triton @triton.jit -def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) +def batch_memcpy_kernel( + src_ptrs, + dst_ptrs, + sizes, + BLOCK_SIZE: tl.constexpr, + ELEMENT_SIZE: tl.constexpr, # bytes per load element (4=uint32, 8=uint64) +): + # 2D grid: axis-0 = batch index, axis-1 = chunk index within that tensor + batch_id = tl.program_id(0) + chunk_id = tl.program_id(1) - src_ptr = tl.load(src_ptrs + pid) - dst_ptr = tl.load(dst_ptrs + pid) - size = tl.load(sizes + pid) + # Load pointers and size for this batch entry + src_base = tl.load(src_ptrs + batch_id) + dst_base = tl.load(dst_ptrs + batch_id) + size_bytes = tl.load(sizes + batch_id).to(tl.int64) # enforce int64 + # Work in units of ELEMENT_SIZE bytes for wider memory transactions + chunk_bytes = BLOCK_SIZE * ELEMENT_SIZE + start_byte = chunk_id * chunk_bytes + + # Early exit if this chunk is entirely out of range + if start_byte >= size_bytes: + return + + # Cast base pointers once, outside any inner loop + src_ptr = src_base.to(tl.pointer_type(tl.uint32 if ELEMENT_SIZE == 4 else tl.uint64)) + dst_ptr = dst_base.to(tl.pointer_type(tl.uint32 if ELEMENT_SIZE == 4 else tl.uint64)) + + # Element-level offsets within this chunk offsets = tl.arange(0, BLOCK_SIZE) - for i in range(0, size, BLOCK_SIZE): - mask = (i + offsets) < size + start_elem = start_byte // ELEMENT_SIZE + size_elems = (size_bytes + ELEMENT_SIZE - 1) // ELEMENT_SIZE # ceil div - curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) - curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + elem_offsets = start_elem + offsets + mask = elem_offsets < size_elems - data = tl.load(curr_src_ptr, mask=mask) - tl.store(curr_dst_ptr, data, mask=mask) + # Wide load → store with streaming cache hint to avoid polluting L1 + data = tl.load( + src_ptr + elem_offsets, + mask=mask, + cache_modifier=".cg", # cache-global: bypass L1 for streaming data + ) + tl.store( + dst_ptr + elem_offsets, + data, + mask=mask, + cache_modifier=".cg", + ) -def batch_memcpy(src_ptrs, dst_ptrs, sizes): +# --------------------------------------------------------------------------- +# Python launcher +# --------------------------------------------------------------------------- +def batch_memcpy(src_ptrs, dst_ptrs, sizes, max_size: int | None = None): + """ + Copy each src_ptrs[i] → dst_ptrs[i] for sizes[i] bytes. + + Args: + src_ptrs: 1-D int64 tensor of source pointers, shape (B,) + dst_ptrs: 1-D int64 tensor of destination pointers, shape (B,) + sizes: 1-D int64 tensor of byte counts, shape (B,) + max_size: optional override for the largest tensor size (bytes). + If None, computed from sizes.max() — requires a device sync. + """ batch = src_ptrs.shape[0] - assert dst_ptrs.shape[0] == batch - assert sizes.shape[0] == batch - grid = (batch,) - BLOCK_SIZE = 1 - batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + if dst_ptrs.shape[0] != batch: + raise ValueError(f"dst_ptrs batch dim {dst_ptrs.shape[0]} != {batch}") + if sizes.shape[0] != batch: + raise ValueError(f"sizes batch dim {sizes.shape[0]} != {batch}") + + # Determine the maximum number of chunks needed across all tensors. + # Using a caller-supplied max_size avoids a device→host sync on the hot path. + if max_size is None: + max_size = int(sizes.max().item()) + # ELEMENT_SIZE=8 (uint64) gives the widest loads; adjust if alignment differs. + ELEMENT_SIZE = 8 + # Compute max chunks conservatively using the largest possible BLOCK_SIZE. + BLOCK_SIZE = 4096 + max_chunks = (max_size + BLOCK_SIZE * ELEMENT_SIZE - 1) // (BLOCK_SIZE * ELEMENT_SIZE) + max_chunks = max(max_chunks, 1) -vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel -vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy + grid = (batch, max_chunks) + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE, ELEMENT_SIZE) From 995e379165e55708969b9cda1a9e5acbbf4c0057 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Wed, 11 Mar 2026 10:49:56 +0800 Subject: [PATCH 05/12] bugfix Signed-off-by: Angazenn --- vllm_ascend/patch/worker/patch_mamba_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index 00b48c0f13d..f6f6f340f66 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -1,6 +1,7 @@ # mypy: ignore-errors +import vllm from vllm.triton_utils import tl, triton @@ -90,3 +91,7 @@ def batch_memcpy(src_ptrs, dst_ptrs, sizes, max_size: int | None = None): grid = (batch, max_chunks) batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE, ELEMENT_SIZE) + + +vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel +vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy From 4eb77bfb279e0a60f19389410333bc114d2990bb Mon Sep 17 00:00:00 2001 From: Angazenn Date: Wed, 11 Mar 2026 16:01:50 +0800 Subject: [PATCH 06/12] update triton kernel Signed-off-by: Angazenn --- vllm_ascend/patch/worker/patch_mamba_utils.py | 98 +++++-------------- 1 file changed, 23 insertions(+), 75 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index f6f6f340f66..700922b75e2 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -6,91 +6,39 @@ @triton.jit -def batch_memcpy_kernel( - src_ptrs, - dst_ptrs, - sizes, - BLOCK_SIZE: tl.constexpr, - ELEMENT_SIZE: tl.constexpr, # bytes per load element (4=uint32, 8=uint64) -): - # 2D grid: axis-0 = batch index, axis-1 = chunk index within that tensor - batch_id = tl.program_id(0) - chunk_id = tl.program_id(1) +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) - # Load pointers and size for this batch entry - src_base = tl.load(src_ptrs + batch_id) - dst_base = tl.load(dst_ptrs + batch_id) - size_bytes = tl.load(sizes + batch_id).to(tl.int64) # enforce int64 + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) - # Work in units of ELEMENT_SIZE bytes for wider memory transactions - chunk_bytes = BLOCK_SIZE * ELEMENT_SIZE - start_byte = chunk_id * chunk_bytes + # We need to mv pointer_type cast outside the loop. + # Otherwise it causes potential bugs. + src_ptr = src_ptr.to(tl.pointer_type(tl.uint8)) + dst_ptr = dst_ptr.to(tl.pointer_type(tl.uint8)) - # Early exit if this chunk is entirely out of range - if start_byte >= size_bytes: - return - - # Cast base pointers once, outside any inner loop - src_ptr = src_base.to(tl.pointer_type(tl.uint32 if ELEMENT_SIZE == 4 else tl.uint64)) - dst_ptr = dst_base.to(tl.pointer_type(tl.uint32 if ELEMENT_SIZE == 4 else tl.uint64)) - - # Element-level offsets within this chunk offsets = tl.arange(0, BLOCK_SIZE) - start_elem = start_byte // ELEMENT_SIZE - size_elems = (size_bytes + ELEMENT_SIZE - 1) // ELEMENT_SIZE # ceil div - - elem_offsets = start_elem + offsets - mask = elem_offsets < size_elems + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size - # Wide load → store with streaming cache hint to avoid polluting L1 - data = tl.load( - src_ptr + elem_offsets, - mask=mask, - cache_modifier=".cg", # cache-global: bypass L1 for streaming data - ) - tl.store( - dst_ptr + elem_offsets, - data, - mask=mask, - cache_modifier=".cg", - ) + curr_src_ptr = src_ptr + i + offsets + curr_dst_ptr = dst_ptr + i + offsets + # cache_modifier=".cg" bypasses L1 cache for streaming data. + data = tl.load(curr_src_ptr, mask=mask, cache_modifier=".cg") + tl.store(curr_dst_ptr, data, mask=mask, cache_modifier=".cg") -# --------------------------------------------------------------------------- -# Python launcher -# --------------------------------------------------------------------------- -def batch_memcpy(src_ptrs, dst_ptrs, sizes, max_size: int | None = None): - """ - Copy each src_ptrs[i] → dst_ptrs[i] for sizes[i] bytes. - Args: - src_ptrs: 1-D int64 tensor of source pointers, shape (B,) - dst_ptrs: 1-D int64 tensor of destination pointers, shape (B,) - sizes: 1-D int64 tensor of byte counts, shape (B,) - max_size: optional override for the largest tensor size (bytes). - If None, computed from sizes.max() — requires a device sync. - """ +def batch_memcpy(src_ptrs, dst_ptrs, sizes): batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch - if dst_ptrs.shape[0] != batch: - raise ValueError(f"dst_ptrs batch dim {dst_ptrs.shape[0]} != {batch}") - if sizes.shape[0] != batch: - raise ValueError(f"sizes batch dim {sizes.shape[0]} != {batch}") - - # Determine the maximum number of chunks needed across all tensors. - # Using a caller-supplied max_size avoids a device→host sync on the hot path. - if max_size is None: - max_size = int(sizes.max().item()) - - # ELEMENT_SIZE=8 (uint64) gives the widest loads; adjust if alignment differs. - ELEMENT_SIZE = 8 - # Compute max chunks conservatively using the largest possible BLOCK_SIZE. - BLOCK_SIZE = 4096 - max_chunks = (max_size + BLOCK_SIZE * ELEMENT_SIZE - 1) // (BLOCK_SIZE * ELEMENT_SIZE) - max_chunks = max(max_chunks, 1) - - grid = (batch, max_chunks) - batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE, ELEMENT_SIZE) + grid = (batch,) + # using larger block_size to accelerate copy. + BLOCK_SIZE = 8192 + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel From 0cfa290623a5c25530ed4a6cf513b740cbf4d882 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 13 Mar 2026 16:55:51 +0800 Subject: [PATCH 07/12] modify Signed-off-by: Angazenn --- vllm_ascend/ops/triton/batch_memcpy.py | 31 +++++++++++++++++++ vllm_ascend/patch/__init__.py | 22 +++++++++++++ vllm_ascend/patch/worker/patch_mamba_utils.py | 26 +--------------- 3 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 vllm_ascend/ops/triton/batch_memcpy.py diff --git a/vllm_ascend/ops/triton/batch_memcpy.py b/vllm_ascend/ops/triton/batch_memcpy.py new file mode 100644 index 00000000000..0bd576d3876 --- /dev/null +++ b/vllm_ascend/ops/triton/batch_memcpy.py @@ -0,0 +1,31 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.triton_utils import tl, triton + + +@triton.jit +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + # We need to mv pointer_type cast outside the loop. + # Otherwise it causes potential bugs. + src_ptr = src_ptr.to(tl.pointer_type(tl.uint8)) + dst_ptr = dst_ptr.to(tl.pointer_type(tl.uint8)) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = src_ptr + i + offsets + curr_dst_ptr = dst_ptr + i + offsets + + # cache_modifier=".cg" bypasses L1 cache for streaming data. + data = tl.load(curr_src_ptr, mask=mask, cache_modifier=".cg") + tl.store(curr_dst_ptr, data, mask=mask, cache_modifier=".cg") diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4eb74013467..41463c7b594 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -506,3 +506,25 @@ # Rotary quant is a unique feature of vllm-ascend. # Future Plan: # Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. +# ** 22. File: worker/patch_mamba_utils.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel` +# Why: +# Oringnal batch_memcpy_kernel implemented in vLLM might encounter bugs when running on +# Ascend hardwares. +# How: +# patch to fix related bugs. +# Future Plan: +# Remove this patch when: +# (1) oringnal batch_memcpy_kernel can run on Ascend hardware. +# or +# (2) design a dispatch mechanism for batch_memcpy_kernel. +# 2. `vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy` +# Why: +# vLLM use BLOCK_SIZE 1024 for batch_memcpy_kernel. This results in suboptimal performance +# on Ascend hardwares. +# How: +# patch to change BLOCK_SIZE to 8192. +# Future Plan: +# Remove this patch when: +# design a dispatch mechanism for batch_memcpy_kernel. diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index 700922b75e2..8b899c6f850 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -2,32 +2,8 @@ import vllm -from vllm.triton_utils import tl, triton - -@triton.jit -def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - - src_ptr = tl.load(src_ptrs + pid) - dst_ptr = tl.load(dst_ptrs + pid) - size = tl.load(sizes + pid) - - # We need to mv pointer_type cast outside the loop. - # Otherwise it causes potential bugs. - src_ptr = src_ptr.to(tl.pointer_type(tl.uint8)) - dst_ptr = dst_ptr.to(tl.pointer_type(tl.uint8)) - - offsets = tl.arange(0, BLOCK_SIZE) - for i in range(0, size, BLOCK_SIZE): - mask = (i + offsets) < size - - curr_src_ptr = src_ptr + i + offsets - curr_dst_ptr = dst_ptr + i + offsets - - # cache_modifier=".cg" bypasses L1 cache for streaming data. - data = tl.load(curr_src_ptr, mask=mask, cache_modifier=".cg") - tl.store(curr_dst_ptr, data, mask=mask, cache_modifier=".cg") +from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel def batch_memcpy(src_ptrs, dst_ptrs, sizes): From 311b096eb7b4fbe94e8599711f9e8717dc414abc Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 13 Mar 2026 17:56:22 +0800 Subject: [PATCH 08/12] add test Signed-off-by: Angazenn --- .../triton/test_batch_memcpy.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py new file mode 100644 index 00000000000..a86016425d0 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py @@ -0,0 +1,37 @@ +import pytest +import torch + +from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_batch_memcpy(dtype): + device = "npu:0" + # this is a typical case when used in mamba states copy. + sizes = torch.tensors([24576, 262144, 24576, 262144], device=device, dtype=torch.int32) + + src_tensors_list = [] + src_addr_list = [] + dst_tensors_list = [] + dst_addr_list = [] + for i in range(len(sizes)): + src_tensors_list.append( + torch.rand(sizes[i].item(), dtype=dtype, device=device) + ) + src_addr_list.append(src_tensors_list[-1].data_ptr()) + dst_tensors_list.append( + torch.empty(sizes[i].item(), dtype=dtype, device=device) + ) + dst_addr_list.append(dst_tensors_list[-1].data_ptr()) + + src_addr_list = torch.tensor(src_addr_list, dtype=torch.int64, device=device) + dst_addr_list = torch.tensor(dst_addr_list, dtype=torch.int64, device=device) + + batch = src_addr_list.shape[0] + + grid = (batch,) + # using larger block_size to accelerate copy. + BLOCK_SIZE = 8192 + batch_memcpy_kernel[grid](src_addr_list, dst_addr_list, sizes, BLOCK_SIZE=BLOCK_SIZE) + + for i in range(len(sizes)): + assert src_tensors_list[i] == dst_tensors_list[i] From 07cc14bf8a3e2018dfb9166cca5c83c6ce10820e Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 13 Mar 2026 18:12:15 +0800 Subject: [PATCH 09/12] fix Signed-off-by: Angazenn --- vllm_ascend/patch/worker/patch_mamba_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index 8b899c6f850..0b3c550c2fa 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -2,6 +2,7 @@ import vllm +from vllm.v1.worker import mamba_utils from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel @@ -17,5 +18,5 @@ def batch_memcpy(src_ptrs, dst_ptrs, sizes): batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) -vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel -vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy +mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel +mamba_utils.batch_memcpy = batch_memcpy From 487eab35bd079c8c32b3bc8aa59e0db6e3bead10 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 13 Mar 2026 18:22:24 +0800 Subject: [PATCH 10/12] bugfix Signed-off-by: Angazenn --- .../ops/singlecard_ops/triton/test_batch_memcpy.py | 7 ++++--- vllm_ascend/patch/worker/patch_mamba_utils.py | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py index a86016425d0..d98fe0f921e 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py @@ -5,9 +5,10 @@ @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_batch_memcpy(dtype): + element_size = 2 if dtype == torch.bfloat16 else 4 device = "npu:0" # this is a typical case when used in mamba states copy. - sizes = torch.tensors([24576, 262144, 24576, 262144], device=device, dtype=torch.int32) + sizes = torch.tensor([24576, 262144, 24576, 262144], device=device, dtype=torch.int32) src_tensors_list = [] src_addr_list = [] @@ -15,11 +16,11 @@ def test_batch_memcpy(dtype): dst_addr_list = [] for i in range(len(sizes)): src_tensors_list.append( - torch.rand(sizes[i].item(), dtype=dtype, device=device) + torch.rand(sizes[i].item() // element_size, dtype=dtype, device=device) ) src_addr_list.append(src_tensors_list[-1].data_ptr()) dst_tensors_list.append( - torch.empty(sizes[i].item(), dtype=dtype, device=device) + torch.empty(sizes[i].item() // element_size, dtype=dtype, device=device) ) dst_addr_list.append(dst_tensors_list[-1].data_ptr()) diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py index 0b3c550c2fa..063789bfd4c 100644 --- a/vllm_ascend/patch/worker/patch_mamba_utils.py +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -1,8 +1,7 @@ # mypy: ignore-errors -import vllm -from vllm.v1.worker import mamba_utils +from vllm.v1.worker import mamba_utils from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel From 1a95456b0bd1881ede0d5853ba360eaa5f6b9cdb Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 13 Mar 2026 21:48:45 +0800 Subject: [PATCH 11/12] fix lint Signed-off-by: Angazenn --- .../single_node/ops/singlecard_ops/triton/test_batch_memcpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py index d98fe0f921e..54a8e2f12e1 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py @@ -35,4 +35,4 @@ def test_batch_memcpy(dtype): batch_memcpy_kernel[grid](src_addr_list, dst_addr_list, sizes, BLOCK_SIZE=BLOCK_SIZE) for i in range(len(sizes)): - assert src_tensors_list[i] == dst_tensors_list[i] + torch.testing.assert_close(src_tensors_list[i], dst_tensors_list[i], rtol=0, atol=0) From b14853a7f35fafaab07de09a68ef71894e1120cf Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 13 Mar 2026 23:18:14 +0800 Subject: [PATCH 12/12] fix Signed-off-by: Angazenn --- .../single_node/ops/singlecard_ops/triton/test_batch_memcpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py index 54a8e2f12e1..b5162a6d96e 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py @@ -27,7 +27,7 @@ def test_batch_memcpy(dtype): src_addr_list = torch.tensor(src_addr_list, dtype=torch.int64, device=device) dst_addr_list = torch.tensor(dst_addr_list, dtype=torch.int64, device=device) - batch = src_addr_list.shape[0] + batch = sizes.shape[0] grid = (batch,) # using larger block_size to accelerate copy.