Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
b141669
[Perf][Async] Implement zero-bubble async speculative decoding
izhuhaoran Jan 15, 2026
679054f
skip compute_slot_mapping for async_spec_zero_bubble_mode
izhuhaoran Jan 16, 2026
527b55d
remove seq_lens_cpu & num_computed_tokens_cpu for async_spec_zero_bub…
izhuhaoran Jan 16, 2026
8c2ce4c
Get rid of async_spec_zero_bubble_mode option
MatthewBonanni Jan 21, 2026
1554ab7
Fully async version
MatthewBonanni Jan 22, 2026
315e7e6
Increase max_concurrent_batches
MatthewBonanni Jan 22, 2026
15d8ee9
Handle reordering
MatthewBonanni Jan 23, 2026
abb8ba7
Merge branch 'main' into async-eagle-mod
MatthewBonanni Jan 23, 2026
dd8b7c3
Fix
MatthewBonanni Jan 23, 2026
1573015
Merge branch 'main' into async-eagle-mod
MatthewBonanni Jan 23, 2026
317b452
Cleanup
MatthewBonanni Jan 23, 2026
e7e39ce
Cleanup
MatthewBonanni Jan 23, 2026
6fb3042
Fix
MatthewBonanni Jan 23, 2026
7ff5674
Cleanup
MatthewBonanni Jan 26, 2026
c491421
Fix hybrid
MatthewBonanni Jan 26, 2026
a2e861e
Disable mamba cache mode align
MatthewBonanni Jan 27, 2026
0c45e7e
Treat num_accepted_tokens like num_computed_tokens
MatthewBonanni Jan 27, 2026
1eccc1d
Treat num_accepted_tokens like num_computed_tokens
MatthewBonanni Jan 27, 2026
7005b52
Treat num_accepted_tokens like num_computed_tokens
MatthewBonanni Jan 27, 2026
4e4d8d5
Cleanup
MatthewBonanni Jan 27, 2026
c134792
Eliminate compute_slot_mapping
MatthewBonanni Jan 27, 2026
ea7f670
Rename compute_slot_mapping_gpu to compute_slot_mapping
MatthewBonanni Jan 27, 2026
a113c3f
Restore comments
MatthewBonanni Jan 27, 2026
1153da9
Restore comments
MatthewBonanni Jan 27, 2026
e04f1f2
Add TODO comment
MatthewBonanni Jan 27, 2026
a224144
Fix (num_accepted_tokens shouldn't be int64)
MatthewBonanni Jan 27, 2026
c1f550b
Cleanup
MatthewBonanni Jan 27, 2026
ab2d19a
Fix
MatthewBonanni Jan 27, 2026
3c732eb
Fix
MatthewBonanni Jan 29, 2026
843a151
Fix
MatthewBonanni Jan 29, 2026
4c522a2
Fix
MatthewBonanni Jan 29, 2026
c329c93
Cleanup
MatthewBonanni Jan 30, 2026
98c146b
Use CpuGpuBuffer for arange
MatthewBonanni Jan 30, 2026
5f1d06f
Make seq_lens GPU-only and introduce optimistic_seq_lens
MatthewBonanni Jan 30, 2026
050959c
Restructure if block
MatthewBonanni Jan 30, 2026
718087f
Fix TypeError
MatthewBonanni Feb 5, 2026
31c8674
Fix order of operations error
MatthewBonanni Feb 5, 2026
4c6bd9d
Fix positions and seq_lens calculation
MatthewBonanni Feb 5, 2026
03adb2e
Improve _get_cumsum_and_arange
MatthewBonanni Feb 9, 2026
e76d32e
Rename
MatthewBonanni Feb 9, 2026
526fea3
Use query_pos
MatthewBonanni Feb 9, 2026
48df0ab
Eliminate CPU-side num_computed_tokens from GPUModelRunner. Update op…
MatthewBonanni Feb 10, 2026
f3a2684
Fix M-RoPE and XD-RoPE
MatthewBonanni Feb 10, 2026
9f20781
Factor out _compute_batch_index_mapping
MatthewBonanni Feb 11, 2026
779ee44
Fix optimistic_seq_lens_cpu update
MatthewBonanni Feb 11, 2026
2fab694
Fix acceptance length
MatthewBonanni Feb 11, 2026
65f7b3d
Fix
MatthewBonanni Feb 13, 2026
8cc914c
Use CpuGpuBuffer
MatthewBonanni Feb 13, 2026
2dacc94
Merge branch 'main' into async-eagle-mod
MatthewBonanni Feb 13, 2026
fa849f8
Fix
MatthewBonanni Feb 13, 2026
ccc9575
Use preallocated buffer
MatthewBonanni Feb 13, 2026
bc13824
Use triton kernel instead of pytorch ops
MatthewBonanni Feb 16, 2026
3f0be40
Use buffers to avoid sync
MatthewBonanni Feb 16, 2026
664eaf1
Always use placeholders
MatthewBonanni Feb 17, 2026
66ae34b
Comment
MatthewBonanni Feb 17, 2026
c645e05
Bugfix: add arange scratch buffer
MatthewBonanni Feb 18, 2026
04e023b
Use buffers to prevent allocation on the fly
MatthewBonanni Feb 18, 2026
60ca5ec
Re-add indices_match fast path
MatthewBonanni Feb 18, 2026
82da483
Clean up input batch
MatthewBonanni Feb 18, 2026
f7d028e
Merge branch 'main' into async-eagle-mod
MatthewBonanni Feb 23, 2026
0cdf7de
Simplify BatchIndexMapping and num_computed_tokens tracking
LucasWilkinson Feb 27, 2026
d0b43ec
Merge pull request #2 from LucasWilkinson/simplify-batch-index-mapping
MatthewBonanni Mar 4, 2026
7fde95a
Fix acceptance length
MatthewBonanni Mar 4, 2026
e20565d
Deduplicate
MatthewBonanni Mar 4, 2026
fa02bb8
Make positions GPU-only
MatthewBonanni Mar 4, 2026
a74cd15
Eliminate has_prev_draft_tokens
MatthewBonanni Mar 4, 2026
6aa50c2
Add CPU correction
MatthewBonanni Mar 5, 2026
3cee2b2
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 5, 2026
d4304e4
Skip unnecessary copy
MatthewBonanni Mar 5, 2026
56bd8f8
Rename
MatthewBonanni Mar 5, 2026
a694e92
Docstring
MatthewBonanni Mar 5, 2026
836fed2
Use optimistic_seq_lens_cpu and clean up runner
MatthewBonanni Mar 5, 2026
02e6ee2
Fix arange_size
MatthewBonanni Mar 5, 2026
39807a8
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 5, 2026
09ffa74
Undo disable
MatthewBonanni Mar 5, 2026
1023895
Rename
MatthewBonanni Mar 9, 2026
168ecc2
Rename
MatthewBonanni Mar 9, 2026
79a2127
Docstring
MatthewBonanni Mar 9, 2026
d84e371
Add comment
MatthewBonanni Mar 9, 2026
504a7e1
Move to utils
MatthewBonanni Mar 9, 2026
83c52fd
Comment
MatthewBonanni Mar 9, 2026
a1cd3c8
Clean up unnecessary change
MatthewBonanni Mar 9, 2026
f0f9978
Add comment
MatthewBonanni Mar 9, 2026
1c7d2a7
Restore fast path
MatthewBonanni Mar 9, 2026
74fb29a
Remove unrelated fast path
MatthewBonanni Mar 9, 2026
5f22dce
Clean up
MatthewBonanni Mar 9, 2026
65f18f1
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 10, 2026
741f181
Update comment
MatthewBonanni Mar 10, 2026
1a31810
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 11, 2026
865f2c1
Clean up commit_slot_mapping (dead code)
MatthewBonanni Mar 11, 2026
70c12ed
Accumulate rejections and only issue copies when previous has been co…
MatthewBonanni Mar 11, 2026
279b867
Fix
MatthewBonanni Mar 11, 2026
790a8ab
Revert "Accumulate rejections and only issue copies when previous has…
MatthewBonanni Mar 12, 2026
abc93ec
Update CPU side with _finalize_async_spec_cpu_state
MatthewBonanni Mar 12, 2026
e2d5dbc
Clean up
MatthewBonanni Mar 12, 2026
8980abb
Clean up
MatthewBonanni Mar 12, 2026
7a3412a
Comments
MatthewBonanni Mar 12, 2026
1cac626
Clean up
MatthewBonanni Mar 12, 2026
4192b65
Comment
MatthewBonanni Mar 12, 2026
1c2499d
Use deferred_spec_decode_corrections per Lucas's suggestion
MatthewBonanni Mar 18, 2026
30dae23
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 18, 2026
f852fcc
Fix type
MatthewBonanni Mar 20, 2026
5152276
Clean up
MatthewBonanni Mar 20, 2026
39b8b38
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 20, 2026
377d609
Fix mamba align
MatthewBonanni Mar 20, 2026
51eb871
Fix ngram during correction
MatthewBonanni Mar 20, 2026
6811111
Fix mamba
MatthewBonanni Mar 20, 2026
26741ad
Fix ngram
MatthewBonanni Mar 20, 2026
dfd0896
Clean up
MatthewBonanni Mar 20, 2026
ed3fb14
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 20, 2026
6bed7ac
Fix
MatthewBonanni Mar 23, 2026
7e29a71
Merge branch 'main' into async-eagle-mod
MatthewBonanni Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_prepare_next_token_ids():

next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata,
common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/spec_decode/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded():
)

