diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py new file mode 100644 index 00000000000..b5162a6d96e --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_batch_memcpy(dtype): + element_size = 2 if dtype == torch.bfloat16 else 4 + device = "npu:0" + # this is a typical case when used in mamba states copy. + sizes = torch.tensor([24576, 262144, 24576, 262144], device=device, dtype=torch.int32) + + src_tensors_list = [] + src_addr_list = [] + dst_tensors_list = [] + dst_addr_list = [] + for i in range(len(sizes)): + src_tensors_list.append( + torch.rand(sizes[i].item() // element_size, dtype=dtype, device=device) + ) + src_addr_list.append(src_tensors_list[-1].data_ptr()) + dst_tensors_list.append( + torch.empty(sizes[i].item() // element_size, dtype=dtype, device=device) + ) + dst_addr_list.append(dst_tensors_list[-1].data_ptr()) + + src_addr_list = torch.tensor(src_addr_list, dtype=torch.int64, device=device) + dst_addr_list = torch.tensor(dst_addr_list, dtype=torch.int64, device=device) + + batch = sizes.shape[0] + + grid = (batch,) + # using larger block_size to accelerate copy. + BLOCK_SIZE = 8192 + batch_memcpy_kernel[grid](src_addr_list, dst_addr_list, sizes, BLOCK_SIZE=BLOCK_SIZE) + + for i in range(len(sizes)): + torch.testing.assert_close(src_tensors_list[i], dst_tensors_list[i], rtol=0, atol=0) diff --git a/vllm_ascend/ops/triton/batch_memcpy.py b/vllm_ascend/ops/triton/batch_memcpy.py new file mode 100644 index 00000000000..0bd576d3876 --- /dev/null +++ b/vllm_ascend/ops/triton/batch_memcpy.py @@ -0,0 +1,31 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.triton_utils import tl, triton + + +@triton.jit +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + # We need to mv pointer_type cast outside the loop. + # Otherwise it causes potential bugs. + src_ptr = src_ptr.to(tl.pointer_type(tl.uint8)) + dst_ptr = dst_ptr.to(tl.pointer_type(tl.uint8)) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = src_ptr + i + offsets + curr_dst_ptr = dst_ptr + i + offsets + + # cache_modifier=".cg" bypasses L1 cache for streaming data. + data = tl.load(curr_src_ptr, mask=mask, cache_modifier=".cg") + tl.store(curr_dst_ptr, data, mask=mask, cache_modifier=".cg") diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4eb74013467..41463c7b594 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -506,3 +506,25 @@ # Rotary quant is a unique feature of vllm-ascend. # Future Plan: # Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. +# ** 22. File: worker/patch_mamba_utils.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel` +# Why: +# Oringnal batch_memcpy_kernel implemented in vLLM might encounter bugs when running on +# Ascend hardwares. +# How: +# patch to fix related bugs. +# Future Plan: +# Remove this patch when: +# (1) oringnal batch_memcpy_kernel can run on Ascend hardware. +# or +# (2) design a dispatch mechanism for batch_memcpy_kernel. +# 2. `vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy` +# Why: +# vLLM use BLOCK_SIZE 1024 for batch_memcpy_kernel. This results in suboptimal performance +# on Ascend hardwares. +# How: +# patch to change BLOCK_SIZE to 8192. +# Future Plan: +# Remove this patch when: +# design a dispatch mechanism for batch_memcpy_kernel. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index f7d509a24c0..06aa5d2a7b1 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -28,6 +28,7 @@ import vllm_ascend.patch.worker.patch_distributed # noqa import vllm_ascend.patch.worker.patch_minimax_m2 # noqa import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa +import vllm_ascend.patch.worker.patch_mamba_utils # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py new file mode 100644 index 00000000000..063789bfd4c --- /dev/null +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -0,0 +1,21 @@ +# mypy: ignore-errors + + +from vllm.v1.worker import mamba_utils + +from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel + + +def batch_memcpy(src_ptrs, dst_ptrs, sizes): + batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch + + grid = (batch,) + # using larger block_size to accelerate copy. + BLOCK_SIZE = 8192 + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + + +mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel +mamba_utils.batch_memcpy = batch_memcpy diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 4ffc7df6310..3c812aa4432 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -3,6 +3,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.cp_utils import get_total_cp_world_size class BlockTable: @@ -239,21 +240,10 @@ def __init__( device: torch.device, block_sizes: list[int], num_speculative_tokens: int = 0, + max_num_blocks: list[int] | None = None, kernel_sizes: list[list[int]] | None = None, cp_kv_cache_interleave_size: int = 1, ) -> None: - # Note(hc): each dcp rank only store - # (max_model_len//dcp_world_size) tokens in kvcache, - # so the block_size which used for calc max_num_blocks_per_req - # must be multiplied by dcp_world_size. - try: - dcp_world_size = get_dcp_group().world_size - pcp_world_size = get_pcp_group().world_size - except AssertionError: - # DCP might not be initialized in testing - dcp_world_size = 1 - pcp_world_size = 1 - if kernel_sizes is None: kernel_sizes = [[0]] * len(block_sizes) # Ensure kernel_sizes matches block_sizes length @@ -264,12 +254,25 @@ def __init__( f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})" ) + if max_num_blocks is None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + total_cp_world_size = get_total_cp_world_size() + max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes] + + if len(max_num_blocks) != len(block_sizes): + raise ValueError( + f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})" + ) + # Use zip to pair block_sizes with kernel_sizes one-to-one self.block_tables = [ BlockTable( block_size, max_num_reqs, - max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), + max_num_blocks_per_req, max_num_batched_tokens, pin_memory, device, @@ -277,7 +280,7 @@ def __init__( cp_kv_cache_interleave_size, num_speculative_tokens, ) - for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 896e90e032f..3c6d48bce2f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,6 +77,10 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + get_total_cp_world_size, +) from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner from vllm.v1.worker.ubatch_utils import ( UBatchSlices, @@ -416,6 +420,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.cudagraph_batch_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) else: self.cudagraph_batch_sizes = [] + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None @property def use_cp(self) -> bool: @@ -1250,6 +1256,23 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -2542,6 +2565,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups @@ -2983,6 +3007,21 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] self.kernel_block_sizes.append([0]) + + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv(max_model_len, block_sizes[i] * get_total_cp_world_size()) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + + max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req) + max_num_blocks.append(max_num_blocks_per_req) + if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -2991,7 +3030,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: ) self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -3006,6 +3045,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: else 0 ), kernel_block_sizes=self.kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, ) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -3171,8 +3211,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers[layer_name] = attn_module if len(mamba_layers) > 0: - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError("Prefix caching is not supported for Mamba yet.") mamba_page_size_padded = 0 for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 2d7a7c8b062..b5f21a6a986 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -40,6 +40,7 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[list[int]], + max_num_blocks_per_req: list[int] | None = None, logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, @@ -97,6 +98,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + max_num_blocks=max_num_blocks_per_req, num_speculative_tokens=num_speculative_tokens, kernel_sizes=kernel_block_sizes, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,