Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a5bfa27
[Async][spec decode] Zero-bubble async scheduling +spec decoding
Mar 25, 2026
e599872
[Async][spec decode] Zero-bubble async scheduling +spec decoding
Mar 25, 2026
33a6d13
[Async][spec decode] Zero-bubble async scheduling +spec decoding
Mar 25, 2026
0a60c26
optimize
Mar 27, 2026
9248184
update to 0324
22dimensions Mar 25, 2026
1dfa935
fix: add vllm_is_batch_invariant compatibility wrapper
claude Mar 25, 2026
29eb701
Merge branch 'main' into zero_bubble_async_spec
HF-001 Mar 31, 2026
c82dad6
fix
Mar 31, 2026
62c38ee
fix format
Mar 31, 2026
74e2526
Merge branch 'main' into zero_bubble_async_spec
HF-001 Mar 31, 2026
dfe1b6d
fix
Mar 31, 2026
5f157ba
fix
Mar 31, 2026
9b0ff73
fix
Mar 31, 2026
c39d214
fix
Mar 31, 2026
013dcbe
fix
Mar 31, 2026
f28aacd
Merge branch 'main' into zero_bubble_async_spec
HF-001 Mar 31, 2026
8ca4fcd
Merge branch 'main' into zero_bubble_async_spec
HF-001 Mar 31, 2026
69bf146
fix
HF-001 Mar 31, 2026
fce6a54
Merge branch 'main' into zero_bubble_async_spec
HF-001 Mar 31, 2026
c45a066
fix ut test
Apr 1, 2026
2acc473
fix ut test
Apr 1, 2026
946107d
fix ut test
Apr 1, 2026
80a604b
fix ut test
Apr 1, 2026
5c88ee5
fix
Apr 1, 2026
c1e05db
fix
Apr 1, 2026
f36c75c
Merge branch 'main' into zero_bubble_async_spec
HF-001 Apr 1, 2026
eba0a00
Merge branch 'main' into zero_bubble_async_spec
HF-001 Apr 2, 2026
687e8c1
fix kvcache
Apr 2, 2026
a9bac6f
fix ci
Apr 2, 2026
e931e98
fix
Apr 2, 2026
4363d13
fix
Apr 2, 2026
5609a66
fix
Apr 2, 2026
7f441b0
fix
HF-001 Apr 2, 2026
e01874d
Merge branch 'main' into zero_bubble_async_spec
HF-001 Apr 2, 2026
bcbee07
fix
Apr 3, 2026
c83ea55
fix
Apr 3, 2026
bd21d43
Merge branch 'main' into zero_bubble_async_spec
HF-001 Apr 3, 2026
c0239cc
fix
Apr 3, 2026
b4b1dfb
Merge branch 'main' into zero_bubble_async_spec
HF-001 Apr 4, 2026
39f2498
fix
HF-001 Apr 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ on:
continue_on_error:
required: false
type: boolean
default: false
default: true
# The following inputs are used by comment-triggered E2E tests (/e2e <tests>).
# They carry space-separated pytest paths, categorized by runner type.
# Leave empty (default) when running label-triggered full/light suites.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_test_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ jobs:
strategy:
fail-fast: false
matrix:
vllm_version: [35141a7eeda941a60ad5a4956670c60fd5a77029]
vllm_version: [35141a7eeda941a60ad5a4956670c60fd5a77029, v0.18.0]
needs: [parse-trigger]
if: ${{ needs.parse-trigger.outputs.allowed == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml
Expand Down
100 changes: 55 additions & 45 deletions tests/ut/worker/test_block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@

import numpy as np
import torch

# import vllm.utils.cpu_triton_utils as cpu_tl
from vllm.distributed.parallel_state import GroupCoordinator

from tests.ut.base import TestBase


class TestBlockTableComputeSlotMapping(TestBase):
"""Test suite for BlockTable.compute_slot_mapping() method

This test suite covers different configurations of DCP (Decode Context Parallelism),
PCP (Prefill Context Parallelism), and cp_kv_cache_interleave_size to ensure
correct slot_mapping calculation on different ranks.
Expand All @@ -41,13 +43,13 @@ def setUp(self):
self.device = torch.device("cpu")
self.kernel_sizes = [128]

def create_block_table(self, dcp_world_size, dcp_rank, pcp_world_size,
pcp_rank, cp_kv_cache_interleave_size):
def create_block_table(self, dcp_world_size, dcp_rank, pcp_world_size, pcp_rank, cp_kv_cache_interleave_size):
"""Helper method to create BlockTable with mocked distributed groups"""

with patch('vllm_ascend.worker.block_table.get_dcp_group') as mock_get_dcp_group, \
patch('vllm_ascend.worker.block_table.get_pcp_group') as mock_get_pcp_group:

with (
patch("vllm_ascend.worker.block_table.get_dcp_group") as mock_get_dcp_group,
patch("vllm_ascend.worker.block_table.get_pcp_group") as mock_get_pcp_group,
):
# Mock DCP group
mock_dcp_group = MagicMock(spec=GroupCoordinator)
mock_dcp_group.world_size = dcp_world_size
Expand All @@ -71,23 +73,21 @@ def create_block_table(self, dcp_world_size, dcp_rank, pcp_world_size,
device=self.device,
kernel_sizes=self.kernel_sizes,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
num_speculative_tokens=0)
num_speculative_tokens=0,
)

return block_table

def setup_block_table_data(self, block_table, num_reqs=2):
"""Helper method to populate block table with test data"""
# Add block IDs for each request
for i in range(num_reqs):
block_ids = list(range(i * 4,
(i + 1) * 4)) # [0,1,2,3], [4,5,6,7], etc.
block_ids = list(range(i * 4, (i + 1) * 4)) # [0,1,2,3], [4,5,6,7], etc.
block_table.add_row(block_ids, i)

def _test_slot_mapping_for_ranks(self, dcp_world_size, pcp_world_size,
cp_kv_cache_interleave_size,
test_configs):
def _test_slot_mapping_for_ranks(self, dcp_world_size, pcp_world_size, cp_kv_cache_interleave_size, test_configs):
"""Helper method to test slot_mapping across multiple ranks

Args:
dcp_world_size: Number of DCP ranks
pcp_world_size: Number of PCP ranks
Expand All @@ -97,31 +97,46 @@ def _test_slot_mapping_for_ranks(self, dcp_world_size, pcp_world_size,
for dcp_rank, pcp_rank, req_indices, positions, expected_result in test_configs:
with self.subTest(dcp_rank=dcp_rank, pcp_rank=pcp_rank):
block_table = self.create_block_table(
dcp_world_size, dcp_rank, pcp_world_size, pcp_rank,
cp_kv_cache_interleave_size)
dcp_world_size, dcp_rank, pcp_world_size, pcp_rank, cp_kv_cache_interleave_size
)

num_reqs = max(req_indices) + 1 if len(req_indices) > 0 else 1
self.setup_block_table_data(block_table, num_reqs=num_reqs)

block_table.compute_slot_mapping(req_indices, positions)
# Build query_start_loc [num_reqs + 1] from req_indices.
# query_start_loc holds the cumulative token count per request,
# e.g. req_indices=[0,0,1,1] -> query_start_loc=[0,2,4].
num_tokens = len(positions)
counts = np.bincount(req_indices, minlength=num_reqs)
query_start_loc_np = np.concatenate([[0], np.cumsum(counts)]).astype(np.int32)
query_start_loc = torch.from_numpy(query_start_loc_np)

# positions must be a torch int64 tensor to match the
# _compute_slot_mapping_kernel's positions_ptr type.
positions_tensor = torch.from_numpy(positions.astype(np.int64))
# block_table._compute_slot_mapping_kernel = cpu_tl.compute_slot_mapping_kernel
block_table.compute_slot_mapping(num_reqs, query_start_loc, positions_tensor)

actual_result = block_table.slot_mapping.np[:num_tokens]

actual_result = block_table.slot_mapping.np[:len(positions)]
np.testing.assert_array_equal(
actual_result, expected_result,
actual_result,
expected_result,
f"DCP={dcp_world_size}, PCP={pcp_world_size}, "
f"interleave={cp_kv_cache_interleave_size}, "
f"dcp_rank={dcp_rank}, pcp_rank={pcp_rank}")
f"dcp_rank={dcp_rank}, pcp_rank={pcp_rank}",
)

def test_compute_slot_mapping_dcp1_pcp1_interleave1(self):
"""Test compute_slot_mapping with DCP=1, PCP=1, interleave_size=1

With no parallelism (DCP=1, PCP=1), all tokens are local to the single rank.

Setup:
- Block size: 16
- Request 0 has blocks: [0, 1, 2, 3]
- Request 1 has blocks: [4, 5, 6, 7]

Test positions for each request:
- Request 0, position 0: block_id=0, offset=0 → slot = 0*128+0 = 0
- Request 0, position 1: block_id=0, offset=1 → slot = 0*128+1 = 1
Expand All @@ -137,14 +152,13 @@ def test_compute_slot_mapping_dcp1_pcp1_interleave1(self):
(0, 0, req_indices, positions, expected_result),
]

self._test_slot_mapping_for_ranks(dcp_world_size=1,
pcp_world_size=1,
cp_kv_cache_interleave_size=1,
test_configs=test_configs)
self._test_slot_mapping_for_ranks(
dcp_world_size=1, pcp_world_size=1, cp_kv_cache_interleave_size=1, test_configs=test_configs
)

def test_compute_slot_mapping_dcp4_pcp2_interleave1(self):
"""Test compute_slot_mapping with DCP=4, PCP=2, interleave_size=1

With interleave_size=1, tokens are distributed round-robin across all 8 ranks:
- Position 0 → Rank 0
- Position 1 → Rank 1
Expand Down Expand Up @@ -183,28 +197,25 @@ def test_compute_slot_mapping_dcp4_pcp2_interleave1(self):
for pcp_rank in range(2):
for dcp_rank in range(4):
current_rank = 4 * pcp_rank + dcp_rank
expected_result = np.array(rank_expectations[current_rank],
dtype=np.int32)
test_configs.append((dcp_rank, pcp_rank, req_indices,
positions, expected_result))
expected_result = np.array(rank_expectations[current_rank], dtype=np.int32)
test_configs.append((dcp_rank, pcp_rank, req_indices, positions, expected_result))

self._test_slot_mapping_for_ranks(dcp_world_size=4,
pcp_world_size=2,
cp_kv_cache_interleave_size=1,
test_configs=test_configs)
self._test_slot_mapping_for_ranks(
dcp_world_size=4, pcp_world_size=2, cp_kv_cache_interleave_size=1, test_configs=test_configs
)

def test_compute_slot_mapping_dcp4_pcp2_interleave128(self):
"""Test compute_slot_mapping with DCP=4, PCP=2, interleave_size=128

With interleave_size=128, tokens are distributed in chunks of 128 across ranks.
Virtual block size = 16 * 4 * 2 = 128

Token distribution with interleave_size=128:
- Positions 0-127 belong to rank 0 (first chunk of 128)
- Positions 128-255 belong to rank 1 (second chunk of 128)
- Positions 256-383 belong to rank 2 (third chunk of 128)
- And so on...

Using 130 positions ensures we test both rank 0 (positions 0-127) and rank 1 (positions 128-129).
"""
num_positions = 130
Expand Down Expand Up @@ -245,14 +256,13 @@ def test_compute_slot_mapping_dcp4_pcp2_interleave128(self):
expected_result = [-1] * 130

test_configs.append(
(dcp_rank, pcp_rank, req_indices, positions,
np.array(expected_result, dtype=np.int32)))
(dcp_rank, pcp_rank, req_indices, positions, np.array(expected_result, dtype=np.int32))
)