next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata,
common_attn_metadata.seq_lens_cpu,
sampled_token_ids,
mock_requests,
mock_input_batch,
Expand Down
13 changes: 13 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,19 @@ def __post_init__(self):
else:
self.parallel_config.disable_nccl_for_dp_synchronization = False

if (
self.speculative_config is not None
and self.scheduler_config.async_scheduling
and self.model_config is not None
and not self.model_config.disable_cascade_attn
):
logger.warning_once(
"Disabling cascade attention (not yet compatible with "
"async speculative decoding).",
scope="local",
)
self.model_config.disable_cascade_attn = True

if (
self.model_config is not None
and self.model_config.multimodal_config is not None
Expand Down
12 changes: 4 additions & 8 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model

self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
Expand Down Expand Up @@ -424,8 +423,6 @@ def propose(
)
)

assert self.runner is not None

per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
Expand Down Expand Up @@ -821,7 +818,7 @@ def prepare_next_token_ids_cpu(

def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
seq_lens_cpu: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
Expand All @@ -836,11 +833,10 @@ def prepare_next_token_ids_padded(
"""
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
for i in range(num_reqs)
],
dtype=np.int32,
Expand Down Expand Up @@ -925,7 +921,7 @@ def prepare_inputs_padded(
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
causal=True,
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/spec_decode/extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _build_attn_metadata_builder(

def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
seq_lens: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
Expand All @@ -303,11 +303,10 @@ def prepare_next_token_ids_padded(
device = sampled_token_ids.device

# Compute backup tokens for discarded / invalid requests
seq_lens_list = seq_lens[:num_reqs].tolist()
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
for i in range(num_reqs)
],
dtype=torch.int32,
Expand Down
34 changes: 34 additions & 0 deletions vllm/v1/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
Expand Down Expand Up @@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel(
out_idx,
mask=is_new_token_region & in_bounds,
)


@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor,
num_accepted_tokens: torch.Tensor,
prev_positions: torch.Tensor,
valid_sampled_token_count: torch.Tensor,
prev_num_draft_tokens: torch.Tensor,
cpu_num_computed_tokens: torch.Tensor,
) -> None:
"""Correct num_computed_tokens for async spec decode drift.

Requests that had drafts: corrected = prev_gpu + valid_count.
New requests or non-draft (e.g. prefills): use CPU value directly.
"""
# Clamp because prev_positions can be -1 for new requests
gather_indices = prev_positions.clamp(min=0)

valid_counts = valid_sampled_token_count[gather_indices]
prev_computed = num_computed_tokens[gather_indices]
prev_drafts = prev_num_draft_tokens[gather_indices]

participating = (prev_positions >= 0) & (prev_drafts > 0)
corrected = prev_computed + valid_counts.int()

n = prev_positions.shape[0]
num_computed_tokens[:n].copy_(
torch.where(participating, corrected, cpu_num_computed_tokens)
)
num_accepted_tokens.copy_(
torch.where(participating, valid_counts, num_accepted_tokens)
)
149 changes: 85 additions & 64 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size

Expand Down Expand Up @@ -131,71 +133,33 @@ def swap_row(self, src: int, tgt: int) -> None:
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]

def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
num_tokens = positions.shape[0]
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:

# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
)

block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)

block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
_compute_slot_mapping_kernel[(num_reqs + 1,)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.block_table.gpu,
self.block_table.gpu.stride(0),
self.block_size,
self.slot_mapping.gpu,
TOTAL_CP_WORLD_SIZE=total_cp_world_size,
TOTAL_CP_RANK=total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)

def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)

def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)

def clear(self) -> None:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
Expand Down Expand Up @@ -320,23 +284,80 @@ def swap_row(self, src: int, tgt: int) -> None:
block_table.swap_row(src, tgt)

def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)

def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)

def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)

def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()

def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]


@triton.jit
def _compute_slot_mapping_kernel(
num_tokens,
max_num_tokens,
query_start_loc_ptr, # [num_reqs + 1], int32
positions_ptr, # [num_tokens], int64
block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
block_table_stride, # max_num_blocks_per_req
block_size,
slot_mapping_ptr, # [max_num_tokens], int64
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)

if req_idx == tl.num_programs(0) - 1:
# Pad remaining slots for CUDA graph compatibility.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(
slot_mapping_ptr + offsets,
PAD_ID,
mask=offsets < max_num_tokens,
)
return

start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)

virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
row_offset = req_idx * block_table_stride
for i in range(start_idx, end_idx, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < end_idx
pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
block_indices = pos // virtual_block_size
block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
tl.int64
)

virtual_block_offsets = pos - block_indices * virtual_block_size
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)

slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)
10 changes: 6 additions & 4 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(

# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()

Expand Down Expand Up @@ -989,13 +989,15 @@ def update_async_output_token_ids(self) -> None:
continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
# Also account for case where there may be a smaller number of
# output placeholders (tokens can be discarded after a kv-load failure).
# output placeholders (tokens can be discarded after kv-load
# failure) or a larger number (async spec decode adds optimistic
# placeholders that may exceed the actual acceptance count).
first_placeholder = req_output_token_ids.index(-1)
num_placeholders = len(req_output_token_ids) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders)
del new_ids[num_to_replace:]
end_index = first_placeholder + num_to_replace
req_output_token_ids[first_placeholder:end_index] = new_ids
req_output_token_ids[first_placeholder:] = new_ids
# ^ Implicitly resizes to (first_placeholder + num_to_replace)

def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
"""
Expand Down
Loading
Loading