Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,9 +1500,28 @@ def _update_states_after_model_execute(

is_align = self.cache_config.mamba_cache_mode == "align"
if is_align:
for i, num_tokens in enumerate(
self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy()
# PR #42574: skip the postprocess_mamba call entirely when no
# request can cross a mamba block boundary this step. In that
# regime we only need a non-blocking copy of num_accepted_tokens;
# the event.synchronize() in `_prepare_inputs` next iter absorbs
# the deferred wait.
copy_bufs = self._get_mamba_copy_bufs()
if mamba_utils.can_skip_mamba_postprocess(
scheduler_output,
self.input_batch,
self.requests,
copy_bufs.mamba_spec.block_size,
num_reqs,
):
self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True
)
assert self.num_accepted_tokens_event is not None
self.num_accepted_tokens_event.record()
return
# Fallthrough: blocking sync, then upstream's per-request populate
np_arr = self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy()
for i, num_tokens in enumerate(np_arr):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
else:
self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
Expand Down
35 changes: 35 additions & 0 deletions vllm/v1/worker/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,41 @@ def preprocess_mamba(
do_mamba_copy_block(copy_bufs)


def can_skip_mamba_postprocess(
scheduler_output: SchedulerOutput,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_block_size: int,
num_reqs: int,
) -> bool:
"""Return True iff `postprocess_mamba` is provably a no-op this step.

Bounded by ``n_draft + 1`` accepted tokens, we can decide on CPU
whether any request can cross a mamba block boundary. If not, the
caller can defer the device-to-host sync of ``num_accepted_tokens``.

Must stay in lockstep with the inner conditional in
:func:`postprocess_mamba` below.
"""
if not mamba_block_size or mamba_block_size <= 0:
return False
num_scheduled = scheduler_output.num_scheduled_tokens
spec_decode = scheduler_output.scheduled_spec_decode_tokens
req_ids = input_batch.req_ids
for i in range(num_reqs):
req_id = req_ids[i]
n_draft = len(spec_decode.get(req_id, ()))
n_running = (
requests[req_id].num_computed_tokens
+ num_scheduled[req_id]
- n_draft
)
max_new = n_running + n_draft
if (max_new // mamba_block_size) * mamba_block_size >= n_running:
return False
return True


def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
Expand Down