self._test_slot_mapping_for_ranks(dcp_world_size=4,
pcp_world_size=2,
cp_kv_cache_interleave_size=128,
test_configs=test_configs)
self._test_slot_mapping_for_ranks(
dcp_world_size=4, pcp_world_size=2, cp_kv_cache_interleave_size=128, test_configs=test_configs
)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
44 changes: 40 additions & 4 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def build(
)

block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens[:num_reqs].to("cpu")

slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
# this slot_mapping override doesn't work since vllm will override it again. We should fix it vllm.
Expand Down Expand Up @@ -688,7 +688,26 @@ def full_graph_pa(
graph_params.handles[num_tokens].append(handle)
return output

def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, kv_cache=None):
# PrefillNoCache doesn't need key_cache, but other modes do
# Only initialize/require cache for modes that actually use it
if attn_metadata.attn_state != AscendAttentionState.PrefillNoCache:
# Initialize cache from kv_cache if not already set (for DecodeOnly mode)
if self.key_cache is None and kv_cache is not None:
if (
isinstance(kv_cache, torch.Tensor)
and kv_cache.dim() > 0
and kv_cache.shape[0] == 2
or isinstance(kv_cache, (list, tuple))
and len(kv_cache) >= 2
):
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]

if self.key_cache is None:
raise RuntimeError(
f"key_cache is None in _get_fia_params for mode {attn_metadata.attn_state}. kv_cache={kv_cache}"
)

