Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import torch

from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_batch_memcpy(dtype):
element_size = 2 if dtype == torch.bfloat16 else 4
device = "npu:0"
# this is a typical case when used in mamba states copy.
sizes = torch.tensor([24576, 262144, 24576, 262144], device=device, dtype=torch.int32)

src_tensors_list = []
src_addr_list = []
dst_tensors_list = []
dst_addr_list = []
for i in range(len(sizes)):
src_tensors_list.append(
torch.rand(sizes[i].item() // element_size, dtype=dtype, device=device)
)
src_addr_list.append(src_tensors_list[-1].data_ptr())
dst_tensors_list.append(
torch.empty(sizes[i].item() // element_size, dtype=dtype, device=device)
)
dst_addr_list.append(dst_tensors_list[-1].data_ptr())

src_addr_list = torch.tensor(src_addr_list, dtype=torch.int64, device=device)
dst_addr_list = torch.tensor(dst_addr_list, dtype=torch.int64, device=device)

batch = sizes.shape[0]

grid = (batch,)
# using larger block_size to accelerate copy.
BLOCK_SIZE = 8192
batch_memcpy_kernel[grid](src_addr_list, dst_addr_list, sizes, BLOCK_SIZE=BLOCK_SIZE)

for i in range(len(sizes)):
torch.testing.assert_close(src_tensors_list[i], dst_tensors_list[i], rtol=0, atol=0)
31 changes: 31 additions & 0 deletions vllm_ascend/ops/triton/batch_memcpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from vllm.triton_utils import tl, triton


@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)

src_ptr = tl.load(src_ptrs + pid)
dst_ptr = tl.load(dst_ptrs + pid)
size = tl.load(sizes + pid)

# We need to mv pointer_type cast outside the loop.
# Otherwise it causes potential bugs.
src_ptr = src_ptr.to(tl.pointer_type(tl.uint8))
dst_ptr = dst_ptr.to(tl.pointer_type(tl.uint8))

offsets = tl.arange(0, BLOCK_SIZE)
for i in range(0, size, BLOCK_SIZE):
mask = (i + offsets) < size

curr_src_ptr = src_ptr + i + offsets
curr_dst_ptr = dst_ptr + i + offsets

# cache_modifier=".cg" bypasses L1 cache for streaming data.
data = tl.load(curr_src_ptr, mask=mask, cache_modifier=".cg")
tl.store(curr_dst_ptr, data, mask=mask, cache_modifier=".cg")
22 changes: 22 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,25 @@
# Rotary quant is a unique feature of vllm-ascend.
# Future Plan:
# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`.
# ** 22. File: worker/patch_mamba_utils.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel`
# Why:
# Oringnal batch_memcpy_kernel implemented in vLLM might encounter bugs when running on
# Ascend hardwares.
# How:
# patch to fix related bugs.
# Future Plan:
# Remove this patch when:
# (1) oringnal batch_memcpy_kernel can run on Ascend hardware.
# or
# (2) design a dispatch mechanism for batch_memcpy_kernel.
# 2. `vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy`
# Why:
# vLLM use BLOCK_SIZE 1024 for batch_memcpy_kernel. This results in suboptimal performance
# on Ascend hardwares.
# How:
# patch to change BLOCK_SIZE to 8192.
# Future Plan:
# Remove this patch when:
# design a dispatch mechanism for batch_memcpy_kernel.
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import vllm_ascend.patch.worker.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_minimax_m2 # noqa
import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa
import vllm_ascend.patch.worker.patch_mamba_utils # noqa
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
Expand Down
21 changes: 21 additions & 0 deletions vllm_ascend/patch/worker/patch_mamba_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# mypy: ignore-errors


from vllm.v1.worker import mamba_utils

from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel


def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch

grid = (batch,)
# using larger block_size to accelerate copy.
BLOCK_SIZE = 8192
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)


mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel
mamba_utils.batch_memcpy = batch_memcpy
31 changes: 17 additions & 14 deletions vllm_ascend/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size


class BlockTable:
Expand Down Expand Up @@ -239,21 +240,10 @@ def __init__(
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
max_num_blocks: list[int] | None = None,
kernel_sizes: list[list[int]] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
pcp_world_size = 1

if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
# Ensure kernel_sizes matches block_sizes length
Expand All @@ -264,20 +254,33 @@ def __init__(
f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})"
)

if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes]

if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})"
)

# Use zip to pair block_sizes with kernel_sizes one-to-one
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens),
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_size_list,
cp_kv_cache_interleave_size,
num_speculative_tokens,
)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks)
]

def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
Expand Down
44 changes: 41 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.cp_utils import (
get_total_cp_world_size,
)
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
Expand Down Expand Up @@ -416,6 +420,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.cudagraph_batch_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
else:
self.cudagraph_batch_sizes = []
self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None

@property
def use_cp(self) -> bool:
Expand Down Expand Up @@ -1250,6 +1256,23 @@ def execute_model(

pad_attn = cudagraph_mode == CUDAGraphMode.FULL

# NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877,
# there should be a corresponding 'postprocess_mamba'. However, it is called inside
# '_update_states_after_model_execute', which is not overridden in vLLM-Ascend.
# We simply utilize the implementation in vLLM.
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.preprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.cache_config,
self.mamba_state_idx,
self.input_batch,
self.requests,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
)

use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices

Expand Down Expand Up @@ -2542,6 +2565,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
Expand Down Expand Up @@ -2983,6 +3007,21 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
# of mamba block. In this case, BlockTable.block_size will never equal
# to kernel_block_sizes[0]
self.kernel_block_sizes.append([0])

max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
max_num_blocks_per_req = cdiv(max_model_len, block_sizes[i] * get_total_cp_world_size())
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = (
max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks

max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req)
max_num_blocks.append(max_num_blocks_per_req)

if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
Expand All @@ -2991,7 +3030,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
)
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.model_config.max_model_len, self.max_encoder_len),
max_model_len=max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
Expand All @@ -3006,6 +3045,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
else 0
),
kernel_block_sizes=self.kernel_block_sizes,
max_num_blocks_per_req=max_num_blocks,
)

def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
Expand Down Expand Up @@ -3171,8 +3211,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
mamba_layers[layer_name] = attn_module

if len(mamba_layers) > 0:
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
mamba_page_size_padded = 0
for layer_name, mamba_module in mamba_layers.items():
if spec := mamba_module.get_kv_cache_spec(self.vllm_config):
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker/npu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[list[int]],
max_num_blocks_per_req: list[int] | None = None,
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
max_num_blocks=max_num_blocks_per_req,
num_speculative_tokens=num_speculative_tokens,
kernel_sizes=kernel_block_sizes,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
Expand Down
Loading