if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
Expand Down Expand Up @@ -766,6 +785,7 @@ def forward_fused_infer_attention(
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
kv_cache=None,
):
# we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
Expand All @@ -781,7 +801,9 @@ def forward_fused_infer_attention(
and self.sinks is None
):
return self._forward_fia_slidingwindow(query, attn_metadata, output)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(
key, value, attn_metadata, kv_cache
)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
if (
Expand Down Expand Up @@ -927,7 +949,7 @@ def forward_impl(
):
output = self.forward_paged_attention(query, attn_metadata, output)
else:
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output, kv_cache)

return output

Expand Down Expand Up @@ -963,6 +985,20 @@ def forward(
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)

# Initialize key_cache and value_cache from kv_cache if not already set.
# This is needed for DecodeOnly mode where key/value are None but we still
# need access to the cache for attention computation.
if self.key_cache is None and kv_cache is not None:
if (
isinstance(kv_cache, torch.Tensor)
and kv_cache.dim() > 0
and kv_cache.shape[0] == 2
or isinstance(kv_cache, (list, tuple))
and len(kv_cache) >= 2
):
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]

output_padded = None
if key is not None and value is not None:
output_padded = output
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,11 @@ def build(

query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
self.seq_lens = None
if common_attn_metadata.seq_lens_cpu is not None:
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
else:
self.seq_lens = common_attn_metadata.seq_lens[:num_reqs].to("cpu")

self.graph_pad_size = common_attn_metadata.graph_pad_size
block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def build(

cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs]

seq_lens_cpu = None
if common_attn_metadata.seq_lens_cpu is not None:
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs]
else:
seq_lens_cpu = common_attn_metadata.seq_lens[:num_reqs].to("cpu")

cos, sin = get_cos_and_sin_mla(input_positions, True)

Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommo
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs] if self.seq_lens_cpu is not None else None,
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs]
if self.num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
Expand Down
14 changes: 14 additions & 0 deletions vllm_ascend/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@
torch_sum = torch.sum


def vllm_is_batch_invariant() -> bool:
"""Check if batch-invariant mode is enabled.

This is a compatibility wrapper for the vllm function that was removed
in recent upstream vLLM refactoring.
"""
# Try to access from envs module, fall back to environment variable
if hasattr(envs, "VLLM_BATCH_INVARIANT"):
return bool(envs.VLLM_BATCH_INVARIANT)
else:
# Fallback to environment variable for older vLLM versions
return bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0")))


if HAS_TRITON:
from vllm_ascend.ops.triton.batch_invariant.matmul import (
addmm_batch_invariant,
Expand Down
Loading
Loading